import torch
import numpy as np
from torch.autograd import Variable
import argparse
import random
import os
import models
import utils
import glob
import os.path as osp
import cv2
#import BilateralLayer as bs
import torch.nn.functional as F
import scipy.io as io
import utils
import pdb

encoder = models.encoder0(cascadeLevel = 0).eval()

albedoDecoder = models.decoder0(mode=0).eval()

roughDecoder = models.decoder0(mode=2).eval()

encoder.load_state_dict(torch.load('PretrainModel/check_cascadeIIW0/encoder0_1.pth').state_dict())
#encoder.load_state_dict(torch.load('PretrainModel/check_cascade0_w320_h240/encoder0_13.pth').state_dict())


albedoDecoder.load_state_dict(torch.load('PretrainModel/check_cascadeIIW0/albedo0_1.pth' ).state_dict())
#albedoDecoder.load_state_dict(torch.load('PretrainModel/check_cascade0_w320_h240/albedo0_13.pth' ).state_dict())

roughDecoder.load_state_dict(torch.load('PretrainModel/check_cascadeIIW0/rough0_1.pth' ).state_dict())
#roughDecoder.load_state_dict(torch.load('PretrainModel/check_cascade0_w320_h240/rough0_13.pth' ).state_dict())

device = torch.device("cuda")

encoder = encoder.to(device)

albedoDecoder = albedoDecoder.to(device)

roughDecoder = roughDecoder.to(device)

print("Finishing building Albedo, Roughness model.")

def AlbedoFunc(img, magnitude, encoder = encoder, albedoDecoder = albedoDecoder):
    im = (np.transpose(img, [2, 0, 1] ).astype(np.float32 ) / 255.0 )[np.newaxis, :, :, :]

    im = im / im.max()

    imBatch= Variable(torch.from_numpy(im**(2.2))).cuda()

    x1, x2, x3, x4, x5, x6 = encoder(imBatch)

    albedoPred = 0.5 * (albedoDecoder(imBatch, x1, x2, x3, x4, x5, x6) + 1)

    bn, ch, nrow, ncol = albedoPred.size()
    albedoPred = albedoPred.view(bn, -1)
    albedoPred = albedoPred / torch.clamp(torch.mean(albedoPred, dim=1), min=1e-10).unsqueeze(1) / 3.0
    albedoPred = albedoPred.view(bn, ch, nrow, ncol)

    albedoPred = albedoPred.data.cpu().numpy().squeeze()

    albedoPred = albedoPred.transpose([1, 2, 0] )
    albedoPred = (albedoPred ) ** (1.0/2.2 )

    albedoPredIm = (np.clip(255 * albedoPred, 0, 255) ).astype(np.uint8)
    
    return albedoPredIm


def RoughFunc(img, magnitude, encoder = encoder, roughDecoder = roughDecoder):
    im = (np.transpose(img, [2, 0, 1] ).astype(np.float32 ) / 255.0 )[np.newaxis, :, :, :]

    im = im / im.max()

    imBatch= Variable(torch.from_numpy(im**(2.2))).cuda()

    x1, x2, x3, x4, x5, x6 = encoder(imBatch)

    roughPred = roughDecoder(imBatch, x1, x2, x3, x4, x5, x6 )

    roughPred = roughPred.data.cpu().numpy().squeeze()

    roughPredIm = (255 * 0.5*(roughPred+1) )
    
    roughPredIm = cv2.cvtColor(roughPredIm, cv2.COLOR_GRAY2BGR)
    
    roughPredIm = roughPredIm.astype(np.uint8)
    
    return roughPredIm



if __name__ == '__main__':
    img_name = "100014.png"
    
    raw_img = cv2.imread(img_name)
    
    AlbedoImg = AlbedoFunc(raw_img, magnitude = 0)
    
    RoughImg = RoughFunc(raw_img, magnitude = 0)
    
    cv2.imwrite("Albedo" + img_name + ".png", AlbedoImg)
    
    cv2.imwrite("Rough" + img_name + ".png", RoughImg)