import os
import numpy as np
import torch
from tqdm import tqdm
from datasets.mvtec import MVTecDataset
from utils.util import  readYamlConfig,computeAUROC,loadWeights
from utils.functions import cal_anomaly_maps
from models.teacherTimm import teacherTimm
from models.model import singleNet
from utils.fourierFilter import fourierTransformST


class NetTester:          
    def __init__(self, data,device):  
        self.device = device
        self.data_path = data['data_path']
        self.obj = data['obj']
        self.img_resize = data['TrainingData']['img_size']
        self.img_cropsize = data['TrainingData']['crop_size']
        self.model_dir = "bestWeights/"+self.obj
        os.makedirs(self.model_dir, exist_ok=True)
        self.modelName = data['backbone']
        self.outIndices = data['out_indice']    
        self.load_model()


    def load_model(self):
        print("loading and training SingleNet")
        modelTeacherName=self.modelName
        self.net=singleNet().to(self.device)
        self.net=loadWeights(self.net,self.model_dir,"best.pth")
        
        self.model_t=teacherTimm(backbone_name=modelTeacherName,out_indices=self.outIndices).to(self.device)
        self.model_t.eval()
        for param in self.model_t.parameters():
            param.requires_grad = False


    @torch.no_grad()
    def test(self):

        kwargs = (
            {"num_workers": 1, "pin_memory": True} if torch.cuda.is_available() else {} 
        )
        test_dataset = MVTecDataset(
            root_dir=self.data_path+"/"+self.obj+"/test/",
            resize_shape=[self.img_resize,self.img_resize],
            crop_size=[self.img_cropsize,self.img_cropsize],
            phase='test'
        )
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, **kwargs)
        scores = []
        test_imgs = []
        gt_list = []
        progressBar = tqdm(test_loader)
        for sample in test_loader:
            label=sample['has_anomaly']
            image = sample['imageBase'].to(self.device)
            test_imgs.extend(image.cpu().numpy())
            gt_list.extend(label.cpu().numpy())
            with torch.set_grad_enabled(False):
                features_s, features_t = self.infer(image)  
                
                features_t=fourierTransformST(features_t,cutoff=10)
                features_s=fourierTransformST(features_s,cutoff=10)
                
                score =cal_anomaly_maps(features_s,features_t,self.img_cropsize) 
                
                progressBar.update()
                
            scores.append(score)

        progressBar.close()
        scores = np.asarray(scores)
        gt_list = np.asarray(gt_list)
        img_roc_auc,_=computeAUROC(scores,gt_list,self.obj,"singleNet")

        return img_roc_auc
    
    def infer(self, img):
        features_t = self.model_t(img)
        features_s=self.net(features_t)
        return features_s,features_t



import sys  
if __name__ == "__main__":
    if len(sys.argv) ==1:
        print("default to MVTEC AD textures")
        dataset="--mvtec"
    else:
        dataset=sys.argv[1]
    
    if dataset=="--mvtec":
        textures=['tile','wood','grid','carpet','leather']
    if dataset=="--tilda":
        textures=['tilda1','tilda2','tilda3','tilda4','tilda5','tilda6','tilda7','tilda8']
    if dataset=="--aitex":
        textures=['aitex1','aitex2','aitex3','aitex4']
    if dataset=="--dagm":
        textures=['dagm1','dagm2','dagm3','dagm4','dagm5','dagm6']
        
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    data=readYamlConfig("config.yaml")
    for texture in textures: 
        auroc=0
        data['obj']=texture
        distill = NetTester(data,device)
        distill.test() 

