import csv
import shutil
import wandb

from PIL import Image
import torch
import torchvision
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision.transforms import ToTensor, Resize, Normalize, RandomCrop, RandomRotation
from torch import nn
from torchvision.models import resnet101, resnet50
from torchvision.models.feature_extraction import create_feature_extractor
import os
from tqdm import tqdm
from torch.optim import Adam
import torch.nn.functional as F
import numpy as np
import sys


sys.setrecursionlimit(3000)


class StickerClassificationDataset(Dataset):
    def __init__(self, source):
        reader = csv.DictReader(open(source, 'r'))
        type_to_num = {'background': 0, 'attach': 1, 'individual': 1}

        self._preprocess = torchvision.transforms.Compose([
            Resize((300, 300)),
            RandomCrop((256, 256)),
            RandomRotation(15),
            ToTensor(),
            Normalize(mean=[0.485, 0.456, 0.406, 0.560], std=[0.229, 0.224, 0.225, 0.461])
        ])

        cnt_background, cnt_others = 0, 0
        self._data = []
        for row in reader:
            if row['type'] in type_to_num.keys():
                type_label = type_to_num[row['type']]
                need_mask = 1 if 'mask' in row['meta'] else 0
                need_opacity = 1 if 'opacity' in row['meta'] else 0
                self._data.append((row['url'], type_label, need_mask, need_opacity))
                if row['type'] == 'background':
                    cnt_background += 1
                else:
                    cnt_others += 1
        print("{} samples collected in total, {}/{} are background/other.".format(
            cnt_background + cnt_others, cnt_background, cnt_others
        ))

    def __len__(self):
        return len(self._data)

    def __getitem__(self, item):
        url, stype, mask, opacity = self._data[item]
        img = Image.open(url)
        sticker_has_alpha = 1.0
        if img.mode != 'RGBA':
            img = img.convert('RGBA')
            sticker_has_alpha = 0.0
        img_tensor = self._preprocess(img)
        w, h = img.size
        w_h_ratio = (w ** 2 + h ** 2) / ((w + h) ** 2)
        return img_tensor, (w_h_ratio, sticker_has_alpha), stype, mask, opacity, url


class BackgroundClassifier(nn.Module):

    mask_threshold = 0.5
    opacity_threshold = 0.5
    type_threshold = 0.3

    def __init__(self):
        super().__init__()
        resnet = resnet50(pretrained=True)
        resnet.conv1 = nn.Conv2d(4, 64,
                                 kernel_size=(7, 7), stride=(2, 2), padding=3, bias=False)
        nn.init.kaiming_normal_(resnet.conv1.weight, mode='fan_out', nonlinearity='relu')

        self.img_backbone = create_feature_extractor(resnet, return_nodes=['avgpool'])

        self.head = nn.Sequential(
            nn.Linear(2050, 512),
            nn.ReLU(),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, 3),
        )

        self._preprocess = torchvision.transforms.Compose([
            Resize((256, 256)),
            ToTensor(),
            Normalize(mean=[0.485, 0.456, 0.406, 0.560], std=[0.229, 0.224, 0.225, 0.461])
        ])

    def forward(self, x, ratio, alpha):
        img_feature = self.img_backbone(x)['avgpool'].flatten(1)
        agg_feature = torch.concat((img_feature, ratio, alpha), dim=1)
        pred = torch.sigmoid(self.head(agg_feature))
        return pred[:, 0], pred[:, 1], pred[:, 2]


def eval_single(model, img, device=torch.device('cpu')):
    sticker_has_alpha = 1.0
    _preprocess = torchvision.transforms.Compose([
        Resize((256, 256)),
        ToTensor(),
        Normalize(mean=[0.485, 0.456, 0.406, 0.560], std=[0.229, 0.224, 0.225, 0.461])
    ])

    if img.mode != 'RGBA':
        img = img.convert('RGBA')
        sticker_has_alpha = 0.0
    img_tensor = _preprocess(img).unsqueeze(0)
    img_tensor = img_tensor.to(device)
    w, h = img.size
    w_h_ratio = (1 - min(w, h) / max(w, h)) ** 2
    meta_values = (w_h_ratio, sticker_has_alpha)
    ratio, alpha = (torch.tensor(_, dtype=torch.float).reshape((1, 1)).to(device) for _ in meta_values)
    pred_type, pred_mask, pred_opacity = model(img_tensor, ratio, alpha)

    pred_type, pred_mask, pred_opacity = pred_type[0], pred_mask[0], pred_opacity[0]

    print('Background prediction: {}, {}, {}'.format(pred_type.detach().cpu().numpy(),
                                                     pred_mask.detach().cpu().numpy(),
                                                     pred_opacity.detach().cpu().numpy()))

    if sticker_has_alpha < 0.5:
        pred_mask = 1.

    return pred_type, pred_mask, pred_opacity


def composite(img, sticker, has_mask, has_opacity, foreground_mask):
    if sticker.mode != 'RGBA':
        sticker_array = np.ones((img.size[1], img.size[0], 4), dtype=np.uint8) * 255
        sticker_array[:, :, :3] = np.array(sticker.convert('RGB').resize(img.size))
    else:
        sticker_array = np.array(sticker.resize(img.size))

    opacity = 0.5 if has_opacity else 1.0

    if has_mask:
        foreground_mask = 255 - foreground_mask
        sticker_array[:, :, 3] = np.minimum(sticker_array[:, :, 3], foreground_mask)
        sticker_array[:, :, 3] = (sticker_array[:, :, 3] * opacity).astype(np.uint8)
    else:
        sticker_array[:, :, 3] = (sticker_array[:, :, 3] * opacity).astype(np.uint8)

    sticker = Image.fromarray(sticker_array)
    img.paste(sticker, (0, 0), mask=sticker)
    return img


def prepare_data_sticker(source, batch_size=16):
    dataset = StickerClassificationDataset(source)
    train_len = int(len(dataset) * 0.98)
    val_len = int((len(dataset) - train_len) * 0.5)
    test_len = len(dataset) - train_len - val_len
    train, val, test = random_split(dataset, [train_len, val_len, test_len])

    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=8)
    val_loader = DataLoader(val, batch_size=batch_size, shuffle=True, num_workers=4)
    test_loader = DataLoader(test, batch_size=batch_size, shuffle=True, num_workers=4)
    torch.save(train_loader, 'stickercls_train_loader.pt')
    torch.save(val_loader, 'stickercls_val_loader.pt')
    torch.save(test_loader, 'stickercls_test_loader.pt')


def train(model_dir, resume=None, total_epoch=10):
    if os.path.exists(model_dir):
        shutil.rmtree(model_dir)
    os.mkdir(model_dir)

    train_loader = torch.load('stickercls_train_loader.pt')
    val_loader = torch.load('stickercls_val_loader.pt')
    test_loader = torch.load('stickercls_test_loader.pt')

    device = torch.device('cuda')
    model = BackgroundClassifier()
    print(model)

    if model.training:
        model.eval()
        print(eval_single(model, Image.open('bg.png')))
        model.train()

    if resume is not None:
        model = model.load_state_dict(torch.load(resume))

    model.train()
    model.to(device)

    best_acc, best_epoch = 0.0, 0
    optimizer = Adam(model.parameters(), lr=2e-5)
    for epoch in range(total_epoch):
        for item in tqdm(train_loader):
            optimizer.zero_grad()
            img_tensor, meta, type_gt, mask_gt, opacity_gt, _ = item
            img_tensor = img_tensor.to(device)
            ratio, alpha = (v.unsqueeze(-1).to(device).float() for v in meta)
            type_gt = type_gt.to(device).float()
            mask_gt = mask_gt.to(device)[type_gt == 0].float()
            opacity_gt = opacity_gt.to(device)[type_gt == 0].float()

            type_pred, mask_pred, opacity_pred = model(img_tensor, ratio, alpha)
            type_weight = torch.ones_like(type_gt)
            type_weight[type_gt > 0.5] /= 3
            type_loss = F.binary_cross_entropy(type_pred, type_gt, weight=type_weight)

            type_pred_one_ratio = len(type_pred[type_pred > 0.5].flatten()) / len(type_gt)

            mask_pred = mask_pred[type_gt < 0.5]
            mask_weight = torch.ones_like(mask_gt)
            mask_weight[mask_gt < 0.5] /= 5
            mask_loss = F.binary_cross_entropy(mask_pred, mask_gt, weight=mask_weight)

            opacity_pred = opacity_pred[type_gt < 0.5]
            opacity_weight = torch.ones_like(opacity_gt)
            opacity_weight[opacity_gt < 0.5] /= 5
            opacity_loss = F.binary_cross_entropy(opacity_pred, opacity_gt, weight=opacity_weight)

            loss = type_loss * 3 + mask_loss + opacity_loss

            wandb.log({
                'type_loss': type_loss,
                'mask_loss': mask_loss,
                'opacity_loss': opacity_loss,
                'total_loss': loss,
                '1-ratio in prediction': type_pred_one_ratio,
            })

            loss.backward()
            optimizer.step()

        torch.save(model, os.path.join(model_dir, 'tmp_{}.pt'.format(epoch)))
        model_script = torch.jit.script(model)
        model_script.save(os.path.join(model_dir, 'script_{}.pt'.format(epoch)))
        acc = evaluate(model, val_loader, 'Val epoch {}'.format(epoch), device)
        if acc > best_acc:
            best_acc = acc
            best_epoch = epoch

    model = torch.load(os.path.join(model_dir, 'tmp_{}.pt'.format(best_epoch)))
    torch.save(model.state_dict(), os.path.join(model_dir, 'best.pt'))    
    evaluate(model, test_loader, 'Test with epoch {}'.format(best_epoch), device, export='background')


def evaluate(model, loader, name, device, export=None):
    if export is not None:
        if os.path.exists(export):
            shutil.rmtree(export)
        os.mkdir(export)

    type_stat = np.zeros((2, 2), dtype=int)
    mask_stat = np.zeros((2, 2), dtype=int)
    opa_stat = np.zeros((2, 2), dtype=int)
    acc_type, acc_mask, acc_opa, total_bg = 0, 0, 0, 0
    total = 0

    print('---------------------')
    if model.training:
        model.eval()
        print(eval_single(model, Image.open('bg.png'), device=device))
        model.train()
    print('---------------------')

    with torch.no_grad():
        for item in tqdm(loader):
            img_tensor, meta, type_gt, mask_gt, opacity_gt, urls = item
            img_tensor = img_tensor.to(device)
            ratio, alpha = (v.unsqueeze(-1).to(device).float() for v in meta)
            type_gt = type_gt.to(device)
            mask_gt = mask_gt.numpy()
            opacity_gt = opacity_gt.numpy()

            type_pred, mask_pred, opacity_pred = model(img_tensor, ratio, alpha)
            type_weight = torch.ones_like(type_gt, dtype=torch.float)
            type_weight[type_gt > 0.5] /= 3
            type_loss = F.binary_cross_entropy(type_pred.to(device), type_gt.float(), weight=type_weight)

            type_pred_one_ratio = len(type_pred[type_pred > 0.5].flatten()) / len(type_gt)

            wandb.log({
                'val_type_loss': type_loss,
                'val_1-ratio in prediction': type_pred_one_ratio,
            })

            # print(type_pred)
            for i in range(len(type_pred)):
                stype = 1 if type_pred[i] > model.type_threshold else 0
                type_stat[type_gt[i], stype] += 1
                mask = 1 if mask_pred[i] > model.mask_threshold else 0
                opacity = 1 if opacity_pred[i] > model.opacity_threshold else 0
                if type_gt[i] == stype:
                    acc_type += 1
                if type_gt[i] == 0:
                    total_bg += 1

                    mask_stat[mask_gt[i], mask] += 1
                    opa_stat[opacity_gt[i], opacity] += 1
                    if mask == mask_gt[i]:
                        acc_mask += 1
                    if opacity == opacity_gt[i]:
                        acc_opa += 1

                if export is not None and (type_gt[i] != stype or type_gt[i] == stype == 0):
                    shutil.copy(urls[i],
                                os.path.join(export, '{}_gt={}-{}-{}_pred={}-{}-{}.{}'.format(
                                    total, type_gt[i], mask_gt[i], opacity_gt[i],
                                    stype, mask, opacity, urls[i].split('.')[-1]
                                )))

                total += 1

    print(name, ':')
    print('Type Accuracy: {} ({}/{})'.format(acc_type / total, acc_type, total))
    print(type_stat)
    print('Mask accuracy: {} ({}/{}), Opacity accuracy: {} ({}/{}).'.format(
        acc_mask / total_bg, acc_mask, total_bg,
        acc_opa / total_bg, acc_opa, total_bg))
    print('Mask:\n', mask_stat)
    print('Opacity:\n', opa_stat)
    return acc_type / total * 3 + acc_mask / total_bg + acc_opa / total_bg


if __name__ == '__main__':    
    model = BackgroundClassifier()
    model.load_state_dict(torch.load('background.pt'))
    model.eval().cuda()    
    print(eval_single(model, Image.open('sticker_bg.png'), torch.device('cuda')))
