import torch
import torch.nn as nn
from torchvision import models


class ResNet50VSF(nn.Module):
    """
    Visual Semantic Feature Resnet50
    """

    def __init__(self, args):
        super(ResNet50VSF, self).__init__()
        self.word_dim = args.word_embedding_dim
        if args.dataset == 'coco':
            num_classes = 91
        else:
            num_classes = 365
        self.vsf_model = models.resnet50(pretrained=False, num_classes=num_classes)
        if num_classes == 365:
            ckpt = torch.load('../NetDissect-Lite/zoo/resnet50_places365.pth.tar')
            state_dict = {str.replace(k, 'module.', ''): v for k, v in ckpt['state_dict'].items()}
        elif num_classes == 91:
            ckpt = torch.load('./model_outputs/ckpt_pretrain_coco_resnet50.pth.tar')
            state_dict = ckpt['state_dict']
        self.vsf_model.load_state_dict(state_dict)

        if args.freeze:
            for param in self.vsf_model.parameters():
                param.requires_grad = False

        if args.layer == 'layer4':
            self.feature_dim = 512 * 4
        else:
            self.vsf_model.layer4 = Identity()
            self.feature_dim = 256 * 4
        self.vsf_model.fc = nn.Sequential(nn.BatchNorm1d(self.feature_dim),
                                          nn.Dropout(0.1),
                                          nn.Linear(in_features=self.feature_dim, out_features=self.feature_dim,
                                                    bias=True),
                                          nn.ReLU(),
                                          nn.BatchNorm1d(self.feature_dim),
                                          nn.Dropout(0.1),
                                          nn.Linear(in_features=self.feature_dim, out_features=self.word_dim,
                                                    bias=True))

    def forward(self, x):
        x = self.vsf_model(x)
        return x


class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, x):
        return x


class ResNet18VSF(nn.Module):
    """
    Visual Semantic Feature Resnet50
    """

    def __init__(self, args):
        super(ResNet18VSF, self).__init__()
        self.word_dim = args.word_embedding_dim
        if args.dataset == 'coco':
            num_classes = 91
        else:
            num_classes = 365
        self.vsf_model = models.resnet18(pretrained=False, num_classes=num_classes)
        if num_classes == 365:
            ckpt = torch.load('../NetDissect-Lite/zoo/resnet18_places365.pth.tar')
            state_dict = {str.replace(k, 'module.', ''): v for k, v in ckpt['state_dict'].items()}
        elif num_classes == 91:
            ckpt = torch.load('./model_outputs/ckpt_pretrain_coco_resnet18.pth.tar')
            state_dict = ckpt['state_dict']
        self.vsf_model.load_state_dict(state_dict)
        # Set grad to false to freeze.
        if args.freeze:
            for param in self.vsf_model.parameters():
                param.requires_grad = False

        # Default sets requires_grad to true,
        # so final fc can be optimized.
        # self.vsf_model.fc = nn.Linear(512, word_dim, bias=False)
        if args.layer == 'layer4':
            self.feature_dim = 512
        else:
            self.vsf_model.layer4 = Identity()
            self.feature_dim = 256
        self.vsf_model.fc = nn.Sequential(nn.BatchNorm1d(self.feature_dim),
                                          nn.Dropout(0.1),
                                          nn.Linear(in_features=self.feature_dim, out_features=self.feature_dim, bias=True),
                                          nn.ReLU(),
                                          nn.BatchNorm1d(self.feature_dim),
                                          nn.Dropout(0.1),
                                          nn.Linear(in_features=self.feature_dim, out_features=self.word_dim, bias=True))

    def forward(self, x):
        x = self.vsf_model(x)
        return x


class ViewFlatten(nn.Module):
    def __init__(self):
        super(ViewFlatten, self).__init__()

    def forward(self, x):
        return x.view(x.size(0), -1)

