import numpy as np
import os
from typing import Optional, Any
import pickle
from PIL import Image
import matplotlib.pyplot as plt
import cv2
import torch
import pdb
import torch.nn.functional as F
import torchvision
from torchvision.models import resnet50
import torchvision.transforms as T

import torch
import torch.nn as nn
from torchvision.ops import roi_align
import math

use_gpu = torch.cuda.is_available()
device = torch.device("cuda:0" if use_gpu else "cpu")
torch.backends.cudnn.benchmark = True

class ObjectsEncoder(nn.Module):
  def __init__(self, input_dim,  output_size,filter_len,
               layer_filters=[64,64,128,128,256]):
    super(ObjectsEncoder_withAttn, self).__init__()
    self.input_dim = input_dim
    self.layer_filters = layer_filters

    self.temp_conv1 = nn.Sequential(
      nn.Conv1d(input_dim , self.layer_filters[0], filter_len, padding=filter_len//2),
      nn.BatchNorm1d(self.layer_filters[0]),
      nn.ReLU(),
    )

    self.temp_conv2 = nn.Sequential(
      nn.Conv1d(self.layer_filters[0], self.layer_filters[1], filter_len, padding=filter_len//2),
      nn.BatchNorm1d(self.layer_filters[1]),
      nn.ReLU()
    )

    self.temp_conv3 = nn.Sequential(
      nn.Conv1d(self.layer_filters[1], self.layer_filters[2], filter_len, padding=filter_len//2),
      nn.BatchNorm1d(self.layer_filters[2]),
      nn.ReLU()
    )

    self.temp_conv4 = nn.Sequential(
      nn.Conv1d(self.layer_filters[2], self.layer_filters[3], filter_len, padding=filter_len//2),
      nn.BatchNorm1d(self.layer_filters[3]),
      nn.ReLU()
    )

    self.temp_conv5 = nn.Sequential(
      nn.Conv1d(self.layer_filters[3], self.layer_filters[4], filter_len, padding=filter_len//2),
      nn.BatchNorm1d(self.layer_filters[4]),
      nn.ReLU()
    )

    self.avg_pool_temporal = nn.AdaptiveAvgPool1d(1)
    self.max_pool_temporal = nn.AdaptiveMaxPool1d(1)

    self.classifier = torch.nn.Sequential(
      torch.nn.Linear(self.layer_filters[-1], output_size)
    )
  def init_hidden(self, batch_size):
    return (torch.randn(2, batch_size, self.hidden_dim).to(device),
            torch.randn(2, batch_size, self.hidden_dim).to(device))

  def forward(self, batch):
    final_feat = self.temp_conv5(self.temp_conv4(self.temp_conv3(self.temp_conv2(self.temp_conv1(batch)))))
    final_feat = self.avg_pool_temporal(final_feat)+self.max_pool_temporal(final_feat)
    final_feat = final_feat.view(-1,self.layer_filters[-1])

    return self.classifier(final_feat)



