import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw
import torch
import time
from data import build_val_transform
from datasets.cityscapes import Cityscapes
from model import RegSeg
from competitors_models.DDRNet_Reimplementation import get_ddrnet_23
from train import get_dataset_loaders
import yaml
from data_utils import get_dataloader_val
import torchvision.transforms as T
import torch.cuda.amp as amp

mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
def get_colors(dataset):
    dic={
        "default":get_default_colors,
        "cityscapes":get_colors_cityscapes,
        "cityscapes_labelid":get_colors_cityscapes_labelid,
        "mapillary":get_colors_mapillary,
        "mapillary_reduced":get_colors_mapillary_reduced,
        "camvid":get_colors_camvid,
    }
    return dic[dataset]()
def get_default_colors():
    palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
    colors = torch.arange(255).view(-1, 1) * palette
    colors = (colors % 255).numpy().astype("uint8")
    return colors
def get_colors_cityscapes():
    colors=np.zeros((256,3))
    colors[255]=[255,255,255]
    for c in Cityscapes.classes:
        if 0<=c.train_id<=18:
            colors[c.train_id]=c.color
    return colors.astype("uint8")
def get_colors_cityscapes_labelid():
    colors=np.zeros((256,3))
    colors[255]=[255,255,255]
    for c in Cityscapes.classes:
        colors[c.id]=c.color
    return colors.astype("uint8")
def get_colors_mapillary():
    #colors=[[165, 42, 42], [0, 192, 0], [250, 170, 31], [250, 170, 32], [196, 196, 196], [190, 153, 153], [180, 165, 180], [90, 120, 150], [250, 170, 33], [250, 170, 34], [128, 128, 128], [250, 170, 35], [102, 102, 156], [128, 64, 255], [140, 140, 200], [170, 170, 170], [250, 170, 36], [250, 170, 160], [250, 170, 37], [96, 96, 96], [230, 150, 140], [128, 64, 128], [110, 110, 110], [110, 110, 110], [244, 35, 232], [128, 196, 128], [150, 100, 100], [70, 70, 70], [150, 150, 150], [150, 120, 90], [220, 20, 60], [220, 20, 60], [255, 0, 0], [255, 0, 100], [255, 0, 200], [255, 255, 255], [255, 255, 255], [250, 170, 29], [250, 170, 28], [250, 170, 26], [250, 170, 25], [250, 170, 24], [250, 170, 22], [250, 170, 21], [250, 170, 20], [255, 255, 255], [250, 170, 19], [250, 170, 18], [250, 170, 12], [250, 170, 11], [255, 255, 255], [255, 255, 255], [250, 170, 16], [250, 170, 15], [250, 170, 15], [255, 255, 255], [255, 255, 255], [255, 255, 255], [255, 255, 255], [64, 170, 64], [230, 160, 50], [70, 130, 180], [190, 255, 255], [152, 251, 152], [107, 142, 35], [0, 170, 30], [255, 255, 128], [250, 0, 30], [100, 140, 180], [220, 128, 128], [222, 40, 40], [100, 170, 30], [40, 40, 40], [33, 33, 33], [100, 128, 160], [20, 20, 255], [142, 0, 0], [70, 100, 150], [250, 171, 30], [250, 172, 30], [250, 173, 30], [250, 174, 30], [250, 175, 30], [250, 176, 30], [210, 170, 100], [153, 153, 153], [153, 153, 153], [128, 128, 128], [0, 0, 80], [210, 60, 60], [250, 170, 30], [250, 170, 30], [250, 170, 30], [250, 170, 30], [250, 170, 30], [250, 170, 30], [192, 192, 192], [192, 192, 192], [192, 192, 192], [220, 220, 0], [220, 220, 0], [0, 0, 196], [192, 192, 192], [220, 220, 0], [140, 140, 20], [119, 11, 32], [150, 0, 255], [0, 60, 100], [0, 0, 142], [0, 0, 90], [0, 0, 230], [0, 80, 100], [128, 64, 64], [0, 0, 110], [0, 0, 70], [0, 0, 142], [0, 0, 192], [170, 170, 170], [32, 32, 32], [111, 74, 0], [120, 10, 10], [81, 0, 81], [111, 111, 0], [0, 0, 0]]
    colors=[[165, 42, 42], [0, 192, 0], [196, 196, 196], [190, 153, 153], [180, 165, 180], [90, 120, 150], [102, 102, 156], [128, 64, 255], [140, 140, 200], [170, 170, 170], [250, 170, 160], [96, 96, 96], [230, 150, 140], [128, 64, 128], [110, 110, 110], [244, 35, 232], [150, 100, 100], [70, 70, 70], [150, 120, 90], [220, 20, 60], [255, 0, 0], [255, 0, 100], [255, 0, 200], [200, 128, 128], [255, 255, 255], [64, 170, 64], [230, 160, 50], [70, 130, 180], [190, 255, 255], [152, 251, 152], [107, 142, 35], [0, 170, 30], [255, 255, 128], [250, 0, 30], [100, 140, 180], [220, 220, 220], [220, 128, 128], [222, 40, 40], [100, 170, 30], [40, 40, 40], [33, 33, 33], [100, 128, 160], [142, 0, 0], [70, 100, 150], [210, 170, 100], [153, 153, 153], [128, 128, 128], [0, 0, 80], [250, 170, 30], [192, 192, 192], [220, 220, 0], [140, 140, 20], [119, 11, 32], [150, 0, 255], [0, 60, 100], [0, 0, 142], [0, 0, 90], [0, 0, 230], [0, 80, 100], [128, 64, 64], [0, 0, 110], [0, 0, 70], [0, 0, 192], [32, 32, 32], [120, 10, 10], [0, 0, 0]]
    colors=np.array(colors).astype("uint8")
    return colors
def get_colors_mapillary_reduced():
    colors=[[165, 42, 42], [0, 192, 0], [196, 196, 196], [190, 153, 153], [180, 165, 180], [90, 120, 150], [102, 102, 156], [128, 64, 255], [140, 140, 200], [170, 170, 170], [250, 170, 160], [96, 96, 96], [230, 150, 140], [128, 64, 128], [110, 110, 110], [244, 35, 232], [150, 100, 100], [70, 70, 70], [150, 120, 90], [220, 20, 60], [255, 0, 0], [255, 0, 100], [255, 0, 200], [200, 128, 128], [255, 255, 255], [64, 170, 64], [230, 160, 50], [70, 130, 180], [190, 255, 255], [152, 251, 152], [107, 142, 35], [0, 170, 30], [255, 255, 128], [250, 0, 30], [100, 140, 180], [220, 220, 220], [220, 128, 128], [222, 40, 40], [100, 170, 30], [40, 40, 40], [33, 33, 33], [100, 128, 160], [142, 0, 0], [70, 100, 150], [210, 170, 100], [153, 153, 153], [128, 128, 128], [0, 0, 80], [250, 170, 30], [192, 192, 192], [220, 220, 0], [140, 140, 20], [119, 11, 32], [150, 0, 255], [0, 60, 100], [0, 0, 142], [0, 0, 90], [0, 0, 230], [0, 80, 100], [128, 64, 64], [0, 0, 110], [0, 0, 70], [0, 0, 192], [32, 32, 32], [120, 10, 10], [0, 0, 0]]
    colors=np.array(colors).astype("uint8")
    ious=[0.0, 0.0, 57.68, 58.66, 63.16, 56.59, 50.9, 45.04, 39.82, 18.31, 22.22, 45.77, 49.91, 87.97, 43.31, 70.61, 76.67, 86.43, 41.84, 66.81, 46.77, 50.41, 0.0, 69.53, 57.07, 48.28, 4.99, 97.77, 76.83, 68.69, 88.77, 72.93, 17.02, 22.26, 5.05, 45.31, 29.76, 0.0, 20.38, 36.26, 2.43, 43.12, 4.4, 0.0, 37.03, 40.48, 52.69, 44.16, 60.96, 36.36, 66.61, 43.94, 47.44, 16.69, 73.89, 89.68, 0.0, 55.74, 46.28, 22.28, 6.71, 67.39, 8.41, 68.79, 91.75,0]
    ious=np.array(ious)
    colors=colors[ious>30]
    all_colors=np.zeros((256,3)).astype("uint8")
    all_colors[:len(colors)]=colors
    return all_colors
def get_colors_camvid():
    colors=np.zeros((256,3))
    colors[255]=[255,255,255]
    color_to_class={(0, 128, 192): 10, (128, 0, 0): 0, (64, 0, 128): 3, (192, 192, 128): 8, (64, 64, 128): 7, (64, 64, 0): 6, (128, 64, 128): 5, (0, 0, 192): 9, (192, 128, 128): 4, (128, 128, 128): 2, (128, 128, 0): 1}
    for color,cls in color_to_class.items():
        colors[cls]=color
    return colors.astype("uint8")

def get_mask_func(colors):
    def show_mask(images,colors):
        r = Image.fromarray(images.byte().cpu().numpy())
        r.putpalette(colors)
        return r
    mask_func=lambda images: show_mask(images,colors)
    return mask_func

def transform_image(image):
    """Imshow for Tensor."""
    image = image.numpy().transpose((1, 2, 0))
    image = std * image + mean
    image = np.clip(image, 0, 1)
    return image

def add_grid_markings(image, gridsize=16):
    if type(image) != Image.Image:
        image = Image.fromarray((image * 255).astype(np.uint8))
    draw=ImageDraw.Draw(image)
    w, h = image.size
    for i in range(1,w//gridsize):
        p1=(gridsize*i,0)
        p2=(gridsize*i,h)
        draw.line([p1, p2], width=2, fill=0)
    for j in range(1, h // gridsize):
        p1 = (0,gridsize*j)
        p2 = (w,gridsize*j)
        draw.line([p1, p2], width=2, fill=0)
    return image

def display_image_grid(images, images_per_line, save_name=None):
    images_so_far = 0
    fig = plt.figure(figsize=(images_per_line*2, images_per_line))
    num_images=len(images)
    num_rows=int(np.ceil(num_images/images_per_line))
    for image in images:
        plt.subplot(num_rows, images_per_line, images_so_far + 1)
        plt.axis('off')
        plt.imshow(image)
        images_so_far+=1
    plt.tight_layout()
    if save_name:
        plt.savefig(save_name)
    plt.show()

def display(data_loader,show_mask,num_images=5,skip=4,images_per_line=6,must_contain=None):
    def contains(target):
        if not must_contain:
            return True
        for c in must_contain:
            if c in target:
                return True
        return False
    counter=0
    data_loader = iter(data_loader)
    for _ in range(skip):
        next(data_loader)
    all_images=[]
    for images, targets in data_loader:
        for image, target in zip(images, targets):
            if not contains(target):
                continue
            print(image.size(), target.size())
            image=transform_image(image)
            target=show_mask(target)
            all_images.append(image)
            all_images.append(target)
            counter+=1
            if counter==num_images:
                display_image_grid(all_images, images_per_line * 2)
                return
    display_image_grid(all_images, images_per_line * 2)

@torch.no_grad()
def show(model,data_loader,device,show_mask,num_images=5,skip=4,images_per_line=2,mixed_precision=False):
    model.eval()
    model.to(device)
    data_loader=iter(data_loader)
    for _ in range(skip):
        next(data_loader)
    all_images=[]
    counter=0
    for images, targets in data_loader:
        images, targets = images.to(device), targets.to(device)
        start=time.time()
        if torch.cuda.is_available():
            with amp.autocast(enabled=mixed_precision):
                outputs = model(images)
        else:
            outputs = model(images)
        outputs=torch.nn.functional.interpolate(
            outputs, size=targets.shape[-2:], mode='bilinear', align_corners=False)
        outputs=outputs.argmax(1)
        end=time.time()
        print(end-start)
        outputs=outputs.cpu()
        images=images.cpu()
        targets=targets.cpu()
        for image,target,output in zip(images,targets,outputs):
            print(image.size(),target.size(),output.size())
            image=transform_image(image)
            target=show_mask(target)
            output=show_mask(output)
            all_images.append(image)
            all_images.append(target)
            all_images.append(output)

            counter+=1
            if counter==num_images:
                display_image_grid(all_images, images_per_line * 3)
                return
    display_image_grid(all_images, images_per_line * 3)


def open_image(filename, size):
    image = Image.open(filename).convert("RGB")
    preprocess = T.Compose([
        T.Resize(size),
        T.ToTensor(),
        T.Normalize(mean=mean, std=std),
    ])
    input_tensor = preprocess(image)
    images = input_tensor.unsqueeze(0)
    return images
@torch.no_grad()
def show_files(model,files,device,show_mask,images_per_line=2,mixed_precision=False,size=1024,save_name=None,gridsize=None):
    model.eval()
    model.to(device)
    all_images=[]
    for filename in files:
        images=open_image(filename, size)
        images= images.to(device)
        start=time.time()
        if torch.cuda.is_available():
            with amp.autocast(enabled=mixed_precision):
                outputs = model(images)
        else:
            outputs = model(images)
        outputs=torch.nn.functional.interpolate(outputs, size=[1024,2048], mode='bilinear', align_corners=False)
        outputs=outputs.argmax(1)
        end=time.time()
        print(end-start)
        outputs=outputs.cpu()
        images=images.cpu()
        image=images[0]
        output=outputs[0]
        output=show_mask(output)
        image = transform_image(image)
        print(type(image))
        print(type(output))
        if gridsize:
            image=add_grid_markings(image,gridsize)
            output=add_grid_markings(output,gridsize)
        all_images.append(image)
        all_images.append(output)
    display_image_grid(all_images, images_per_line * 2, save_name)
def get_dataset_config(dataset):
    config_filenames={
        "cityscapes":"configs/cityscapes_500epochs.yaml",
        "camvid":"configs/camvid_200epochs.yaml",
        "mapillary":"configs/mapillary_180epochs.yaml",
        "coco":"configs/coco_100epochs.yaml",
        "synthetic":"configs/synthetic.yaml"
    }
    config_filename=config_filenames[dataset]
    with open(config_filename) as file:
        config=yaml.full_load(file)
    return config

def show_cityscapes_failure_modes():
    from data import build_train_transform
    num_images=8
    images_per_line=2
    skip=16
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    model=RegSeg(
        name="L6_decoder26",
        num_classes=19,
        pretrained="checkpoints/L6_decoder26_1000_epochs_run2"
    ).eval()
    model=RegSeg(
        name="exp48_decoder26",
        num_classes=19,
        pretrained="checkpoints/cityscapes_exp48_decoder26_trainval_1000_epochs_1024_crop_bootstrapped_run1"
    ).eval()
    model = RegSeg(
        name="exp55_decoder26",
        num_classes=19,
        pretrained="checkpoints/exp55_decoder26_1000_epochs_run1"
    ).eval()
    train_transform=build_train_transform(1024, 1024, [1024,2048], "randaug_reduced",255)
    val_transform=build_val_transform(1024,1024)
    #ce loss
    indices=[319, 54, 263, 230, 318, 303, 297, 276, 312, 161, 310, 262, 360, 291, 228, 84, 121, 301, 101, 267, 285, 278, 231, 317, 125, 59, 290, 144, 27, 37]
    # ce loss
    # indices=[54, 303, 297, 360, 291, 121, 301, 285, 231, 317, 290, 27, 37]
    # weighted ce loss
    # indices=[220, 301, 26, 404, 38, 21, 195, 366, 98, 421, 250, 138, 258]
    val = Cityscapes("cityscapes_dataset", split="val", target_type="semantic",
                     transforms=val_transform, class_uniform_pct=0)
    val=torch.utils.data.Subset(val,indices)
    val_loader = get_dataloader_val(val, 0)
    colors=get_colors("cityscapes")
    mask_func=get_mask_func(colors)
    show(model,val_loader,device,mask_func,num_images=num_images,skip=skip,images_per_line=images_per_line)

def colorize_test_submission():
    import os
    colors=get_colors_cityscapes_labelid()
    load_names=[]
    directory= "exp55_decoder26_run1_submission"
    out_dir="temp_dir"
    save_names=[]
    for filename in os.listdir(directory):
        load_name=os.path.join(directory, filename)
        load_names.append(load_name)
        save_name=os.path.join(out_dir, filename)
        save_names.append(save_name)
    print(len(load_names))
    for i,(load_name,save_name) in enumerate(zip(load_names,save_names)):
        image=Image.open(load_name).convert("L")
        image.putpalette(colors)
        image.save(save_name)
        if (i+1)%100==0:
            print(i)
def convert_filenames():
    files=["berlin/berlin_000129_000019_leftImg8bit.png",
           "berlin/berlin_000019_000019_leftImg8bit.png",
           "berlin/berlin_000035_000019_leftImg8bit.png",
           "berlin/berlin_000043_000019_leftImg8bit.png",
           "berlin/berlin_000087_000019_leftImg8bit.png",
           "berlin/berlin_000117_000019_leftImg8bit.png",
           ]
    files=[
        "berlin_000097_000019_leftImg8bit_pred",
        'berlin_000120_000019_leftImg8bit_pred',
        "berlin_000129_000019_leftImg8bit_pred",
        "berlin_000138_000019_leftImg8bit_pred",
        "berlin_000156_000019_leftImg8bit_pred",
        "berlin_000191_000019_leftImg8bit_pred",
        "berlin_000203_000019_leftImg8bit_pred",
        "berlin_000422_000019_leftImg8bit_pred",
        "bielefeld_000000_031244_leftImg8bit_pred",
        "bielefeld_000000_046495_leftImg8bit_pred",
        "mainz_000001_040195_leftImg8bit_pred",
        "munich_000033_000019_leftImg8bit_pred",
        "munich_000206_000019_leftImg8bit_pred",
        "munich_000203_000019_leftImg8bit_pred",
        "munich_000235_000019_leftImg8bit_pred",
        "munich_000252_000019_leftImg8bit_pred",
        "munich_000261_000019_leftImg8bit_pred",
        "munich_000396_000019_leftImg8bit_pred",
    ]
    new_files=[]
    import os
    for file in files:
        s=file.split("_")
        city=s[0]
        s=os.path.join(s[0],"_".join(s[:-1])+".png")
        new_files.append(s)
    print(new_files)

def show_model():
    import random
    seed=0
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    # model=RegSeg(
    #     name="exp48_decoder26",
    #     num_classes=3,
    #     pretrained="checkpoints/synthetic_exp48_decoder26_run1"
    # ).eval()
    model=get_ddrnet_23(3)
    dic = torch.load("checkpoints/synthetic_ddrnet23_run1", map_location='cpu')
    if type(dic) == dict and "model" in dic:
        dic = dic['model']
    model.load_state_dict(dic, strict=True)
    num_images=6
    images_per_line=2
    skip=0
    config=get_dataset_config("synthetic")
    config["num_workers"]=0
    config["batch_size"]=1
    train_loader,val_loader,train_set=get_dataset_loaders(config)
    colors=get_colors("default")
    mask_func=get_mask_func(colors)
    show(model,val_loader,device,mask_func,num_images=num_images,skip=skip,images_per_line=images_per_line)
def show_cityscapes_test():
    import os
    from competitors_models.DDRNet_Reimplementation import get_ddrnet_23
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    size=1024
    # model=RegSeg(
    #     name="exp65_decoder26",
    #     num_classes=19,
    #     pretrained="checkpoints/exp65_decoder26_400_epochs_trainval_map_pretrained_run2"
    # ).eval()
    model = RegSeg(
        name="exp48_decoder26",
        num_classes=19,
        pretrained="checkpoints/cityscapes_exp48_decoder26_trainval_1000_epochs_1024_crop_bootstrapped_run1"
    ).eval()
    # files=['berlin/berlin_000097_000019_leftImg8bit.png', 'berlin/berlin_000120_000019_leftImg8bit.png', 'berlin/berlin_000129_000019_leftImg8bit.png', 'berlin/berlin_000138_000019_leftImg8bit.png', 'berlin/berlin_000156_000019_leftImg8bit.png', 'berlin/berlin_000191_000019_leftImg8bit.png', 'berlin/berlin_000203_000019_leftImg8bit.png', 'berlin/berlin_000422_000019_leftImg8bit.png', 'bielefeld/bielefeld_000000_031244_leftImg8bit.png', 'bielefeld/bielefeld_000000_046495_leftImg8bit.png', 'mainz/mainz_000001_040195_leftImg8bit.png', 'munich/munich_000033_000019_leftImg8bit.png', 'munich/munich_000206_000019_leftImg8bit.png', 'munich/munich_000203_000019_leftImg8bit.png', 'munich/munich_000235_000019_leftImg8bit.png', 'munich/munich_000252_000019_leftImg8bit.png', 'munich/munich_000261_000019_leftImg8bit.png', 'munich/munich_000396_000019_leftImg8bit.png']
    files = ['bielefeld/bielefeld_000000_046495_leftImg8bit.png',
             'mainz/mainz_000001_040195_leftImg8bit.png',
             ]
    new_files=[]
    for file in files:
        file=os.path.join("cityscapes_dataset/leftImg8bit/test",file)
        new_files.append(file)
    # new_files=new_files[:4]
    images_per_line=1
    colors=get_colors("cityscapes")
    mask_func=get_mask_func(colors)
    save_name="model_outputs/exp65_decoder26_trainval_map_pretrained.pdf"
    show_files(model,new_files,device,mask_func,images_per_line=images_per_line,size=size,save_name=save_name,gridsize=16)

def display_dataset():
    import random
    num_images=8
    images_per_line=2
    skip=8
    seed=0
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    config=get_dataset_config("mapillary")
    config["num_workers"]=0
    config["mapillary_reduced"]=False
    config["batch_size"]=1
    must_contain=None
    train_loader,val_loader,train_set=get_dataset_loaders(config)
    colors=get_colors("mapillary")
    mask_func=get_mask_func(colors)
    display(val_loader,mask_func,num_images,skip,images_per_line,must_contain)
if __name__=="__main__":
    show_cityscapes_failure_modes()
    # show_cityscapes_test()
    # convert_filenames()
    # display_dataset()
    # show_model()
