import os
import numpy as np
import pandas as pd

trf = lambda x, i : x[np.argmax(x,1)==i] # select the subclass given the index of the argmax

softmax = lambda x : np.exp(x)/np.sum(np.exp(x), axis = 1, keepdims = True)
pred_class = lambda x: np.argmax(x, axis=1) # predicted class

ene_score = lambda x : np.log(np.sum(np.exp(x), axis=1))  #energy score
max_score = lambda x : np.max(x, axis=1) # max_logit score
soft_score = lambda x : np.max(softmax(x), axis = 1)


def threshold_tpr95(data): 
    tpr = 95
    parts = int(tpr/(100-tpr)) + 1
    N = data.size

    th = int(N/parts)   
    tmp_sorted = np.sort(data)
    thresholds = tmp_sorted[th]

    return thresholds


RED = '\033[91m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'
END = '\033[0m'

# Hyperparameters precomputed through grid-search 
from pyod.models.knn import KNN
from sklearn.svm import OneClassSVM
from sklearn.preprocessing import StandardScaler

cache_size = 4000
scorer_ocsvm = {}
scorer_ocsvm['cifar10_wrn'] = OneClassSVM ( kernel='poly', gamma=1.0, nu=0.1, cache_size = cache_size)
scorer_ocsvm['cifar100_wrn'] = OneClassSVM ( kernel='poly', gamma=1.0, nu=0.1, cache_size = cache_size/5)
scorer_ocsvm['cifar10_densenet'] = OneClassSVM ( kernel='poly', gamma=1.0, nu=0.4, cache_size = cache_size)  
scorer_ocsvm['cifar100_densenet'] = OneClassSVM ( kernel='poly', gamma=1.0, nu=0.4, cache_size = cache_size/5) 
scorer_ocsvm['svhn_wrn']  = OneClassSVM(kernel='poly', gamma=1.0, nu= 0.8, cache_size = cache_size) #, coef0 = 0.01
scorer_ocsvm['gtsrb_alexnet']  = OneClassSVM(kernel='poly', gamma=1.0, nu= 0.8, cache_size = cache_size) 

scorer_knn = {}
scorer_knn['cifar10_wrn'] = KNN(contamination=0.05, n_neighbors=4, metric='braycurtis', method='median', algorithm='auto', radius=1.0) 
scorer_knn['cifar100_wrn'] = KNN(contamination=0.05, n_neighbors=4, metric='braycurtis', method='median', algorithm='auto', radius=1.0) 
scorer_knn['cifar10_densenet'] = KNN(contamination=0.05, n_neighbors=15, metric='braycurtis', method='largest', algorithm='auto', radius=1.0)
scorer_knn['cifar100_densenet'] = KNN(contamination=0.05, n_neighbors=15, metric='braycurtis', method='largest', algorithm='auto', radius=1.0)
scorer_knn['svhn_wrn'] = KNN(contamination=0.05, n_neighbors=5, metric='braycurtis', method='mean', algorithm='auto', radius=1.0)
scorer_knn['gtsrb_alexnet'] = KNN(contamination=0.05, n_neighbors=4, metric='braycurtis', method='mean', algorithm='auto', radius=1.0)


index_rename = {'animalscustom' : 'Animals',
 'animefacedata': 'Anime Faces',
 'fish': 'Fishes',
 'fruit': 'Fruits',
 'isun' : 'iSUN',
 'jigsaw_training': 'Jigsaw Training',
 'lsunc' : 'LSUN-Crop',
 'lsunr': 'LSUN-Resize',
 'officehome1art' : 'Office-Home Art',
 'officehome2clipart' : 'Office-Home Clipart',
 'officehome3product': 'Office-Home Product',
 'officehome4real': 'Office-Home Real',
 'pacs1photo' : 'PACS Photo',
 'pacs2art': 'PACS Art',
 'pacs3cartoon': 'PACS Cartoon',
 'pacs4sketch': 'PACS Sketch',
 'places365': 'Place365',
 'svhn': 'SVHN',
 'textures' : 'Texture',
 'jigsawtrain': 'Jigsaw on Training Set'
        }