import os
import torch
from .train import get_net
from ..api import ComparedMethod

class Comi(ComparedMethod):
    image_size = 224
    
    def __init__(self, gen) -> None:
        super().__init__()
        self.gen = gen
        
    @torch.no_grad()
    def forward(self, c, x):
        return self.gen(c, x)
        

def load_method(output_dir='output/comi/224_1'):
    net = get_net()
    state = torch.load(os.path.join(output_dir, 'checkpoint.pt'), map_location='cpu')
    net.load_state_dict(state['net'])
    gen = net['G']
    return Comi(gen).eval()


if __name__ == '__main__':
    method = load_method()
    
    