import numpy as np
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def use_cpu():
    global device
    device = torch.device('cpu')

# PROXIMAL HELPERS
def soft_threshold(x, thresh):
    return torch.sign(x) * torch.maximum(torch.abs(x) - thresh, torch.tensor([0], device=device))

# for approximation of closed form TV proximal solution
## TODO - make indexing variable, rather than hardcoding
def haar_soft(x, ax, shift, thresh):
    o = torch.zeros_like(x)
    C = 1 / np.sqrt(2)
    if shift:
        x = torch.roll(x, -1, dims=ax)
        if ax == 0:
            o[-1] = o[-2]
        elif ax == 1:
            o[:, -1] = o[:, -2]
        elif ax ==  2:
            o[:, :, -1] = o[:, :, -2]
    m = o.shape[ax] // 2
    if ax == 0:
        o[:m] = C * (x[0::2] + x[1::2])
        o[m:] = soft_threshold(C * (x[1::2, :] - x[0::2, :]), thresh)
    elif ax == 1:
        o[:, :m] = C * (x[:, 0::2] + x[:, 1::2])
        o[:, m:] = soft_threshold(C * (x[:, 1::2] - x[:, 0::2]), thresh)
    elif ax == 2:
        o[:, :, :m] = C * (x[:, :, 0::2] + x[:, :, 1::2])
        o[:, :, m:] = soft_threshold(C * (x[:, :, 1::2] - x[:, :, 0::2]), thresh)
    
    return o
def haar_inv(x, ax, shift):
    o = torch.zeros_like(x)
    C = 1 / np.sqrt(2)
    m = o.shape[ax] // 2
    if ax == 0:
        o[0::2] = C * (x[:m] - x[m:])
        o[1::2] = C * (x[:m] + x[m:])
    elif ax == 1:
        o[:, 0::2] = C * (x[:, :m] - x[:, m:])
        o[:, 1::2] = C * (x[:, :m] + x[:, m:])
    elif ax == 2:
        o[:, :, 0::2] = C * (x[:, :, :m] - x[:, :, m:])
        o[:, :, 1::2] = C * (x[:, :, :m] + x[:, :, m:])

    if shift:
        o = torch.roll(o, 1, dims=ax)

    return o

# base operator class
class Op:
    def forward(self, x):
        return x
    def adjoint(self, x):
        return x
    def err(self, x):
        return 0
    def prox(self, x, lr):
        return x

# PROXIMALS
class L1(Op):
    def __init__(self, lmb_sp):
        super().__init__()
        self.lmb_sp = lmb_sp
    def err(self, x):
        return self.lmb_sp * torch.sum(torch.abs(x))
    def prox(self, x, lr):
        return soft_threshold(x, lr * self.lmb_sp)

# approximate TV prior, from Kamilov [13]
class TVApprox(Op):
    def __init__(self, lmb_tv):
        super().__init__()
        self.lmb_tv = np.asarray(lmb_tv)
        self.num_tv_dims = len(self.lmb_tv)
    def err(self, x):
        err = 0
        for i in range(self.num_tv_dims):
            err += self.lmb_tv[i] * torch.sum(torch.abs(torch.diff(x, dim = i)))
        return err
    def prox(self, x, lr):
        u = torch.zeros_like(x)
        thresh = np.sqrt(2) * 2 * self.num_tv_dims * self.lmb_tv * lr
        for i in range(self.num_tv_dims):
            u += haar_inv(haar_soft(x, i, False, thresh[i]), i, False)
            u += haar_inv(haar_soft(x, i, True, thresh[i]), i, True)
        return u / (2 * self.num_tv_dims)

# Non-negativity prior
class NonNeg(Op):
    def prox(self, x, lr):
        return torch.maximum(x, torch.tensor([0], device=device))



# FISTA
def fista(init, target, forward, adjoint = None, iter_cb = None, lr=1e-1, iters=20000, target_min=0, target_max=np.inf, \
    reg = [], prox_w = []):

    x = init

    u = torch.clone(x).detach()
    q = 1

    for i in range(iters):
        try:
            qo = q
            uo = u
            
            if adjoint is not None:
                tx = x
            else:
                tx = x.clone().detach().requires_grad_(True)
            fx = torch.clip(forward(tx), target_min, target_max)
            diff_frame_err = (fx - target)

            curr_err = 0.5 * torch.linalg.norm(diff_frame_err) ** 2 
            data_err = curr_err.item()

            if adjoint is not None:
                z = x - lr * adjoint(diff_frame_err)
            else:
                curr_err.backward()
                z = x - lr * tx.grad.detach()

            # apply proximals
            u = torch.zeros_like(z)
            for r, weight in zip(reg, prox_w):
                if r is tuple:
                    err, prox = r
                else:
                    err = r.err
                    prox = r.prox
                u += weight * prox(z, lr)
                curr_err += err(x)

            q = (1 + np.sqrt(1 + 4 * (q ** 2))) / 2

            x = u + (qo - 1) / q * (u - uo)

            if iter_cb is not None:
                iter_cb(i, data_err, curr_err.item(), x)

        except KeyboardInterrupt:
            print("Early stopping, graceful exit...")
            break
        
            
    return x, fx.detach(), curr_err.item()
