from re import L
import torch
import os
import glob
import yaml
from ..api import ComparedMethod


class ASCFT(ComparedMethod):
    image_size = 256
    
    def __init__(self, gen) -> None:
        super().__init__()
        self.gen = gen
    
    @torch.no_grad()
    def forward(self, c, x):
        return self.gen(x * 2 - 1, c * 2 - 1)[0] / 2 + 0.5

def load_G(output_dir):
    from .modified.model import Generator
    
    config_path = os.path.join(os.path.dirname(__file__), 'modified', 'config.yml')
    config = yaml.load(open(config_path, 'r'), Loader=yaml.FullLoader)
    
    model_dir = os.path.join(output_dir, config['TRAINING_CONFIG']['MODEL_DIR'])
    assert glob.glob(os.path.join(model_dir, '*-G.ckpt')), 'no checkpoint found'
    
    last_ckpt_path = sorted([
        (int(os.path.basename(path).split('-')[0]), path)
        for path in glob.glob(os.path.join(model_dir, '*-G.ckpt'))
    ], key=(lambda pair: pair[0]), reverse=True)[0][1]
    
    g_spec = config['TRAINING_CONFIG']['G_SPEC'] == 'True'
    G = Generator(spec_norm=g_spec)
    G.load_state_dict(torch.load(last_ckpt_path, map_location='cpu'))
    
    # G(reference, sketch) 256x256
    return G

def load_method(output_dir='output/ascft/256_1'):
    gen = load_G(output_dir)
    return ASCFT(gen)
    

if __name__ == '__main__':
    method = load_method()
