import torch
from sparseSSDmodel import *
from utils import *
import fnmatch

def load_pretrained_denseSSD_toSparse(sparse_SSD, dense_SSD):
    """
    Load VGG-16 pretrained on the ImageNet task as the base network to VGGBase_Sparse model
    """
    # Current state of base
    state_dict = sparse_SSD.state_dict()
    param_names = list(state_dict.keys())

    # Pretrained VGG base
    pretrained_state_dict = dense_SSD.state_dict()
    pretrained_param_names = list(pretrained_state_dict.keys())

    ######################################################################
    # Transfer conv. parameters from pretrained model to current model
    cnt = 1
    for param in param_names[1:50]:
        #print(param)
        if fnmatch.fnmatch(param, '*weight') and len(state_dict[param].shape)==4:
            tmp = pretrained_state_dict[pretrained_param_names[cnt]]
            assert len(tmp.shape) == 4
            #change tensor dimension test
            # e.g.
            #[64, 3, 3, 3] -> [9, 1, 3, 64]
            #[128, 64, 3, 3] -> [9, 1, 64, 128]
            #use both permute and view 
            tmp_weight = tmp.view([tmp.shape[0], tmp.shape[1], 1, tmp.shape[2] * tmp.shape[3]])
            tmp_weight = tmp_weight.permute(3, 2, 1, 0)
            assert state_dict[param].shape == tmp_weight.shape
            state_dict[param] = tmp_weight
            cnt += 1
            
        elif fnmatch.fnmatch(param, '*bias'):
            state_dict[param] = pretrained_state_dict[pretrained_param_names[cnt]]
            cnt += 1
    ######################################################################

    state_dict[param_names[0]] = pretrained_state_dict[pretrained_param_names[0]]

    for i, param in enumerate(param_names[51:]):
        state_dict[param_names[i+51]] = pretrained_state_dict[pretrained_param_names[i+21]]


    sparse_SSD.load_state_dict(state_dict)
    return sparse_SSD


if __name__ == '__main__':

    sparse_SSD = SSD300_sparse(n_classes=len(label_map))

    # convert pre-trained dense VGGSSD to sparse version  
    # load pre-trained denseSSD model weights
    checkpoint = './VGG16_IXReal_dense_ckp.pth.tar'
    checkpoint = torch.load(checkpoint, map_location=torch.device('cpu'))
    dense_SSD = checkpoint['model']

    sparse_SSD = load_pretrained_denseSSD_toSparse(sparse_SSD, dense_SSD)
    # save sparse_SSD model weights
    torch.save(sparse_SSD.state_dict(), './VGG16_IXReal_sparse_ckp.pth.tar')