import cv2
import numpy as np
import pywt
import pdb


def WaveletTransformation(image, magnitude, mode='haar'):
    """
    对图像进行小波变换，并将LH、HL、HH细节系数保存为图片。

    参数:
    image_path (str): 输入图像的文件路径。
    mode (str): 小波基函数，默认为'haar'。
    level (int): 小波变换的层数，默认为1。
    """
    
    # 执行二维小波变换
    image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    coeffs = pywt.dwt2(image, mode)
    
    # 从系数中提取LH、HL、HH细节系数
    coeffs_LH, coeffs_HL, coeffs_HH = coeffs[1]
    
    # 将系数转换为图像格式
    # 由于系数值可能超出0-255范围，需要进行归一化处理
    coeffs_LH = (coeffs_LH - coeffs_LH.min()) / (coeffs_LH.max() - coeffs_LH.min())
    coeffs_HL = (coeffs_HL - coeffs_HL.min()) / (coeffs_HL.max() - coeffs_HL.min())
    coeffs_HH = (coeffs_HH - coeffs_HH.min()) / (coeffs_HH.max() - coeffs_HH.min())
    
    # 将归一化后的系数转换为uint8格式，以便保存为图像
    coeffs_LH = (coeffs_LH * 255).astype(np.uint8)
    coeffs_HL = (coeffs_HL * 255).astype(np.uint8)
    coeffs_HH = (coeffs_HH * 255).astype(np.uint8)
    
    #pdb.set_trace()
    
    if(magnitude <= 3):
        coeffs_LH = cv2.resize(coeffs_LH, (image.shape[1], image.shape[0]))
        coeffs_LH = cv2.cvtColor(coeffs_LH, cv2.COLOR_GRAY2BGR)
        return coeffs_LH
    elif(magnitude >= 7):
        coeffs_HL = cv2.resize(coeffs_HL, (image.shape[1], image.shape[0]))
        coeffs_HL = cv2.cvtColor(coeffs_HL, cv2.COLOR_GRAY2BGR)
        return coeffs_HL
    else:
        coeffs_HH = cv2.resize(coeffs_HH, (image.shape[1], image.shape[0]))
        coeffs_HH = cv2.cvtColor(coeffs_HH, cv2.COLOR_GRAY2BGR)
        return coeffs_HH
    


# 使用函数
if __name__ == '__main__':
    image_path = 'check.jpg'  # 更改为你的图片路径
    image = cv2.imread(image_path)
    wavelet = WaveletTransformation(image, magnitude=7)

    cv2.imwrite("Wavelet.jpg", wavelet)