import os
import time
import numpy as np
import torch
from tqdm import tqdm
from datasets.mvtec import MVTecDataset
from utils.util import  AverageMeter,readYamlConfig,computeAUROC,loadWeights
from utils.functions import (
    cal_loss,
    cal_anomaly_maps,
)
from models.teacherTimm import teacherTimm
from models.model import singleNet
from utils.fourierFilter import fourierTransformST



class NetTrainer:          
    def __init__(self, data,device):  
        self.device = device
        self.validation_ratio = 0.2
        self.data_path = data['data_path']
        self.obj = data['obj']
        self.img_resize = data['TrainingData']['img_size']
        self.img_cropsize = data['TrainingData']['crop_size']
        self.num_epochs = data['TrainingData']['epochs']
        self.lr = data['TrainingData']['lr']
        self.batch_size = data['TrainingData']['batch_size']   
        self.model_dir = data['save_path'] + "/models" + "/" + self.obj
        os.makedirs(self.model_dir, exist_ok=True)
        self.modelName = data['backbone']
        self.outIndices = data['out_indice']
                        
        self.load_model()
        self.load_dataset()
        
        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=self.lr, betas=(0.5, 0.999)) 
        
        self.scheduler = torch.optim.lr_scheduler.OneCycleLR(self.optimizer,max_lr=self.lr*10,epochs=self.num_epochs,steps_per_epoch=len(self.train_loader))
        

    def load_dataset(self):
        kwargs = (
            {"num_workers": 8, "pin_memory": True} if torch.cuda.is_available() else {}
        )
        train_dataset = MVTecDataset(
            root_dir=self.data_path+"/"+self.obj+"/train/good",
            resize_shape=[self.img_resize,self.img_resize],
            crop_size=[self.img_cropsize,self.img_cropsize],
            phase='train'
        )
        img_nums = len(train_dataset)
        valid_num = int(img_nums * self.validation_ratio)
        train_num = img_nums - valid_num
        train_data, val_data = torch.utils.data.random_split(train_dataset, [train_num, valid_num])
        self.train_loader = torch.utils.data.DataLoader(train_data, batch_size=self.batch_size, shuffle=True, **kwargs)
        self.val_loader = torch.utils.data.DataLoader(val_data, batch_size=8, shuffle=False, **kwargs)

    def load_model(self):
        print("loading and training SingleNet")
        modelTeacherName=self.modelName
        self.net=singleNet().to(self.device)
        
        self.model_t=teacherTimm(backbone_name=modelTeacherName,out_indices=self.outIndices).to(self.device)
        self.model_t.eval()
        for param in self.model_t.parameters():
            param.requires_grad = False

    def train(self):
        print("training " + self.obj)
        self.net.train()
                
        best_score = None
        start_time = time.time()
        epoch_time = AverageMeter()
        epoch_bar = tqdm(total=len(self.train_loader) * self.num_epochs,desc="Training",unit="batch")
        
        for _ in range(1, self.num_epochs + 1):
            losses = AverageMeter()
            for sample in self.train_loader:
                image = sample['imageBase'].to(self.device)
                self.optimizer.zero_grad()
                with torch.set_grad_enabled(True):

                    features_s,features_t  = self.infer(image) 
                    
                    loss=cal_loss(features_s, features_t)
                    
                    losses.update(loss.sum().item(), image.size(0))
                    loss.backward()
                    self.optimizer.step()
                    self.scheduler.step()
                epoch_bar.set_postfix({"Loss": loss.item()})
                epoch_bar.update()
            
            val_loss = self.val(epoch_bar)
            if best_score is None:
                best_score = val_loss
                self.save_checkpoint()
            elif val_loss < best_score:
                best_score = val_loss
                self.save_checkpoint()

            epoch_time.update(time.time() - start_time)
            start_time = time.time()
        epoch_bar.close()
        
        print("Training end.")

    def val(self, epoch_bar):
        self.net.eval()
        losses = AverageMeter()
        for sample in self.val_loader: 
            image = sample['imageBase'].to(self.device)
            with torch.set_grad_enabled(False):
                
                features_s,features_t  = self.infer(image)  

                loss=cal_loss(features_s, features_t)
                
                losses.update(loss.item(), image.size(0))
        epoch_bar.set_postfix({"Loss": loss.item()})

        return losses.avg

    def save_checkpoint(self):
        state = {"model": self.net.state_dict()}
        torch.save(state, os.path.join(self.model_dir, "singleNet.pth"))


    @torch.no_grad()
    def test(self):

        self.net=loadWeights(self.net,self.model_dir,"singleNet.pth")
        
        kwargs = (
            {"num_workers": 1, "pin_memory": True} if torch.cuda.is_available() else {} 
        )
        test_dataset = MVTecDataset(
            root_dir=self.data_path+"/"+self.obj+"/test/",
            resize_shape=[self.img_resize,self.img_resize],
            crop_size=[self.img_cropsize,self.img_cropsize],
            phase='test'
        )
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, **kwargs)
        scores = []
        test_imgs = []
        gt_list = []
        progressBar = tqdm(test_loader)
        for sample in test_loader:
            label=sample['has_anomaly']
            image = sample['imageBase'].to(self.device)
            test_imgs.extend(image.cpu().numpy())
            gt_list.extend(label.cpu().numpy())
            with torch.set_grad_enabled(False):
                features_s, features_t = self.infer(image)  
                
                features_t=fourierTransformST(features_t,cutoff=10)
                features_s=fourierTransformST(features_s,cutoff=10)
                
                score =cal_anomaly_maps(features_s,features_t,self.img_cropsize) 
                
                progressBar.update()
                
            scores.append(score)

        progressBar.close()
        scores = np.asarray(scores)
        gt_list = np.asarray(gt_list)
        img_roc_auc,_=computeAUROC(scores,gt_list,self.obj,"singleNet")

        return img_roc_auc
    
    def infer(self, img):
        features_t = self.model_t(img)
        features_s=self.net(features_t)
        return features_s,features_t

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    data=readYamlConfig("config.yaml")
    distill = NetTrainer(data,device)
     
    if data['phase'] == "train":
        distill.train()
        distill.test()
    elif data['phase'] == "test":
        distill.test()
    else:
        print("Phase argument must be train or test.")

