import os
import re
import torch
import torch.nn as nn

from ..api import ComparedMethod
from utils.test import from_image, to_image


class Anime(ComparedMethod):
    image_size = 256
    
    def __init__(self, G, en):
        super().__init__()
        self.G = G
        self.en = en
    
    @torch.no_grad()
    def forward(self, c, x):
        return self.G(c, self.en(x))


def load_method(output_dir='output/anime/256_2'):
    from .model.method import Generator, StyleEncoder
    
    G = Generator()
    en = StyleEncoder(256)
    
    net_state = torch.load(
        os.path.join(output_dir, 'checkpoint.pt'),
        map_location='cpu'
    )['net']
    
    G.load_state_dict({
        re.sub(r'^G\.', '', key): value
        for key, value in net_state.items()
        if re.match(r'G\..+', key)
    })
        
    return Anime(G, en).eval()
    


if __name__ == '__main__':
    G = load_method()
    device = 'cuda:0'
    G.to(device)
    # c = torch.randn(1, 1, 256, 256).to(device)
    # x = torch.randn(1, 3, 256, 256).to(device)
    # out = G(c, x)
    # print(out.shape)
    import ei; ei.embed()
    