"""Safe Control Net"""
import math
import random
import torch
from ControlNet.cldm.cldm import ControlLDM, ControlNet
from ControlNet.ldm.modules.diffusionmodules.util import timestep_embedding
from ControlNet.ldm.util import default, instantiate_from_config
from ControlNet.ldm.models.diffusion.ddpm import disabled_train
from config_control import Config


def symsigmoid(x):
    "Symmetric sigmoid function $|x|*(2\sigma(x)-1)$"
    return torch.abs(x) * (2 * torch.nn.functional.sigmoid(x) - 1)


class OurClipControlLDM(ControlLDM):
    """Change condition to latent use CLIP."""

    def forward(self, x, c, *args, **kwargs):
        t_start = 0
        if (
            hasattr(Config().train, "timestep_start")
            and Config().train.timestep_start > 0
        ):
            t_start = Config().train.timestep_start
        t = torch.randint(
            t_start, self.num_timesteps, (x.shape[0],), device=self.device
        ).long()
        if self.model.conditioning_key is not None:
            assert c is not None
            if self.cond_stage_trainable:
                c = self.get_learned_conditioning(c)
            if self.shorten_cond_schedule:  # TODO: drop this option
                tc = self.cond_ids[t].to(self.device)
                c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
        return self.p_losses(x, c, t, *args, **kwargs)

    def p_losses(self, x_start, cond, t, noise=None):
        noise = default(noise, lambda: torch.randn_like(x_start))
        if self.training:
            x_noisy = x_start
        else:
            x_noisy = x_start + self.q_sample(x_start=x_start, t=t, noise=noise)
        cond_txt = torch.cat(cond["c_crossattn"], 1)
        hint = torch.cat(cond["c_concat"], 1)
        hint = 2 * hint - 1
        hint = self.first_stage_model.encode(hint)
        hint = self.get_first_stage_encoding(hint).detach()
        control = self.control_model(
            x=x_noisy,
            hint=hint,
            timesteps=t,
            context=cond_txt,
        )
        return control

    def instantiate_first_stage(self, config):
        model = instantiate_from_config(config)
        self.first_stage_config = config
        self.first_stage_model = model.eval()
        self.first_stage_model.train = disabled_train
        for param in self.first_stage_model.parameters():
            param.requires_grad = False


class OurClipControlNet(ControlNet):
    """Our control"""

    def __init__(
        self,
        image_size,
        in_channels,
        model_channels,
        hint_channels,
        num_res_blocks,
        attention_resolutions,
        dropout=0,
        channel_mult=...,
        conv_resample=True,
        dims=2,
        use_checkpoint=False,
        use_fp16=False,
        num_heads=-1,
        num_head_channels=-1,
        num_heads_upsample=-1,
        use_scale_shift_norm=False,
        resblock_updown=False,
        use_new_attention_order=False,
        use_spatial_transformer=False,
        transformer_depth=1,
        context_dim=None,
        n_embed=None,
        legacy=True,
        disable_self_attentions=None,
        num_attention_blocks=None,
        disable_middle_self_attn=False,
        use_linear_in_transformer=False,
    ):
        super().__init__(
            image_size,
            in_channels,
            model_channels,
            hint_channels,
            num_res_blocks,
            attention_resolutions,
            dropout,
            channel_mult,
            conv_resample,
            dims,
            use_checkpoint,
            use_fp16,
            num_heads,
            num_head_channels,
            num_heads_upsample,
            use_scale_shift_norm,
            resblock_updown,
            use_new_attention_order,
            use_spatial_transformer,
            transformer_depth,
            context_dim,
            n_embed,
            legacy,
            disable_self_attentions,
            num_attention_blocks,
            disable_middle_self_attn,
            use_linear_in_transformer,
        )
        self.fix_noise = None

    def forward(self, x, hint, timesteps, context, **kwargs):
        h = hint + x.type(self.dtype)
        if not self.training:
            h = symsigmoid(h)
        # here we add a fix noise
        if self.fix_noise is None:
            self.fix_noise = torch.randn(h.shape[1], h.shape[2], h.shape[3])
        h += self.fix_noise.repeat(h.shape[0], 1, 1, 1).to(h.device)
        # Here we need to quantizde fp16 and try it.
        h = h.half()
        h = h.to(torch.float32)
        return h


class ClipControlLDM(ControlLDM):
    """Change condition to latent use CLIP."""

    def p_losses(self, x_start, cond, t, noise=None):
        noise = default(noise, lambda: torch.randn_like(x_start))
        x_noisy = x_start + self.q_sample(x_start=x_start, t=t, noise=noise)
        cond_txt = torch.cat(cond["c_crossattn"], 1)
        hint = torch.cat(cond["c_concat"], 1)
        hint = 2 * hint - 1
        hint = self.first_stage_model.encode(hint)
        hint = self.get_first_stage_encoding(hint).detach()
        control = self.control_model(
            x=x_noisy,
            hint=hint,
            timesteps=t,
            context=cond_txt,
        )
        return control

    def instantiate_first_stage(self, config):
        model = instantiate_from_config(config)
        self.first_stage_config = config
        self.first_stage_model = model.eval()
        self.first_stage_model.train = disabled_train
        for param in self.first_stage_model.parameters():
            param.requires_grad = False


class ClipControlNet(ControlNet):
    """Our control"""

    def forward(self, x, hint, timesteps, context, **kwargs):
        h = hint + x.type(self.dtype)
        return h


class DenoiseClipControlLDM(ClipControlLDM):
    """Change condition to latent use CLIP."""

    def p_losses(self, x_start, cond, t, noise=None):
        x_noisy = x_start
        cond_txt = torch.cat(cond["c_crossattn"], 1)
        hint = torch.cat(cond["c_concat"], 1)
        hint = 2 * hint - 1
        hint = self.first_stage_model.encode(hint)
        hint = self.get_first_stage_encoding(hint).detach()
        control = self.control_model(
            x=x_noisy,
            hint=hint,
            timesteps=t,
            context=cond_txt,
        )
        return control
