首页 > 其他分享 >grad

grad

时间:2024-01-25 15:01:44浏览次数:18  
标签:unsqueeze self torch x2 x0 x1 grad

class Get_gradient_nopadding_rgb(nn.Module):
    def __init__(self):
        super(Get_gradient_nopadding_rgb, self).__init__()
        kernel_v = [[0, -1, 0],
                    [0, 0, 0],
                    [0, 1, 0]]
        kernel_h = [[0, 0, 0],
                    [-1, 0, 1],
                    [0, 0, 0]]
        kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0)
        kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0)
        self.weight_h = nn.Parameter(data=kernel_h, requires_grad=False).cuda()
        self.weight_v = nn.Parameter(data=kernel_v, requires_grad=False).cuda()

    def forward(self, x):
        x0 = x[:, 0]
        x1 = x[:, 1]
        x2 = x[:, 2]
        x0_v = F.conv2d(x0.unsqueeze(1), self.weight_v, padding=1)
        x0_h = F.conv2d(x0.unsqueeze(1), self.weight_h, padding=1)

        x1_v = F.conv2d(x1.unsqueeze(1), self.weight_v, padding=1)
        x1_h = F.conv2d(x1.unsqueeze(1), self.weight_h, padding=1)

        x2_v = F.conv2d(x2.unsqueeze(1), self.weight_v, padding=1)
        x2_h = F.conv2d(x2.unsqueeze(1), self.weight_h, padding=1)

        x0 = torch.sqrt(torch.pow(x0_v, 2) + torch.pow(x0_h, 2) + 1e-6)
        x1 = torch.sqrt(torch.pow(x1_v, 2) + torch.pow(x1_h, 2) + 1e-6)
        x2 = torch.sqrt(torch.pow(x2_v, 2) + torch.pow(x2_h, 2) + 1e-6)

        x = torch.cat([x0, x1, x2], dim=1)
        return x

 

标签:unsqueeze,self,torch,x2,x0,x1,grad
From: https://www.cnblogs.com/yyhappy/p/17987168

相关文章

  • 2024AAAI_SGNet Structure Guided Network via Gradient-Frequency Awareness for Dep
    1.任务描述: 给定输入LR深度图和HRRGB图像,引导DSR目的是在ground-truth深度图监督的条件下,预测HR深度图2.Network本文提出的SGNet主要包括两部分,即梯度校准模块(GCM)和频率感知模块(FAM)。首先将RGB图像和上采样后的LR深度图送入到GCM,利用RGB丰富的梯度信息在梯度域中......
  • 无涯教程-CSS3 - 渐变属性(Gradients)
    渐变显示两种或更多种颜色的组合,如下所示-线性渐变线性渐变用于以线性格式(如从上到下)排列两种或多种颜色。Toptobottom(从上到下)<html><head><style>#grad1{height:100px;background:-webkit-linear-gradient(pink,......
  • CEOI2023D1T3(LOJ4019) Brought Down the Grading Server? (分治+欧拉回路)
    因为我们有\(S=2^k\),所以我们先考虑\(k=1\)即\(S=2\)的时候应该怎么做。发现如果我们对于每一个核心从\(t_1\)向\(t_2\)连一条无向边,如果我们把「不交换\(t_1,t_2\)」看成将这条边定向为\(t_1\tot_2\),否则如果「交换\(t_1,t_2\)」看成将这条边定向为\(t_2\tot_1......
  • Gradle 出现 Could not resolve gradle
    Gradle在进行sync的时候会出现Causedby:org.gradle.internal.resolve.ModuleVersionResolveException:Couldnotresolvegradle:gradle:8.2.查看异常信息发现Gradle无法下载https://services.gradle.org/distributions/gradle-8.2-src.zip,这个链接重定向到https://g......
  • jax框架:jax.grad
    官方地址:https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html#jax.grad这里只给出几个样例代码:设置allow_int参数,实现对整数类型求导:未对整数类型求导:importjaxdeffun(x,y):print(x,y)returnjax.numpy.sum(2*x[0]+y[0]+2*x[1]+......
  • android开发编译出错:Unable to find method ''org.gradle.api.file.RegularFileProper
    Unabletofindmethod''org.gradle.api.file.RegularFilePropertyorg.gradle.api.file.ProjectLayout.fileProperty(org.gradle.api.provider.Provider)'''org.gradle.api.file.RegularFilePropertyorg.gradle.api.file.ProjectLayout.fileProp......
  • gradle仓库配置
    allprojects{repositories{defALIYUN_REPOSITORY_URL='https://maven.aliyun.com/repository/public'defALIYUN_JCENTER_URL='https://maven.aliyun.com/repository/public'defALIYUN_GOOGLE_URL='https://ma......
  • 神经网络优化篇:详解动量梯度下降法(Gradient descent with Momentum)
    动量梯度下降法还有一种算法叫做Momentum,或者叫做动量梯度下降法,运行速度几乎总是快于标准的梯度下降算法,简而言之,基本的想法就是计算梯度的指数加权平均数,并利用该梯度更新的权重。例如,如果要优化成本函数,函数形状如图,红点代表最小值的位置,假设从这里(蓝色点)开始梯度下降法,如果......
  • Android gradle dependency tree change(依赖树变化)监控实现,sdk version 变化一目了然
    @目录前言基本原理执行流程diff报告不同分支merge过来的diff报告同个分支产生的merge报告同个分支提交的diff报告具体实现原理我们需要监控怎样的Dendenpency变化怎样获取dependencyTreeproject.configurations方式./gradlewdependenciesAsciiDependencyReportRe......
  • 用实验来证实CentOS7中yum之update与upgrade之间的异同
    一、实验环境主机IP:10.1.1.21与10.1.1.22操作系统版本:CentOSLinuxrelease7.2.1511(Core)内核版本:3.10.0-327.el7.x86_64二、实验过程1.检查2台机器的初始环境。[root@GeekDevOps~]#cat/etc/redhat-releaseCentOSLinuxrelease7.2.1511(Core)[root@GeekDevOps~]#......