import os
import torch
from ..api import ComparedMethod


class Icon(ComparedMethod):
    image_size = 128
    
    def __init__(self, gen) -> None:
        super().__init__()
        self.gen = gen
    
    @torch.no_grad()
    def forward(self, c, x):
        return self.gen(x, c)

def load_method(output_dir='output/icon/128_1'):
    from .modified.network import Generator
    gen = Generator(3, 1)
    gen.load_state_dict(torch.load(os.path.join(output_dir, 'weights', 'latest_G.pth'), map_location='cpu'))
    return Icon(gen).eval()
    

if __name__ == '__main__':
    method = load_method()
