在GRL中,要实现的目标是:在前向传导的时候,运算结果不变化,在梯度传导的时候,传递给前面的叶子节点的梯度变为原来的相反方向。举个例子最好说明了:
import torch
from torch.autograd import Function
x = torch.tensor([1.,2.,3.],requires_grad=True)
y = torch.tensor([4.,5.,6.],requires_grad=True)
z = torch.pow(x,2) + torch.pow(y,2)
f = z + x + y
s =6* f.sum()
print(s)
s.backward()
print(x)
print(x.grad)
这个程序的运行结果是:
tensor(672., grad_fn=<MulBackward0>)
tensor([1., 2., 3.], requires_grad=True)
tensor([18., 30., 42.])
这个运算过程对于tensor中的每个维度上的运算为:
那么对于x的导数为:
所以当输入x=[1,2,3]时,对应的梯度为:[18,30,42]
因此这个是正常的梯度求导过程,但是如何进行梯度翻转呢?很简单,看下方的代码:
import torch
from torch.autograd import Function
x = torch.tensor([1.,2.,3.],requires_grad=True)
y = torch.tensor([4.,5.,6.],requires_grad=True)
z = torch.pow(x,2) + torch.pow(y,2)
f = z + x + y
class GRL(Function):
def forward(self,input):
return input
def backward(self,grad_output):
grad_input = grad_output.neg()
return grad_input
Grl = GRL()
s =6* f.sum()
s = Grl(s)
print(s)
s.backward()
print(x)
print(x.grad)
运行结果为:
tensor(672., grad_fn=<GRL>)
tensor([1., 2., 3.], requires_grad=True)
tensor([-18., -30., -42.])
这个程序相对于上一个程序,只是差在加了一个梯度翻转层:
class GRL(Function):
def forward(self,input):
return input
def backward(self,grad_output):
grad_input = grad_output.neg()
return grad_input
这个部分的forward没有进行任何操作,backward里面做了.neg()操作,相当于进行了梯度的翻转。在torch.autograd 中的FUnction 的backward部分,在不做任何操作的情况下,这里的grad_output的默认值是1.