import transforms as T
from data_utils import *
from datasets.cityscapes import Cityscapes
from datasets.camvid import Camvid
from datasets.voc12 import Voc12Segmentation
from datasets.coco import Coco
from datasets.mapillary import Mapillary
from datasets.custom_dataset import SegmentationDataset

def build_val_transform(val_input_size,val_label_size):
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)
    transforms=[]
    transforms.append(
        T.ValResize(val_input_size,val_label_size)
    )
    transforms.append(T.ToTensor())
    transforms.append(T.Normalize(
        mean,
        std
    ))
    return T.Compose(transforms)
def build_train_transform(train_min_size, train_max_size, train_crop_size, aug_mode,ignore_value):
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)
    fill = tuple([int(v * 255) for v in mean])
    #ignore_value = 255
    resize_mode="uniform"
    transforms = []
    transforms.append(
        T.RandomResize(train_min_size, train_max_size, resize_mode)
    )
    if isinstance(train_crop_size,int):
        crop_h,crop_w=train_crop_size,train_crop_size
    else:
        crop_h,crop_w=train_crop_size
    transforms.append(
        T.RandomCrop(crop_h,crop_w)
    )
    transforms.append(T.RandomHorizontalFlip(0.5))
    if aug_mode == "baseline":
        pass
    elif aug_mode=="randaug_reduced":
        transforms.append(T.RandAugment(2, 0.2, "reduced",prob=1.0, fill=fill,
                                        ignore_value=ignore_value))
    elif aug_mode=="randaug_reduced2":
        transforms.append(T.RandAugment(2, 0.3, "reduced2",prob=1.0, fill=fill,
                                        ignore_value=ignore_value))
    elif aug_mode== "colour_jitter":
        transforms.append(T.ColorJitter(0.4, 0.4,0.4, 0,prob=1))
    elif aug_mode=="rotate":
        transforms.append(T.RandomRotation((-10,10), mean=fill, ignore_value=ignore_value,prob=1.0,expand=False))
    elif aug_mode=="noise":
        transforms.append(T.AddNoise(15,prob=1.0))
    elif aug_mode == "custom1":
        transforms.append(T.RandAugment(2, 0.2, "reduced",prob=1.0, fill=fill,
                                        ignore_value=ignore_value))
        transforms.append(T.AddNoise(10,prob=0.2))
    elif aug_mode=="rotate5":
        transforms.append(T.RandomRotation((5,5), mean=fill, ignore_value=ignore_value,prob=1.0,expand=False))
    elif aug_mode=="rotate10":
        transforms.append(T.RandomRotation((10,10), mean=fill, ignore_value=ignore_value,prob=1.0,expand=False))
    elif aug_mode=="rotate15":
        transforms.append(T.RandomRotation((15,15), mean=fill, ignore_value=ignore_value,prob=1.0,expand=False))
    else:
        raise NotImplementedError()
    transforms.append(T.RandomPad(crop_h,crop_w,fill,ignore_value,random_pad=True))
    transforms.append(T.ToTensor())
    transforms.append(T.Normalize(
        mean,
        std
    ))
    return T.Compose(transforms)

def get_custom(root, batch_size, train_min_size, train_max_size, train_crop_size, val_input_size,val_label_size, aug_mode,train_split,val_split,num_workers,ignore_value):
    train_transform = build_train_transform(train_min_size, train_max_size,
                                            train_crop_size, aug_mode,
                                            ignore_value)
    val_transform = build_val_transform(val_input_size, val_label_size)
    train=SegmentationDataset(root,train_split,train_transform)
    val=SegmentationDataset(root,val_split,val_transform)
    train_loader = get_dataloader_train(train, batch_size, num_workers)
    val_loader = get_dataloader_val(val, num_workers)
    return train_loader, val_loader, train

def get_cityscapes(root, batch_size, train_min_size, train_max_size, train_crop_size, val_input_size,val_label_size, aug_mode,class_uniform_pct,train_split,val_split,num_workers,ignore_value):
    #assert(boost_rare in [True,False])

    train_transform=build_train_transform(train_min_size, train_max_size, train_crop_size, aug_mode, ignore_value)
    val_transform=build_val_transform(val_input_size,val_label_size)
    train = Cityscapes(root, split=train_split, target_type="semantic",
                       transforms=train_transform, class_uniform_pct=class_uniform_pct)
    val = Cityscapes(root, split=val_split, target_type="semantic",
                     transforms=val_transform, class_uniform_pct=class_uniform_pct)
    train_loader = get_dataloader_train(train, batch_size, num_workers)
    val_loader = get_dataloader_val(val, num_workers)
    return train_loader, val_loader,train
def get_camvid(root, batch_size, train_min_size, train_max_size, train_crop_size, val_input_size,val_label_size, aug_mode,train_split,val_split,num_workers,ignore_value):
    train_transform=build_train_transform(train_min_size, train_max_size, train_crop_size, aug_mode, ignore_value)
    val_transform=build_val_transform(val_input_size,val_label_size)
    train=Camvid(root,train_split,transforms=train_transform)
    val=Camvid(root,val_split,transforms=val_transform)
    train_loader = get_dataloader_train(train, batch_size, num_workers)
    val_loader = get_dataloader_val(val, num_workers)
    return train_loader, val_loader,train

def get_coco(root, batch_size, train_min_size, train_max_size, train_crop_size, val_input_size, val_label_size, aug_mode, num_workers, ignore_value):
    train_transform=build_train_transform(train_min_size, train_max_size, train_crop_size, aug_mode, ignore_value)
    val_transform=build_val_transform(val_input_size,val_label_size)
    train = Coco(root, "train",train_transform)
    val = Coco(root, "val",val_transform)
    train_loader = get_dataloader_train(train, batch_size, num_workers)
    val_loader = get_dataloader_val(val, num_workers)
    return train_loader, val_loader,train
def get_mapillary(root, batch_size, train_min_size, train_max_size, train_crop_size, val_input_size,val_label_size, aug_mode, num_workers,ignore_value,reduced):
    train_transform=build_train_transform(train_min_size, train_max_size, train_crop_size, aug_mode, ignore_value)
    val_transform=build_val_transform(val_input_size,val_label_size)

    train=Mapillary(root,"train",train_transform,reduced,version="v1.2")
    val=Mapillary(root,"val",val_transform,reduced,version="v1.2")
    train_loader = get_dataloader_train(train, batch_size, num_workers)
    val_loader = get_dataloader_val(val, num_workers)
    return train_loader, val_loader,train

def get_pascal_voc(root, batch_size, train_min_size, train_max_size, train_crop_size, val_input_size,val_label_size, aug_mode, num_workers,ignore_value):
    train_transform=build_train_transform(train_min_size, train_max_size, train_crop_size, aug_mode, ignore_value)
    val_transform=build_val_transform(val_input_size,val_label_size)
    download=False
    train = Voc12Segmentation(root, 'train_aug',train_transform,download)
    val = Voc12Segmentation(root, 'val',val_transform,download)
    train_loader = get_dataloader_train(train, batch_size, num_workers)
    val_loader = get_dataloader_val(val, num_workers)
    return train_loader, val_loader

def count_class_nums(data_loader,num_classes):
    class_counts=[0 for _ in range(num_classes)]
    for t,(image,target) in enumerate(data_loader):
        for i in range(num_classes):
            if i in target:
                class_counts[i]+=1
        if (t+1)%100==0:
            print(f"{t+1} done.")
    print(class_counts)

def find_class_weights(dataloader,num_classes,max_iter=300,print_every=50):
    class_weights=torch.zeros(num_classes)
    for count,(image,target) in enumerate(dataloader):
        class_weights+=torch.bincount(target[target<num_classes],minlength=num_classes)
        if (count+1) % print_every==0:
            print(f"{count+1} done")
        if count==max_iter:
            break
    n = class_weights.sum().item()
    class_weights = [n / (num_classes * w.item()) if w != 0 else 0 for w in class_weights]
    return class_weights
