import torch
import numpy as np
import torch.nn as nn

import wandb
from modules import IlluminationLayer, DetectorNoise
from unet import UNet
from HRNet import HighResolutionNet
import wandb
import os


class Model(nn.Module):
    def __init__(self, num_heads, num_leds, arch='unet',  num_channels=1, batch_norm=False, initilization_strategy=None,
                 num_filters=16, task='hela', noise=0.0, padding=False, force_positive=True, skipping=False, stats_mo=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.task = task
        self.skipping = skipping
        self.noise_layer = DetectorNoise(noise)
        self.hardtanh = nn.Hardtanh(-1, 1)
        self.batchnorm = nn.BatchNorm2d(num_channels)
        self.instancenorm = nn.InstanceNorm2d(num_channels)
        self.imagenorm = nn.InstanceNorm2d(num_leds, track_running_stats=True, momentum=stats_mo)
        self.batch_mins = []
        self.batch_maxs = []
        if initilization_strategy.lower() == 'full':
            num_channels = num_leds
            self.skip = True
        else:
            self.skip = False
        print(num_channels)
        self.illumination_layer = IlluminationLayer(num_leds, num_channels, initilization_strategy, force_positive=force_positive)
        if arch.lower() == 'unet':
            model_args = {
                'num_classes': 1,
                'start_filters': num_filters,
                'channels_in': num_channels,
                'batch_norm': batch_norm,
                'padding': padding,
            }
            print("UNET CONFIG:", str(model_args))
            model_fn = UNet
        elif arch.lower() == 'hrnet':
            if skipping:
                raise RuntimeError()
            model_args = {
                'num_channels': num_channels,
                'param_override': None,
                'padding': padding
            }
            model_fn = HighResolutionNet
        else:
            raise RuntimeError()
        self.nets = [model_fn(**model_args) for _ in range(self.num_heads)]
        try:
            self.run_name = os.path.basename(wandb.run.path)
        except:
            pass

    def forward(self, x):
        if self.skipping:
            x = self.imagenorm(x)
        if not self.skip:
            illuminated_image = self.illumination_layer(x)
        else:
            illuminated_image = x
        # adding gaussian noise, pass through if sigma is zero
        illuminated_image = self.noise_layer(illuminated_image)
        # clip the image to simulate a detector
        # illuminated_image = self.hardtanh(illuminated_image)
        # illuminated_image = self.batchnorm(illuminated_image)
        illuminated_image = self.instancenorm(illuminated_image)
        results = [net(illuminated_image) for net in self.nets]
        return torch.stack(results)

    def log_illumination(self, epoch, step):
        # extract the illumination layers weight
        weight = self.illumination_layer.physical_layer.weight.detach().cpu().numpy()
        # save the weights
        weight_path = os.path.join('/hddraid5/data/colin/ctc/patterns', f'epoch_{epoch}_step_{step}.npy')
        np.save(weight_path, weight)


    def save_model(self, file_path=None, verbose=False):
        # if no path given try to get path from W&B
        # if that fails use a UUID
        if file_path is None:
            base_folder = '/hddraid5/data/colin/ctc/models'
            os.makedirs(base_folder, exist_ok=True)
            model_path = os.path.join(base_folder, f'model_{self.run_name}.pth')
            torch.save(self.state_dict(), model_path)
            for u in range(self.num_heads):
                net_path = os.path.join(base_folder, f'net_{u}_{self.run_name}.pth')
                torch.save(self.nets[u].state_dict(), net_path)
                if verbose:
                    print("saved net to : " + net_path)
            if verbose:
                print(f"Saved model to: {model_path}")
