from scipy.interpolate import griddata
from PIL import Image, ImageFile, ImageDraw
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision.ops import sigmoid_focal_loss
from torchvision.transforms import ToTensor, Resize, Normalize, ColorJitter, RandomAdjustSharpness, \
    RandomRotation, RandomAutocontrast, ToPILImage
from torch import nn
from torchvision.models import resnet101, resnet18, resnet34, resnet50
from torchvision.models.feature_extraction import create_feature_extractor
from torch.optim.lr_scheduler import MultiStepLR
from tqdm import tqdm
from torch.optim import Adam
from scipy.ndimage import gaussian_filter

import csv
import shutil
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torchvision
import os
import torch.nn.functional as F
import numpy as np
import sys
import wandb
import random
import multiprocessing as mp
import ctypes


config = {
    'input_shape': (256, 256),
    'epochs': 30,
    'batch_size': 600,
    'num_workers': 64,
    'stride': (4, 8, 16, 32),
    'radius': (None, 12, 12, None),
    'lr': 1e-3,
    'loss_cof': (1., 3.),
    'img_tower_depth': 4,
    'sticker_tower_depth': 3,
    'use_global_cls': True,
    'use_global_reg': True,
    'desc': 'Reduce center radius',
    'source': 'samples_20.csv',
    'backbone': 'resnet50',
    'focal_alpha': 0.25,
    'cls_hidden_dim': 1024,
    'reg_hidden_dim': 1024,
    'model_dir': 'sol_model_12rad',
    'lr_milestones': [5, 10, 15],
    'training_data_sample': None,
    'normalize_output': True,
}


# wandb.init(mode="disabled")

sys.setrecursionlimit(5000)
torch.multiprocessing.set_sharing_strategy('file_system')
ImageFile.LOAD_TRUNCATED_IMAGES = True


def d_iou(pr_bboxes, gt_bboxes, iou_only=False):
    gt_w = gt_bboxes[:, 2] - gt_bboxes[:, 0]
    gt_h = gt_bboxes[:, 3] - gt_bboxes[:, 1]
    pr_w = pr_bboxes[:, 2] - pr_bboxes[:, 0]
    pr_h = pr_bboxes[:, 3] - pr_bboxes[:, 1]

    gt_area = gt_w * gt_h
    pr_area = pr_w * pr_h

    gt_center_x = (gt_bboxes[:, 2] + gt_bboxes[:, 0]) / 2
    gt_center_y = (gt_bboxes[:, 3] + gt_bboxes[:, 1]) / 2
    pr_center_x = (pr_bboxes[:, 2] + pr_bboxes[:, 0]) / 2
    pr_center_y = (pr_bboxes[:, 3] + pr_bboxes[:, 1]) / 2
    dis_center_sqr = (gt_center_x - pr_center_x) ** 2 + (gt_center_y - pr_center_y) ** 2

    lt = torch.max(gt_bboxes[:, :2], pr_bboxes[:, :2])
    rb = torch.min(gt_bboxes[:, 2:], pr_bboxes[:, 2:])
    wh = torch.clamp(rb - lt, min=0)
    inter = wh[:, 0] * wh[:, 1]
    union = gt_area + pr_area - inter
    iou = inter / union
    if iou_only:
        return iou

    lt = torch.min(gt_bboxes[:, :2], pr_bboxes[:, :2])
    rb = torch.max(gt_bboxes[:, 2:], pr_bboxes[:, 2:])
    out_diag = torch.clamp(rb[:, 0] - lt[:, 0], min=0) ** 2 + torch.clamp(rb[:, 1] - lt[:, 1], min=0) ** 2

    dious = iou - dis_center_sqr / out_diag
    dious = torch.clamp(dious, min=-1.0, max=1.0)
    return dious, 1. - dious, iou


def read_image(path):
    if isinstance(path, str):
        img = Image.open(path)
    else:
        img = path

    if img.mode == 'RGBA':
        img_arr = np.array(img)
        return Image.fromarray(img_arr[:, :, :3]), Image.fromarray(img_arr[:, :, 3])
    else:
        return img.convert('RGB'), Image.fromarray(np.ones(config['input_shape'], dtype=np.uint8) * 255)


class AddStickerDataset(Dataset):
    def process_row(self, row):
        x, y, w, h = (float(row[name]) for name in ['x', 'y', 'w', 'h'])

        raw_img, img_alpha = read_image(row['input'])
        raw_sticker, sticker_alpha = read_image(row['sticker'])

        slices = row['input'].split('.')
        slices[-2] = slices[-2][:-5] + 'mask'
        slices[-1] = 'jpg'
        mask_path = '.'.join(slices)
        mask = Image.open(mask_path)

        if self.training:
            crop_side = random.randint(0, 5)
            x_offset, y_offset = 0, 0
        else:
            crop_side = 4

        if crop_side == 0 and x > 0:  # left
            x_offset = random.uniform(0.001, min(0.3, x))
            img = raw_img.crop((x_offset * raw_img.size[0], 0, raw_img.size[0], raw_img.size[1]))
            mask = mask.crop((x_offset * mask.size[0], 0, mask.size[0], mask.size[1]))
            img_alpha = img_alpha.crop((x_offset * img_alpha.size[0], 0, img_alpha.size[0], img_alpha.size[1]))
            x = (x - x_offset) / (1. - x_offset)
            w = w / (1. - x_offset)
        elif crop_side == 1 and y > 0:
            y_offset = random.uniform(0.001, min(0.3, y))
            img = raw_img.crop((0, y_offset * raw_img.size[1], raw_img.size[0], raw_img.size[1]))
            mask = mask.crop((0, y_offset * mask.size[1], mask.size[0], mask.size[1]))
            img_alpha = img_alpha.crop((0, y_offset * img_alpha.size[1], img_alpha.size[0], img_alpha.size[1]))
            y = (y - y_offset) / (1. - y_offset)
            h = h / (1. - y_offset)
        elif crop_side == 2 and x + w < 1:
            x_offset = random.uniform(0.001, min(0.3, 1 - x - w))
            img = raw_img.crop((0, 0, raw_img.size[0] * (1 - x_offset), raw_img.size[1]))
            mask = mask.crop((0, 0, mask.size[0] * (1 - x_offset), mask.size[1]))
            img_alpha = img_alpha.crop((0, 0, img_alpha.size[0] * (1 - x_offset), img_alpha.size[1]))
            x = x / (1. - x_offset)
            w = w / (1. - x_offset)
        elif crop_side == 3 and y + h < 1:
            y_offset = random.uniform(0.001, min(0.3, 1 - y - h))
            img = raw_img.crop((0, 0, raw_img.size[0], raw_img.size[1] * (1 - y_offset)))
            mask = mask.crop((0, 0, mask.size[0], mask.size[1] * (1 - y_offset)))
            img_alpha = img_alpha.crop((0, 0, img_alpha.size[0], img_alpha.size[1] * (1 - y_offset)))
            y = y / (1. - y_offset)
            h = h / (1. - y_offset)
        else:
            img = raw_img.copy()

        if self.training:
            crop_side = random.randint(0, 5)
            x_offset, y_offset = 0, 0
        else:
            crop_side = 4
        if crop_side == 0:  # left
            x_offset = random.uniform(0.001, 0.2)
            sticker = raw_sticker.crop(
                (x_offset * raw_sticker.size[0], 0, raw_sticker.size[0], raw_sticker.size[1]))
            sticker_alpha = sticker_alpha.crop(
                (x_offset * sticker_alpha.size[0], 0, sticker_alpha.size[0], sticker_alpha.size[1]))
            x += w * x_offset
            w -= w * x_offset
        elif crop_side == 1:
            y_offset = random.uniform(0.001, 0.2)
            sticker = raw_sticker.crop(
                (0, y_offset * raw_sticker.size[1], raw_sticker.size[0], raw_sticker.size[1]))
            sticker_alpha = sticker_alpha.crop(
                (0, y_offset * sticker_alpha.size[1], sticker_alpha.size[0], sticker_alpha.size[1]))
            y += h * y_offset
            h -= h * y_offset
        elif crop_side == 2:
            x_offset = random.uniform(0.001, 0.2)
            sticker = raw_sticker.crop(
                (0, 0, (1. - x_offset) * raw_sticker.size[0], raw_sticker.size[1]))
            sticker_alpha = sticker_alpha.crop(
                (0, 0, (1. - x_offset) * sticker_alpha.size[0], sticker_alpha.size[1]))
            w -= w * x_offset
        elif crop_side == 3:
            y_offset = random.uniform(0.001, 0.2)
            sticker = raw_sticker.crop(
                (0, 0, raw_sticker.size[0], (1. - y_offset) * raw_sticker.size[1]))
            sticker_alpha = sticker_alpha.crop(
                (0, 0, sticker_alpha.size[0], (1. - y_offset) * sticker_alpha.size[1]))
            h -= h * y_offset
        else:
            sticker = raw_sticker.copy()

        to_export = None
        if self.export is not None and self.export_count.value > 0:
            self.export_count.value -= 1
            to_export = self.export_count.value

        gt_box = (x, y, x + w, y + h, w, h)

        x *= self.input_shape[0]
        y *= self.input_shape[1]
        w *= self.input_shape[0]
        h *= self.input_shape[1]

        center_x = x + w * 0.5
        center_y = y + h * 0.5

        label = []

        for rad, stride in zip(self.radius, self.strides):
            if rad is not None:
                feature_x, feature_y = self.input_shape[0] // stride, self.input_shape[1] // stride

                for c_x in range(feature_x):
                    for c_y in range(feature_y):
                        mapped_x = int(stride / 2) + c_x * stride
                        mapped_y = int(stride / 2) + c_y * stride
                        if x <= mapped_x <= x + w and y <= mapped_y <= y + h and \
                                (mapped_x - center_x) ** 2 + (mapped_y - center_y) ** 2 < rad ** 2:
                            left = (mapped_x - x) / stride
                            right = (x + w - mapped_x) / stride
                            top = (mapped_y - y) / stride
                            bottom = (y + h - mapped_y) / stride
                            label.append((1.0, left, top, right, bottom))
                        else:
                            label.append((0., 0., 0., 0., 0.))
        label = np.array(label)

        if np.sum(label[:, 0]) == 0:
            print(row)

        img_origin = self.preprocess_image(img)
        sticker_origin = self.preprocess_sticker(sticker)
        if to_export is not None:
            img_sample = self.inv_image(img_origin)
            img_sample.save(os.path.join(self.export, '{}_input.{}'.format(to_export, row['input'].split('.')[-1])))
            sticker_sample = self.inv_image(sticker_origin)
            sticker_sample.save(os.path.join(self.export, '{}_sticker.{}'.format(to_export, row['sticker'].split('.')[-1])))

            shutil.copy(row['output'],
                        os.path.join(self.export, '{}_output.{}'.format(to_export, row['output'].split('.')[-1])))
            shutil.copy(mask_path, os.path.join(self.export, '{}_mask.jpg'.format(to_export)))

        img_origin = img_origin.transpose(-1, -2)
        img_alpha = self.preprocess_alpha(img_alpha).transpose(-1, -2)
        img_mask = self.preprocess_mask(mask).transpose(-1, -2)
        img_tensor = torch.concat((img_origin, img_alpha, img_mask), dim=0)

        sticker_origin = sticker_origin.transpose(-1, -2)
        sticker_alpha = self.preprocess_alpha(sticker_alpha).transpose(-1, -2)
        sticker_tensor = torch.concat((sticker_origin, sticker_alpha), dim=0)

        return img_tensor, sticker_tensor, label, gt_box, row['input'], row['sticker'], row['output']

    def __init__(self, source, config, export):
        self._data = list()
        self.box_info = list()
        self.training = True

        reader = csv.DictReader(open(source, 'r'))

        self.input_shape = config['input_shape']
        self.radius = config['radius']
        self.strides = config['strides']
        self.export = export
        if self.export is not None:
            self.export_count = mp.Value(ctypes.c_int, 100)
        else:
            self.export_count = 0

        for rid, row in enumerate(reader):
            if row['brush_mask'] is not None and len(row['brush_mask']) > 0:
                continue
            if row['rotation'] is not None and abs(float(row['rotation'])) > 10:
                continue
            if float(row['w']) * config['input_shape'][0] < 9 or float(row['h']) * config['input_shape'][1] < 9:
                continue
            if float(row['opacity']) < 50:
                continue
            center_x = float(row['x']) + float(row['w']) * 0.5
            center_y = float(row['y']) + float(row['h']) * 0.5
            if 1. / 16 <= center_x <= 15. / 16 and 1. / 16 <= center_y <= 15. / 16:
                self._data.append(row)

        self.preprocess_image = torchvision.transforms.Compose([
            ToTensor(),
            Resize(config['input_shape']),
            ColorJitter(0.4, 0.4, 0.4),
            RandomAdjustSharpness(sharpness_factor=2),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

        self.inv_image = torchvision.transforms.Compose([
            Normalize(mean=[0., 0., 0.], std=[1. / 0.229, 1. / 0.224, 1. / 0.225]),
            Normalize(mean=[-0.485, -0.456, -0.406], std=[1., 1., 1.]),
            ToPILImage(),
        ])

        self.preprocess_sticker = torchvision.transforms.Compose([
            ToTensor(),
            Resize(config['input_shape']),
            ColorJitter(0.4, 0.4, 0.4),
            RandomAdjustSharpness(sharpness_factor=2),
            RandomRotation(10),
            RandomAutocontrast(),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        self.preprocess_alpha = torchvision.transforms.Compose([
            ToTensor(),
            Resize(config['input_shape']),
            Normalize(mean=[0.927], std=[0.256])
        ])
        self.preprocess_mask = torchvision.transforms.Compose([
            ToTensor(),
            Resize(config['input_shape']),
            Normalize(mean=[0.385], std=[0.484])
        ])

        for rad, stride in zip(self.radius, self.strides):
            if rad is not None:
                feature_x, feature_y = self.input_shape[0] // stride, self.input_shape[1] // stride

                for c_x in range(feature_x):
                    for c_y in range(feature_y):
                        mapped_x = int(stride / 2) + c_x * stride
                        mapped_y = int(stride / 2) + c_y * stride
                        self.box_info.append((mapped_x / self.input_shape[0],
                                              mapped_y / self.input_shape[1],
                                              stride / self.input_shape[0]))

        print('{} samples collected.'.format(len(self._data)))

    def __len__(self):
        return len(self._data)

    def __getitem__(self, item):
        return self.process_row(self._data[item])


def prepare_dataset(source, from_scratch, on_server, model_dir, batch_size=4, export=None):
    if export is not None:
        if os.path.exists(export):
            shutil.rmtree(export)
        os.mkdir(export)
    if not os.path.exists(model_dir):
        os.mkdir(model_dir)

    if from_scratch:
        dataset_config = {
            'input_shape': config['input_shape'],
            'radius': config['radius'],
            'strides': config['stride'],
        }
        dataset = AddStickerDataset(source, dataset_config, export=export)
        if export is None:
            torch.save(dataset, os.path.join(model_dir, 'dataset_sg.pt'))
    else:
        dataset = torch.load(os.path.join(model_dir, 'dataset_sg.pt'))

    train_len = int(len(dataset) * 0.98)
    val_len = int((len(dataset) - train_len) * 0.5)
    test_len = len(dataset) - train_len - val_len
    train, val, test = random_split(dataset, [train_len, val_len, test_len])
    val.training = False
    test.training = False
    box_info = dataset.box_info

    num_workers = min(batch_size, config['num_workers']) if on_server else 2

    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(val, batch_size=batch_size, shuffle=True, num_workers=8)
    test_loader = DataLoader(test, batch_size=batch_size, shuffle=True, num_workers=8)

    if export is None:
        torch.save(train_loader, os.path.join(model_dir, 'sol_train.pt'))
        torch.save(val_loader, os.path.join(model_dir, 'sol_val.pt'))
        torch.save(test_loader, os.path.join(model_dir, 'sol_test.pt'))
        torch.save(box_info, os.path.join(model_dir, 'box_info.pt'))
        print('Dataset dumped.')
    else:
        print('Dataset not dumped as for exporting.')

    return train_loader, val_loader, test_loader, box_info


def init_conv_std(module, std=0.1):
    if isinstance(module, nn.Conv2d):
        nn.init.normal_(module.weight, std=std)

        if module.bias is not None:
            nn.init.constant_(module.bias, 0.01)


def normalized_euclidean_distance(u, v, dim=-1):
    a = torch.var(u - v, dim=dim)
    b = torch.var(u, dim=dim)
    c = torch.var(v, dim=dim)
    return 0.5 * a / (b + c)


class HeadGeneral(nn.Module):
    def __init__(self, img_channel, sticker_channel, img_depth, sticker_depth):
        super().__init__()
        img_cls_tower = []
        img_reg_tower = []
        sticker_cls_tower = []
        sticker_reg_tower = []
        sticker_in, sticker_out = sticker_channel

        for i in range(img_depth):
            img_cls_tower.append(nn.Conv2d(img_channel, img_channel, (3, 3), padding=1, bias=False))
            img_cls_tower.append(nn.GroupNorm(32, img_channel))
            img_cls_tower.append(nn.ReLU())

            img_reg_tower.append(nn.Conv2d(img_channel, img_channel, (3, 3), padding=1, bias=False))
            img_reg_tower.append(nn.GroupNorm(32, img_channel))
            img_reg_tower.append(nn.ReLU())

        for i in range(sticker_depth):
            if i == 0:
                inc, outc = sticker_in, sticker_out
            else:
                inc, outc = sticker_out, sticker_out

            sticker_cls_tower.append(nn.Conv1d(inc, outc, (1,)))
            sticker_cls_tower.append(nn.GroupNorm(32, outc))
            sticker_cls_tower.append(nn.ReLU())

            sticker_reg_tower.append(nn.Conv1d(inc, outc, (1,)))
            sticker_reg_tower.append(nn.GroupNorm(32, outc))
            sticker_reg_tower.append(nn.ReLU())

        self.img_cls_tower = nn.Sequential(*img_cls_tower)
        self.img_reg_tower = nn.Sequential(*img_reg_tower)
        self.sticker_cls_tower = nn.Sequential(*sticker_cls_tower)
        self.sticker_reg_tower = nn.Sequential(*sticker_reg_tower)

        if config['use_global_cls']:
            self.global_cls_head = nn.Sequential(
                nn.Conv1d(img_channel, img_channel, (1,)),
                nn.GroupNorm(32, img_channel),
                nn.ReLU()
            )
            self.cls_head = nn.Sequential(
                nn.Conv1d(img_channel * 2 + sticker_out, config['cls_hidden_dim'], (1,)),
                nn.ReLU(),
                nn.Conv1d(config['cls_hidden_dim'], config['cls_hidden_dim'], (1,)),
                nn.ReLU(),
                nn.Conv1d(config['cls_hidden_dim'], 1, (1,)),
            )
        else:
            self.cls_head = nn.Sequential(
                nn.Conv1d(img_channel + sticker_out, config['cls_hidden_dim'], (1,)),
                nn.ReLU(),
                nn.Conv1d(config['cls_hidden_dim'], config['cls_hidden_dim'], (1,)),
                nn.ReLU(),
                nn.Conv1d(config['cls_hidden_dim'], 1, (1,)),
            )

        if config['use_global_reg']:
            self.global_reg_head = nn.Sequential(
                nn.Conv1d(img_channel, img_channel, (1,)),
                nn.GroupNorm(32, img_channel),
                nn.ReLU()
            )
            self.reg_head = nn.Sequential(
                nn.Conv1d(img_channel * 2 + sticker_out, config['reg_hidden_dim'], (1,)),
                nn.ReLU(),
                nn.Conv1d(config['reg_hidden_dim'], config['reg_hidden_dim'], (1,)),
                nn.ReLU(),
                nn.Conv1d(config['reg_hidden_dim'], 4, (1,)),
            )
        else:
            self.reg_head = nn.Sequential(
                nn.Conv1d(img_channel + sticker_out, config['reg_hidden_dim'], (1,)),
                nn.ReLU(),
                nn.Conv1d(config['reg_hidden_dim'], config['reg_hidden_dim'], (1,)),
                nn.ReLU(),
                nn.Conv1d(config['reg_hidden_dim'], 4, (1,)),
            )

        self.apply(init_conv_std)

    def forward(self, img_feature_map, sticker_feature, img_global):
        img_global = img_global.expand(-1, -1, img_feature_map.shape[-1] * img_feature_map.shape[-2])
        if config['use_global_cls']:
            img_global_cls = self.global_cls_head(img_global)
        if config['use_global_reg']:
            img_global_reg = self.global_reg_head(img_global)

        cls_img_feature = self.img_cls_tower(img_feature_map)
        cls_img_feature = cls_img_feature.reshape(cls_img_feature.shape[0], cls_img_feature.shape[1], -1)
        cls_sticker_feature = self.sticker_cls_tower(sticker_feature).expand(-1, -1, cls_img_feature.shape[-1])

        if config['use_global_cls']:
            cls_feature = torch.concat((cls_img_feature, cls_sticker_feature, img_global_cls), dim=1)
        else:
            cls_feature = torch.concat((cls_img_feature, cls_sticker_feature), dim=1)
        cls_pred = self.cls_head(cls_feature).transpose(-1, -2)

        reg_img_feature = self.img_reg_tower(img_feature_map) \
            .reshape((img_feature_map.shape[0], img_feature_map.shape[1], -1))
        reg_sticker_feature = self.sticker_reg_tower(sticker_feature)\
            .expand(-1, -1, reg_img_feature.shape[-1])
        if config['use_global_reg']:
            reg_feature = torch.concat((reg_img_feature, reg_sticker_feature, img_global_reg), dim=1)
        else:
            reg_feature = torch.concat((reg_img_feature, reg_sticker_feature), dim=1)
        reg_pred = torch.exp(self.reg_head(reg_feature)).transpose(-1, -2)

        return cls_pred, reg_pred


layer_info = {
    'resnet101': (256, 512, 1024, 2048),
    'resnet50': (256, 512, 1024, 2048),
    'resnet34': (64, 128, 256, 512),
    'resnet18': (64, 128, 256, 512),
}

backbone_gen = {
    'resnet101': resnet101,
    'resnet50': resnet50,
    'resnet34': resnet34,
    'resnet18': resnet18,
}


class AddStickerGeneral(nn.Module):
    def __init__(self, backbone, box_info):
        super().__init__()
        self.heads = []
        self.layers = []
        self.box_info = box_info

        for i in range(len(config['radius'])):
            if config['radius'][i] is not None:
                self.layers.append('layer{}'.format(i + 1))
                self.heads.append(HeadGeneral(layer_info[backbone][i],
                                              (layer_info[backbone][i], min(1024, layer_info[backbone][i])),
                                              config['img_tower_depth'], config['sticker_tower_depth']))

        self.heads = nn.ModuleList(self.heads)

        self.avgpool = nn.Sequential(
            nn.AdaptiveAvgPool2d(output_size=(1, 1)),
            nn.Flatten(2),
        )

        self.img_backbone = create_feature_extractor(backbone_gen[config['backbone']](pretrained=True),
                                                     return_nodes=self.layers)
        self.sticker_backbone = create_feature_extractor(backbone_gen[config['backbone']](pretrained=True),
                                                         return_nodes=self.layers)

        self.img_backbone.conv1 = nn.Conv2d(5, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.sticker_backbone.conv1 = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

        self.preprocess_image = torchvision.transforms.Compose([
            ToTensor(),
            Resize(config['input_shape']),
            ColorJitter(0.4, 0.4, 0.4),
            RandomAdjustSharpness(sharpness_factor=2),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        self.preprocess_alpha = torchvision.transforms.Compose([
            ToTensor(),
            Resize(config['input_shape']),
            Normalize(mean=[0.927], std=[0.256])
        ])
        self.preprocess_mask = torchvision.transforms.Compose([
            ToTensor(),
            Resize(config['input_shape']),
            Normalize(mean=[0.385], std=[0.484])
        ])

        self.log = dict()

    def forward(self, img, sticker):
        img_feature = self.img_backbone(img)
        sticker_feature = self.sticker_backbone(sticker)

        cls_predictions = []
        reg_predictions = []

        for layer, head in zip(self.layers, self.heads):
            img_local = img_feature[layer]
            img_global = self.avgpool(img_local)
            sticker_local = sticker_feature[layer]
            sticker_global = self.avgpool(sticker_local)
            cls_pred, reg_pred = head(img_local, sticker_global, img_global)
            cls_predictions.append(cls_pred)
            reg_predictions.append(reg_pred)

        cls_pred = torch.concat(cls_predictions, dim=1)
        reg_pred = torch.concat(reg_predictions, dim=1)
        return cls_pred, reg_pred

    @staticmethod
    def normalize_wh_ratio(img, sticker, pred_w, pred_h, strength):
        pred_area = pred_w * pred_h
        if pred_area < 0.5:
            origin_w, origin_h = sticker.size
            origin_wh = origin_w / origin_h * img.size[1] / img.size[0]
            if strength < 3:
                pred_wh = pred_w / img.size[0] * img.size[1] / pred_h
                normalized_wh = ((origin_wh ** strength) * pred_wh) ** (1. / (strength + 1))
            else:
                normalized_wh = origin_w / origin_h * img.size[1] / img.size[0]
            normal_h = (pred_area / normalized_wh) ** 0.5
            normal_w = normal_h * normalized_wh
            return normal_w, normal_h
        else:
            return pred_w, pred_h

    def get_sticker_feature(self, sticker, device, layer='layer3'):
        eval_preprocess_sticker = torchvision.transforms.Compose([
            ToTensor(),
            Resize(config['input_shape']),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        raw_sticker, sticker_alpha = read_image(sticker)
        sticker_origin = eval_preprocess_sticker(raw_sticker).transpose(-1, -2)
        sticker_alpha = self.preprocess_alpha(sticker_alpha).transpose(-1, -2)
        sticker_tensor = torch.concat((sticker_origin, sticker_alpha), dim=0).to(device).unsqueeze(0)
        sticker_feature = self.avgpool(self.sticker_backbone(sticker_tensor)[layer])
        return sticker_feature

    def eval_single(self, img, sticker, mask, device):
        eval_preprocess_image = torchvision.transforms.Compose([
            ToTensor(),
            Resize(config['input_shape']),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

        eval_preprocess_sticker = torchvision.transforms.Compose([
            ToTensor(),
            Resize(config['input_shape']),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

        with torch.no_grad():
            raw_img, img_alpha = read_image(img)
            raw_sticker, sticker_alpha = read_image(sticker)

            img_origin = eval_preprocess_image(raw_img).transpose(-1, -2)
            img_alpha = self.preprocess_alpha(img_alpha).transpose(-1, -2)
            img_mask = self.preprocess_mask(mask).transpose(-1, -2)
            img_tensor = torch.concat((img_origin, img_alpha, img_mask), dim=0)

            sticker_origin = eval_preprocess_sticker(raw_sticker).transpose(-1, -2)
            sticker_alpha = self.preprocess_alpha(sticker_alpha).transpose(-1, -2)
            sticker_tensor = torch.concat((sticker_origin, sticker_alpha), dim=0)

            img_tensor = img_tensor.unsqueeze(0).to(device)
            sticker_tensor = sticker_tensor.unsqueeze(0).to(device)
            cls_pred, reg_pred = self.forward(img_tensor, sticker_tensor)

            cls_pred_sigmoid = torch.sigmoid(cls_pred.squeeze(-1)[0])
            top_scores, top_indices = torch.topk(cls_pred_sigmoid, 5)

            pred = torch.argmax(cls_pred.squeeze(-1), dim=1)[0]
            pred_box = torch.empty((4,), dtype=torch.float)
            cls_pred_visual = cls_pred.clone().squeeze(-1)[0].cpu().numpy()

            data = []
            for i in range(len(cls_pred_visual)):
                x = self.box_info[i][0]
                y = self.box_info[i][1]
                data.append((1 - y, x, cls_pred_visual[i]))

            x = self.box_info[pred][0] - reg_pred[0, pred, 0] * self.box_info[pred][2]
            y = self.box_info[pred][1] - reg_pred[0, pred, 1] * self.box_info[pred][2]
            w = (reg_pred[0, pred, 0] + reg_pred[0, pred, 2]) * self.box_info[pred][2]
            h = (reg_pred[0, pred, 1] + reg_pred[0, pred, 3]) * self.box_info[pred][2]

            self.log['cls_pred'] = float(torch.sigmoid(cls_pred[0][pred]).cpu().detach().numpy()[0])
            self.log['raw_box'] = '{:.4f}, {:.4f}, {:.4f}, {:.4f}'.format(x, y, w, h)

            if config['normalize_output']:
                center_x = x + 0.5 * w
                center_y = y + 0.5 * h

                if w * h > 0.8:
                    ratio = (w * h / 0.8) ** 0.5
                    w /= ratio
                    h /= ratio

                strength = 1 if self.log['cls_pred'] > 0.5 else 5
                print('Normalizing w/h ratio, before: w={}, h={}, strength={}.'.format(w, h, strength))
                w, h = AddStickerGeneral.normalize_wh_ratio(img, sticker, w, h, strength)
                print('After: w={}, h={}.'.format(w, h))

                x = center_x - 0.5 * w
                y = center_y - 0.5 * h

            self.log['normalized_box'] = '{:.4f}, {:.4f}, {:.4f}, {:.4f}'.format(x, y, w, h)

            pred_box[0] = x
            pred_box[1] = y
            pred_box[2] = x + w
            pred_box[3] = y + h

            top5_boxes = []
            for idx, score in zip(top_indices.cpu().tolist(), top_scores.cpu().tolist()):
                cx = self.box_info[idx][0] - reg_pred[0, idx, 0] * self.box_info[idx][2]
                cy = self.box_info[idx][1] - reg_pred[0, idx, 1] * self.box_info[idx][2]
                ww = (reg_pred[0, idx, 0] + reg_pred[0, idx, 2]) * self.box_info[idx][2]
                hh = (reg_pred[0, idx, 1] + reg_pred[0, idx, 3]) * self.box_info[idx][2]
                top5_boxes.append((float(cx), float(cy), float(cx + ww), float(cy + hh), score))

        return export_case(img, sticker, pred_box, is_url=False, top5=top5_boxes)


def export_case(input_url, sticker_url, position, is_url=True, top5=None):
    if is_url:
        img = Image.open(input_url)
        sticker = Image.open(sticker_url)
    else:
        img = input_url
        sticker = sticker_url

    sticker_shape = max(1, int(img.size[0] * (position[2] - position[0]))), \
                    max(1, int(img.size[1] * (position[3] - position[1])))
    sticker_to_paste = sticker.resize(sticker_shape)

    sticker_location = int(img.size[0] * position[0]), int(img.size[1] * position[1])
    result_image = img.copy()
    if sticker_to_paste.mode == 'RGBA':
        result_image.paste(sticker_to_paste, sticker_location, sticker_to_paste)
    else:
        result_image.paste(sticker_to_paste, sticker_location)

    if top5 is not None:
        return result_image, top5
    else:
        return result_image


def evaluate(model, loader, box_info, name, export=None):
    if export is not None:
        if os.path.exists(export):
            shutil.rmtree(export)
        os.mkdir(export)
        export_count = 200
    else:
        export_count = 0

    total, acc3, acc5, acc7, label_acc = 0, 0, 0, 0, 0
    with torch.no_grad():
        bid = 0
        # for bid, item in tqdm(enumerate(loader)):
        for item in tqdm(loader):
            img_tensor, sticker_tensor, label, gt_box, image_urls, sticker_urls, output_urls = item
            img_tensor = img_tensor.cuda()
            sticker_tensor = sticker_tensor.cuda()
            cls_pred, reg_pred = model(img_tensor, sticker_tensor)

            pred = torch.argmax(cls_pred.squeeze(-1), dim=1)

            pred_box = torch.empty((len(pred), 4), dtype=torch.float, device='cuda')
            for i in range(len(pred)):
                pred_box[i, 0] = box_info[0, pred[i], 0] - reg_pred[i, pred[i], 0] * box_info[0, pred[i], 2]
                pred_box[i, 1] = box_info[0, pred[i], 1] - reg_pred[i, pred[i], 1] * box_info[0, pred[i], 2]
                pred_box[i, 2] = box_info[0, pred[i], 0] + reg_pred[i, pred[i], 2] * box_info[0, pred[i], 2]
                pred_box[i, 3] = box_info[0, pred[i], 1] + reg_pred[i, pred[i], 3] * box_info[0, pred[i], 2]
                if export_count > 0 and export is not None:
                    result_image = export_case(image_urls[i], sticker_urls[i], pred_box[i])
                    shutil.copy(image_urls[i],
                                os.path.join(export, '{}_input.{}'.format(total + i, image_urls[i].split('.')[-1])))
                    shutil.copy(sticker_urls[i],
                                os.path.join(export, '{}_sticker.{}'.format(total + i, sticker_urls[i].split('.')[-1])))
                    shutil.copy(output_urls[i],
                                os.path.join(export, '{}_output.{}'.format(total + i, output_urls[i].split('.')[-1])))
                    result_image.save(
                        os.path.join(export, '{}_pred_{}.{}'.format(total + i,
                                                                    format(cls_pred[i, pred[i]].item() * 100, '.2f'),
                                                                    image_urls[i].split('.')[-1])))
                    export_count -= 1
                if label[i, pred[i], 0] > 0.5:
                    label_acc += 1

            gt_box = torch.stack(gt_box[:4]).transpose(0, 1).cuda()
            iou = d_iou(gt_box, pred_box, iou_only=True)
            if bid < 0:
                print(pred_box)
                print(gt_box)
                print(iou)
                bid += 1

            total += len(pred)
            acc3 += len(iou[iou > 0.3])
            acc5 += len(iou[iou > 0.5])
            acc7 += len(iou[iou > 0.7])

    print(name)
    print('Pixel acc={}({}/{}) Acc@0.7={}({}/{}) Acc@0.5={}({}/{}), Acc@0.3={}({}/{})'.format(
        label_acc / total, label_acc, total, acc7 / total, acc7, total, acc5 / total, acc5, total, acc3 / total, acc3, total))
    wandb.log({
        'Pixel acc': label_acc / total,
        'Acc 0.7': acc7 / total,
        'Acc 0.5': acc5 / total,
        'Acc 0.3': acc3 / total,
    })
    return acc5 / total


def train(model_dir, epochs=3, resume=None, data=None):
    if not os.path.exists(model_dir):
        os.mkdir(model_dir)

    if data is None:
        train_loader = torch.load(os.path.join(model_dir, 'sol_train.pt'))
        val_loader = torch.load(os.path.join(model_dir, 'sol_val.pt'))
        test_loader = torch.load(os.path.join(model_dir, 'sol_test.pt'))
        box_info = torch.load(os.path.join(model_dir, 'box_info.pt'))
        print('Dataset loaded.')
    else:
        train_loader, val_loader, test_loader, box_info = data

    box_info = torch.tensor(box_info, dtype=torch.float, device='cuda').unsqueeze(0)
    model = AddStickerGeneral(config['backbone'], box_info)
    if resume is not None:
        model.load_state_dict(torch.load(resume))
    model = nn.DataParallel(model)

    model.train()
    model.cuda()

    best_acc, best_epoch = 0.0, 0
    acc_history = [0., 0., 0.]

    cls_cof, reg_cof = config['loss_cof']

    optimizer = Adam(model.parameters(), lr=config['lr'])
    scheduler = MultiStepLR(optimizer, milestones=config['lr_milestones'], gamma=0.5)
    for epoch in range(epochs):
        # for bid, item in tqdm(enumerate(train_loader)):
        for item in tqdm(train_loader):
            optimizer.zero_grad()
            img_tensor, sticker_tensor, label, gt_box, image_urls, sticker_urls, output_urls = item

            img_tensor = img_tensor.cuda()
            sticker_tensor = sticker_tensor.cuda()
            label = label.cuda()
            cls_pred, reg_pred = model(img_tensor, sticker_tensor)

            label_box = torch.flatten(label[:, :, 0], 0)
            gt_cls = label[:, :, 0].unsqueeze(-1).float()
            cls_loss = sigmoid_focal_loss(cls_pred, gt_cls, alpha=config['focal_alpha'], reduction='sum')

            gt_cls = gt_cls.flatten()
            cls_pred = cls_pred.flatten()
            all_pos_pred = len(cls_pred[cls_pred > 0])
            if all_pos_pred == 0:
                cls_precision = 0
            else:
                cls_precision = gt_cls[cls_pred > 0]
                cls_precision = len(cls_precision[cls_precision > 0.5]) / all_pos_pred
            all_pos_gt = len(gt_cls[gt_cls > 0.5])
            cls_recall = cls_pred[gt_cls > 0.5]
            cls_recall = len(cls_recall[cls_recall > 0]) / all_pos_gt

            cls_loss /= all_pos_gt

            pred_box = torch.zeros_like(reg_pred, dtype=torch.float, device='cuda')
            # LEFT-TOP-RIGHT-BOTTOM
            pred_box[:, :, 0] = box_info[:, :, 0] - reg_pred[:, :, 0] * box_info[:, :, 2]
            pred_box[:, :, 1] = box_info[:, :, 1] - reg_pred[:, :, 1] * box_info[:, :, 2]
            pred_box[:, :, 2] = box_info[:, :, 0] + reg_pred[:, :, 2] * box_info[:, :, 2]
            pred_box[:, :, 3] = box_info[:, :, 1] + reg_pred[:, :, 3] * box_info[:, :, 2]
            gt_boxes = torch.stack(gt_box[:4]).transpose(0, 1).unsqueeze(1).expand(-1, pred_box.shape[1], -1).cuda()

            pred_box_iou = torch.reshape(pred_box, (-1, 4))[label_box > 0.5]
            gt_box_iou = torch.reshape(gt_boxes, (-1, 4))[label_box > 0.5]
            giou, giou_loss, iou = d_iou(pred_box_iou, gt_box_iou)
            reg_loss = torch.mean(giou_loss)

            loss = cls_cof * cls_loss + reg_cof * reg_loss
            loss.backward()

            nn.utils.clip_grad_norm_(model.parameters(), 10)

            optimizer.step()

            wandb.log({
                'cls_loss': cls_loss,
                'reg_loss': reg_loss,
                'total_loss': loss,
                'cls_gt_pos': all_pos_gt,
                'cls_recall': cls_recall,
                'cls_precision': cls_precision,
                'pos_pred': all_pos_pred,
            })

        torch.save(model.module.state_dict(), os.path.join(model_dir, 'tmp_{}.pt'.format(epoch)))
        model.module.load_state_dict(torch.load(os.path.join(model_dir, 'tmp_{}.pt'.format(epoch))))

        acc = evaluate(model, val_loader, box_info, 'Val epoch {}:'.format(epoch),
                       export=os.path.join(model_dir, 'epoch{}'.format(epoch)))
        if acc > best_acc:
            best_acc = acc
            best_epoch = epoch

        if acc < acc_history[-1] < acc_history[-2] < acc_history[-3]:
            break
        else:
            acc_history.append(acc)

        scheduler.step()

    evaluate(model, test_loader, box_info, 'Test final:',
             export=os.path.join(model_dir, 'sol_final'))

    model.module.load_state_dict(torch.load(os.path.join(model_dir, 'tmp_{}.pt'.format(best_epoch))))
    evaluate(model, test_loader, box_info, 'Test epoch {}:'.format(best_epoch),
             export=os.path.join(model_dir, 'sol_samples'))
    torch.save(model.module.state_dict(), os.path.join(model_dir, 'best.pt'))


if __name__ == '__main__':
    wandb.init(project="solution", entity="user", config=config)
    train_loader, val_loader, test_loader, box_info = \
        prepare_dataset(config['source'], True, True, model_dir=config['model_dir'], batch_size=config['batch_size'],
                        export=config['training_data_sample'])
    train(config['model_dir'], epochs=config['epochs'], data=(train_loader, val_loader, test_loader, box_info))
