pytorch 实现 GRL Gradient Reversal Layer

news/2024/11/9 21:05:47

在GRL中,要实现的目标是:在前向传导的时候,运算结果不变化,在梯度传导的时候,传递给前面的叶子节点的梯度变为原来的相反方向。举个例子最好说明了:

import torch
from torch.autograd  import  Function

x = torch.tensor([1.,2.,3.],requires_grad=True)
y = torch.tensor([4.,5.,6.],requires_grad=True)

z = torch.pow(x,2) + torch.pow(y,2)
f = z + x + y
s =6* f.sum()

print(s)
s.backward()
print(x)
print(x.grad)

这个程序的运行结果是:

tensor(672., grad_fn=<MulBackward0>)
tensor([1., 2., 3.], requires_grad=True)
tensor([18., 30., 42.])

这个运算过程对于tensor中的每个维度上的运算为:

f(x)=(x^{2}+x)*6

那么对于x的导数为:

\frac{\mathrm{d} f}{\mathrm{d} x} = 12x+6

所以当输入x=[1,2,3]时,对应的梯度为:[18,30,42]

因此这个是正常的梯度求导过程,但是如何进行梯度翻转呢?很简单,看下方的代码:

import torch
from torch.autograd  import  Function

x = torch.tensor([1.,2.,3.],requires_grad=True)
y = torch.tensor([4.,5.,6.],requires_grad=True)

z = torch.pow(x,2) + torch.pow(y,2)
f = z + x + y

class GRL(Function):
    def forward(self,input):
        return input
    def backward(self,grad_output):
        grad_input = grad_output.neg()
        return grad_input


Grl = GRL()

s =6* f.sum()
s = Grl(s)

print(s)
s.backward()
print(x)
print(x.grad)

运行结果为:

tensor(672., grad_fn=<GRL>)
tensor([1., 2., 3.], requires_grad=True)
tensor([-18., -30., -42.])

这个程序相对于上一个程序,只是差在加了一个梯度翻转层:

class GRL(Function):
    def forward(self,input):
        return input
    def backward(self,grad_output):
        grad_input = grad_output.neg()
        return grad_input

这个部分的forward没有进行任何操作,backward里面做了.neg()操作,相当于进行了梯度的翻转。在torch.autograd 中的FUnction 的backward部分,在不做任何操作的情况下,这里的grad_output的默认值是1.


http://www.niftyadmin.cn/n/3657692.html

相关文章

多核编程中的负载平衡难题

多核编程中的负载平衡难题作者&#xff1a;周伟明相关文章链接&#xff1a;多核编程中的锁竞争难题 多核编程的几个难题及其应对策略&#xff08;难题一&#xff09; OpenMP并行程序设计&#xff08;二&#xff09; OpenMP并行程序设计&#xff08;一&#xff09; 双核CPU上的快…

对于pytorch中的detach copy 讲解很好的一篇博文

https://blog.csdn.net/guofei_fly/article/details/104486708

测试驱动需求分析--需求文档评审实例

相关文章链接如下&#xff1a;微软过桥问题与测试人员素养 等价类分法 新解 测试用例设计中的NP难题 C/C代码检视实例 90&#xff05;程序员写不出无BUG的二分查找程序&#xff1f; 需求文档评审实例软件的开发文档质量一般只能通过评审来进行保证&#xff0c;如何有效发…

[Invariance Matters: Exemplar Memory for Domain Adaptive Person Re-identification 魔改代码

最近在看这篇文章&#xff0c;以及试着整改代码&#xff0c;按照最初的github设定&#xff0c;跑出来的性能和论文中是一样的&#xff0c;由于论文说了使用camstyle 后的生成图片&#xff0c;我就在想&#xff0c;如果不用这个部分会怎样&#xff0c;我就取消了使用这个数据&am…

使用radix sort 基排序对字符串进行排序

这部分的代码实现的操作是&#xff0c;对一个列表里面的字符串按照字母顺序排序&#xff0c;就像字典里面的单词排序一样&#xff0c;举例子如下&#xff1a; input [jkttsszzo, zie, iukddrjdba, bwjahzwiv, yslzvnjdjg, xkm, aszcnljjl, syniimbq, hqgyd, itvis]output [a…

模块分解原理的探索

模块分解原理的探索在软件高层设计中&#xff0c;如何分解模块是首要考虑的问题。目前业界公认模块划分要按照“高内聚&#xff0c;低耦合”的原则来进行&#xff0c;那么如何划分才能满足“高内聚&#xff0c;低耦合”呢&#xff1f;下面来对模块分解原理方面进行一些探索&…

利用radix sort 基排序对数字进行排序,指定基的基排序实现

基排序的概念就不做解释了&#xff0c;要说的一点是基排序中的这个基是可以任意选择的&#xff0c;只不过网上的大部分radix sort代码都是将10作为了基&#xff0c;我的这个代码是可以任意指定基的&#xff0c;代码如下&#xff1a; import random def numerical_radix_sort(n…

模块分解原理与三权分立

模块分解原理与三权分立相关文章链接&#xff1a;模块分解原理探索前一篇模块分解原理探索的文章中谈到了模块需要按专业领域分解&#xff0c;怎么这篇文章的标题上突然冒出了三权分立&#xff0c;软件怎么和政治制度扯到一起去了&#xff1f;表面看这两个东西好像是风牛马不及…