import numpy as np
import cv2

def rotate(img, ang):
    center = (img.shape[1] // 2, img.shape[0] // 2)
    rotate_matrix = cv2.getRotationMatrix2D(center=center, angle=ang, scale=1)
    return cv2.warpAffine(img, rotate_matrix, (img.shape[1], img.shape[0]))
def resize(img, size):
    if img.shape[0] != size[0] or img.shape[1] != size[1]:
        return cv2.resize(img, (size[-1], size[-2]), interpolation=cv2.INTER_AREA)
    else:
        return img


def get_rot_imgs(img, angs, dsize, pad=False, rot_dc=False, dc_size=10):
    if pad:
        imgs = np.zeros((len(angs), dsize[0] * 2, dsize[1] * 2, 3), dtype=img.dtype)
    else:
        imgs =  np.zeros((len(angs), dsize[0], dsize[1], 3), dtype=img.dtype)

    # assumes img_angs is sorted...
    def get_img(ang):
        return rotate(img, -ang)


    for i in range(len(angs)):
        sm_img = resize(get_img(angs[i]), dsize)
        imgs[i, :dsize[0], :dsize[1], :] = sm_img
    
    if not rot_dc:
        print("replacing dc...")
        dc_sig = imgs[0].copy()
        dc_sig = dc_sig[dsize[0] // 2 - dc_size:dsize[0] // 2 + dc_size, dsize[1] // 2 - dc_size: dsize[1] // 2 + dc_size]
        imgs[:, dsize[0] // 2 - dc_size:dsize[0] // 2 + dc_size, dsize[1] // 2 - dc_size: dsize[1] // 2 + dc_size, :] = dc_sig
    return imgs


def process_img_dtype(img):
    if not np.issubdtype(img.dtype, np.integer):
        img = (np.clip(img, 0, 1) * 255).astype(np.uint8)[..., ::-1] # BGR2RGB
    return img

def write_img(path, img):
    cv2.imwrite(path, process_img_dtype(img))


video_codec = cv2.VideoWriter_fourcc(*'mp4v')
class Video:
    def __init__(self, path, fps=30):
        self.path = path
        self.fps = fps
        self.out = None
    
    def write(self, frame):
        if self.out is None:
            self.out = cv2.VideoWriter(self.path, video_codec, self.fps, (frame.shape[1], frame.shape[0]))
        self.out.write(process_img_dtype(frame))

    def fin(self):
        if self.out is not None:
            self.out.release()