from typing import List
import torch
import torch.nn as nn

from utils.test import to_image, from_image
from PIL import Image

class ComparedMethod(nn.Module):
    _mode_to_ch = {'L': 1, 'RGB': 3}
    c_mode = 'L'
    x_mode = 'RGB'
    image_size = None
    
    def __init__(self) -> None:
        super().__init__()
        self.register_buffer('_dummy', torch.Tensor(1), persistent=False)
        assert self.c_mode in ('L', 'RGB')
        assert self.x_mode in ('L', 'RGB')
        assert self.image_size is not None
    
    @property    
    def device(self):
        return self._dummy.device
    
    @property
    def c_ch(self):
        return self._mode_to_ch[self.c_mode]
    
    @property
    def x_ch(self):
        return self._mode_to_ch[self.x_mode]
    
    def inputs_to_tensor(self, cs: List[Image.Image], xs: List[Image.Image]):
        return (torch.stack([from_image(c.convert(self.c_mode)).to(self.device) for c in cs]),
                torch.stack([from_image(x.convert(self.x_mode)).to(self.device) for x in xs]))
    
    def output_from_tensor(self, r):
        return to_image(r.cpu())
    
    def check_inputs(self, c: torch.Tensor, x: torch.Tensor):
        B1, C1, H1, W1 = c.shape
        B2, C2, H2, W2 = x.shape
        assert B1 == B2
        assert C1 == self.c_ch
        assert C2 == self.x_ch
        assert H1 == W1 == self.image_size
        assert H2 == W2 == self.image_size

    def colorize(self, cs, xs):
        if isinstance(cs, Image.Image) and isinstance(xs, Image.Image):
            inputs = self.inputs_to_tensor([cs], [xs])
            self.check_inputs(*inputs)
            rs = self.forward(*inputs)
            rs = list(map(self.output_from_tensor, rs))
            rs = rs[0]
        else:
            inputs = self.inputs_to_tensor(cs, xs)
            self.check_inputs(*inputs)
            rs = self.forward(*inputs)
            rs = list(map(self.output_from_tensor, rs))
        return rs

    def get_dummy_inputs(self, n=None):
        if n is None:
            c = Image.new(self.c_mode, (self.image_size, self.image_size))
            x = Image.new(self.x_mode, (self.image_size, self.image_size))
        else:
            c = [Image.new(self.c_mode, (self.image_size, self.image_size))] * n
            x = [Image.new(self.x_mode, (self.image_size, self.image_size))] * n
        return c, x
        