import numpy as np
from math import log10, sqrt
import torch


def PSNR_np(original, compressed):
    # The function calculate PSNR between 2 images

    # Calc the MSE between both images
    mse = np.mean((original - compressed) ** 2)

    # MSE is zero means no noise is present in the signal and therefore PSNR have no importance.
    if mse == 0:
        return 100
    max_pixel = 1
    psnr = 20 * log10(max_pixel / sqrt(mse))  # PSNR formula
    return psnr


def PSNR(original, compressed):
    # The function calculate PSNR between 2 images

    # Calc the MSE between both images
    # mse = np.mean((original - compressed) ** 2)
    mse = torch.mean((original - compressed) ** 2)
    # MSE is zero means no noise is present in the signal and therefore PSNR have no importance.
    if mse == 0:
        return 100
    max_pixel = 1
    psnr = 20 * log10(max_pixel / sqrt(mse))  # PSNR formula
    return psnr


def PSRN_accuracy(original_clean_img, predicted_clean_img):
    # The function calcs the average PSNR accuracy on a batch of images

    # Number of images in batch
    img_num = original_clean_img.shape[0]
    # Create a new list to save the PSNR values
    psnr_list = np.zeros(img_num)

    # Calc the average PSNR
    for i in range(img_num):
        psnr_list[i] = PSNR(original_clean_img[i, 0, :, :], predicted_clean_img[i, 0, :, :])

    return np.mean(psnr_list), np.std(psnr_list)  # Return the mean and std of the PSNR
