import glob
import os
import torch
from ..api import ComparedMethod


class MUNIT(ComparedMethod):
    image_size = 256
    c_mode = 'RGB'
    
    def __init__(self, gen_a, gen_b) -> None:
        super().__init__()
        self.gen_a = gen_a
        self.gen_b = gen_b
    
    @torch.no_grad()
    def forward(self, c, x):
        c = c * 2 - 1
        x = x * 2 - 1
        content, _ = self.gen_a.encode(c)
        _, style = self.gen_b.encode(x)
        r = self.gen_b.decode(content, style)
        return r / 2 + 0.5

def load_method(output_dir='output/munit/256_1'):
    from .modified.networks import AdaINGen
    from .modified.utils import get_config
    
    current_dir = os.path.abspath(os.path.dirname(__file__))
    config_path = os.path.join(current_dir, 'modified/configs/edges2handbags_folder.yaml')
    config = get_config(config_path)
    
    gen_a = AdaINGen(config['input_dim_a'], config['gen'])  # auto-encoder for domain a
    gen_b = AdaINGen(config['input_dim_b'], config['gen'])  # auto-encoder for domain b
    paths = glob.glob(os.path.join(output_dir, 'outputs/edges2handbags_folder/checkpoints/gen_*.pt'))
    paths.sort()
    state = torch.load(paths[-1], map_location='cpu')
    gen_a.load_state_dict(state['a'])
    gen_b.load_state_dict(state['b'])
    
    return MUNIT(gen_a, gen_b).eval()
    

if __name__ == '__main__':
    import ei.patched
    method = load_method()
