import torch
from torch.autograd import Function


class GradientReversalFunction(Function):
	"""
	Gradient Reversal Layer from:
	Unsupervised Domain Adaptation by Backpropagation (Ganin & Lempitsky, 2015)

	Forward pass is the identity function. In the backward pass,
	the upstream gradients are multiplied by -lambda (i.e. gradient is reversed)
	"""

	@staticmethod
	def forward(ctx, x, lambda_):
		ctx.lambda_ = lambda_
		return x.clone()

	@staticmethod
	def backward(ctx, grads):
		lambda_ = ctx.lambda_
		lambda_ = grads.new_tensor(lambda_)
		dx = -lambda_ * grads
		return dx, None


class GradientReversal(torch.nn.Module):
	def __init__(self, lambda_=0.1):
		super(GradientReversal, self).__init__()
		self.lambda_ = lambda_

	def forward(self, x):
		return GradientReversalFunction.apply(x, self.lambda_)
