import torch
import math
from torch.optim.lr_scheduler import LambdaLR

def set_learning_rate(optimizer, train_iter, args):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    if train_iter <= args.warmup_iter and args.warmup:
        # warmup
        lr = args.lr * ( float(train_iter) / float(args.warmup_iter) )
    elif train_iter < args.lr_drop_iter[0]:
        lr = args.lr
    elif train_iter >= args.lr_drop_iter[0] and train_iter < args.lr_drop_iter[1]:
        lr = args.lr * 0.1
    elif train_iter >= args.lr_drop_iter[1] and train_iter < args.lr_drop_iter[2]:
        lr = args.lr * 0.01
    elif train_iter >= args.lr_drop_iter[2]:
        lr = args.lr * 0.001

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    return lr


from bisect import bisect_right
import warnings
from torch.optim.lr_scheduler import _LRScheduler

class WarmupMultiStepLR(_LRScheduler):
    def __init__(self, optimizer, milestones, gamma=0.1, warmup_factor=1.0 / 3,
                 warmup_iters=500, last_epoch=-1):
        if not list(milestones) == sorted(milestones):
            raise ValueError(
                "Milestones should be a list of" " increasing integers. Got {}",
                milestones,
            )

        self.milestones = milestones
        self.gamma = gamma
        self.warmup_factor = warmup_factor
        self.warmup_iters = warmup_iters
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        warmup_factor = 1
        if self.last_epoch < self.warmup_iters:
            alpha = float(self.last_epoch) / self.warmup_iters
            warmup_factor = self.warmup_factor * (1 - alpha) + alpha
        return [
            base_lr
            * warmup_factor
            * self.gamma ** bisect_right(self.milestones, self.last_epoch)
            for base_lr in self.base_lrs
        ]


class WarmupStepLR(_LRScheduler):
    """Decays the learning rate of each parameter group by gamma every
    step_size epochs. Notice that such decay can happen simultaneously with
    other changes to the learning rate from outside this scheduler. When
    last_epoch=-1, sets initial lr as lr.

    Args:
        optimizer (Optimizer): Wrapped optimizer.
        step_size (int): Period of learning rate decay.
        gamma (float): Multiplicative factor of learning rate decay.
            Default: 0.1.
        last_epoch (int): The index of last epoch. Default: -1.

    """

    def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1, warmup_factor=1.0 / 3,
                 warmup_iters=500):
        self.step_size = step_size
        self.gamma = gamma
        self.warmup_factor = warmup_factor
        self.warmup_iters = warmup_iters
        super(WarmupStepLR, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        warmup_factor = 1
        if self.last_epoch < self.warmup_iters:
            alpha = float(self.last_epoch) / self.warmup_iters
            warmup_factor = self.warmup_factor * (1 - alpha) + alpha
        return [base_lr * warmup_factor * self.gamma ** (self.last_epoch // self.step_size)
                for base_lr in self.base_lrs]


def WarmupCosineLR(optimizer,
                   num_warmup_steps,
                   num_training_steps,
                   num_cycles=7./16.,
                   last_epoch=-1):
    def _lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        no_progress = float(current_step - num_warmup_steps) / \
            float(max(1, num_training_steps - num_warmup_steps))
        return max(0., math.cos(math.pi * num_cycles * no_progress))

    return LambdaLR(optimizer, _lr_lambda, last_epoch)
