import cv2
import tensorflow as tf
import torch
import numpy as np

from pytorch_implementation.continuous_convolution import ContinuousPool
from tensorflow_implementation.continuous_convolution import continuous_pool

input = cv2.imread('random image', cv2.IMREAD_COLOR)
input = np.expand_dims(input, 0)
# cv2.imshow('asd', input)
# cv2.waitKey(0)

myTorchTensor = torch.from_numpy(np.transpose(input, (0, 3, 1, 2)))
# print(myTorchTensor, flush=True)
myPool = ContinuousPool(3, 1)
result1 = myPool.forward(myTorchTensor).numpy()


tf.reset_default_graph()
myTensorflowTensor = tf.convert_to_tensor(input, tf.float32)
myTFPool = continuous_pool(myTensorflowTensor, 3, 1)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    result2 = sess.run(myTFPool)

print(result1.shape, result2.shape)
print(np.all(np.transpose(result1, (0, 2, 3, 1)) == result2))
