首页 > 其他分享 >cvxpylayer使用(基于Compressive Structured Light for Recovering Inhomogeneous Participating Media论文复现)

cvxpylayer使用(基于Compressive Structured Light for Recovering Inhomogeneous Participating Media论文复现)

时间:2023-07-17 17:46:08浏览次数:32  
标签:Compressive Recovering Media self torch volume shape tch cp

论文中Gini系数的计算
def cal_sparsity(x):
    # print(x.shape)
    n=x.shape[0]
    # x=x.reshape(x.shape.prob)
    x=x.abs()
    x,_=x.sort()
    # print(x)
    Gx=0
    for k in range(n):
        Gx+=x[k]*(n-k+0.5)
    if(x.sum()==0):
        Gx=0
    else:
        Gx*=2/(n*x.sum())
    Gx=1-Gx
    if(math.isinf(Gx)):
        Gx=1
    return Gx
Gini系数代表了数据的离散性(sparsity),论文中优化目标中的常数lambda是数据梯度的Gini系数和数据本身的Gini系数之比。因此这个参数需要提前在我自己的数据上计算。 这里可能出现x.sum()非常小或者就是为0的情况,Gini系数在程序中计算的结果是inf/nan;实际上此时离散度最大,Gini系数为1。

花了一晚上的时间看带绝对值的线性规划怎么转化为标准型,在“目标函数带绝对值号的特殊非线性规划问题”这篇文章中找到了证明:
image
image

然后发现cvxpy是支持绝对值计算的= =,使用cp.pnorm(D @ x, p=1)就行。

定义问题和对应求解器:
image

点击查看代码
import cvxpy as cp
import torch
from cvxpylayers.torch import CvxpyLayer

n, m = size[2], patternnum   # number of unkown pixels, number of measurements
k = 2*n+1
x = cp.Variable(n)
A = cp.Parameter((m, n))
b = cp.Parameter(m)
D = cp.Parameter((k, n))
constraints = [A @ x - b == 0]
objective = cp.Minimize(cp.pnorm(D @ x, p=1))
problem = cp.Problem(objective, constraints)
assert problem.is_dpp()

cvxpylayer = CvxpyLayer(problem, parameters=[A, b, D], variables=[x])
D_tch = torch.zeros(k,n)
A_tch = light_pattern.to(dtype=torch.float32)
lamda=0.8471/0.8588
for i in range(n):
    D_tch[i][i]=1
D_tch[n][0]=lamda
for i in range(n+1,k-1):
    D_tch[i][i-n-1]=lamda
    D_tch[i][i-n]=-1*lamda
D_tch[k-1][n-1]=lamda

直接使用这个求解器:

点击查看代码

loss_avg=0
# Gx=0
# Gdx=0
for i in range(n_valid):
    volume=valid_datasets[i]
    volume=torch.tensor(volume,dtype=torch.float32)
    volume_pred=torch.zeros(volume.shape)
    for x in range(size[0]):
        for y in range(size[1]):
            x_gt=volume[x][y]
            # Gx+=cal_sparsity(x_gt)
            # dx=torch.zeros(size[2]-1)
            # for z in range(size[2]-1):
            #     dx[z]=x_gt[z]-x_gt[z+1]
            # Gdx+=cal_sparsity(dx)
            measurement = A_tch@x_gt
            measurement=measurement+、
			measurement*noise_intensity*torch.randn(measurement.shape)
            b_tch=measurement

            # solve the problem
            solution, = cvxpylayer(A_tch, b_tch, D_tch)
            # print(solution)
            volume_pred[x][y]=solution
    loss=F.mse_loss(volume_pred,volume)
    print(loss)

print("avg:",loss_avg/n_valid)
# print("Gx:",Gx/(n_valid*size[1]*size[0]))
# print("Gdx:",Gdx/(n_valid*size[1]*size[0]))

将求解器作为网路的一部分使用:

点击查看代码
class MyNet(nn.Module):
    def __init__(self,pattern_num,density_shape,cuda):
        super(MyNet, self).__init__()
        self.device=cuda
        self.density_shape=density_shape
		# other parameter omitted
		self.light_pattern=torch.randint(2,[patternnum,size[2]])
        self.noise_intensity=0.1

        n, m = density_shape[2], pattern_num   # number of unkown pixels, number of measurements
        k = 2*n+1
        x = cp.Variable(n)
        A = cp.Parameter((m, n))
        b = cp.Parameter(m)
        D = cp.Parameter((k, n))
        constraints = [A @ x - b == 0]
        objective = cp.Minimize(cp.pnorm(D @ x, p=1))
        problem = cp.Problem(objective, constraints)
        assert problem.is_dpp()

        self.cvxpylayer = CvxpyLayer(problem, parameters=[A, b, D], variables=[x])
        self.D_tch = torch.zeros(k,n).to(device=self.device)
        self.A_tch = self.light_pattern.to(dtype=torch.float32)
        lamda=0.8471/0.8588
        for i in range(n):
            self.D_tch[i][i]=1
        self.D_tch[n][0]=lamda
        for i in range(n+1,k-1):
            self.D_tch[i][i-n-1]=lamda
            self.D_tch[i][i-n]=-1*lamda
        self.D_tch[k-1][n-1]=lamda
    
    def forward(self,volume):
        # print(volume.shape)
		volume=encode(other parameter,volume)
        batch_size=volume.shape[0]
        volume_pred=torch.zeros(volume.shape).to(device=self.device)
        x_gt=volume.reshape(batch_size*self.density_shape[0]*self.density_shape[1],self.density_shape[2])
        measurement = torch.einsum('BN,MN->BM',x_gt,self.A_tch)
        measurement=measurement+measurement*self.noise_intensity*torch.randn(measurement.shape).to(device=self.device)
        b_tch=measurement

        # solve the problem
        solution, = self.cvxpylayer(self.A_tch, b_tch, self.D_tch)
        # print(solution.shape)
        volume_pred=solution.reshape(volume.shape)

        return volume_pred

求解器是可以传入batchsize组数据求解的,但是实际上并不会并行求解,增大batchsize之后平均到每组x上的求解时间并没有变化。。。cvxpylayer内部到底是怎么写的导致这个问题就不知道了。

标签:Compressive,Recovering,Media,self,torch,volume,shape,tch,cp
From: https://www.cnblogs.com/zyx45889/p/17560749.html

相关文章

  • 安装OpenMediaVault服务和Docker应用
    安装SSH服务安装ssh服务sudoaptinstallssh配置ssh:sudovim/etc/ssh/sshd_config在sshd_config文件中找到PasswordAuthentication字段,将其设置为yes:PasswordAuthenticationyes如果需要用root用户远程登录,需要添加一句:PermitRootLoginyes#允许root用户登录启动服务......
  • LayoutRebuilder.ForceRebuildLayoutImmediate的使用和坑点
    LayoutRebuilder.ForceRebuildLayoutImmediate可以强制刷新layout组件,在使用layout和contentsizefitter组件制作如聊天框这种根据文字改变大小之类的UI时很好用。不过LayoutRebuilder.ForceRebuildLayoutImmediate有个坑点,其只有在物体激活时才会生效。......
  • 62.Oracle的实例恢复(instance recovery)和介质恢复(media recovery)
    Oracle数据库中的SCN说明:4种SCN:系统检查点(SystemCheckpoint)SCN数据文件检查点(DatafileCheckpoint)SCN结束SCN(StopSCN)开始SCN(StartSCN)1)systemcheckpointscn --当checkpoing完成后,oracle将systemCheckpointScn号存放在控制文件中,可以通......
  • WinUI(WASDK)使用MediaPipe检查人体姿态关键点
    前言之前有用这个MediaPipe.NET.NET包装库搞了手势识别,丰富了稚晖君的ElectronBot机器人的第三方上位机软件的功能,MediaPipe作为谷歌开源的机器视觉库,功能很丰富了,于是就开始整活了,来体验了一把人体姿态关键点检测。所用框架介绍1.WASDK这个框架是微软最新的应用开发框架,我......
  • Media Encoder 2023-视频编码软件mac/win版
    AdobeMediaEncoder2023是Adobe公司推出的一款专业的媒体编码和转换软件。作为AdobeCreativeCloud套件的一部分,它与其他Adobe创意应用程序(如PremierePro、AfterEffects)无缝集成,提供了一个强大的工具集,用于优化、转换和编码各种媒体文件。→→↓↓载MediaEncoder2......
  • ARM平台移植ZLMediaKit
    ZLMediaKit是一套高性能的流媒体服务框架,目前支持rtmp、rtsp、hls、http-flv等流媒体协议,支持linux、macos、windows三大PC平台和ios、android两大移动端平台。host主机:ubuntu18.04移植平台:rk3568交叉编译链版本:gccversion9.3.0https://github.com/ZLMediaKit/ZLMediaKit1,......
  • windows编译ZLMediaKit(vcpkg)
    windows编译ZLMediaKit转载https://www.jianshu.com/p/f6f1c0b7e32b编译#下载ZLMediaKitgitclonehttps://gitee.com/xia-chu/ZLMediaKit.git#切换到ZLMediaKit目录cdZLMediaKit#更新子模块代码gitsubmoduleupdate--init#vcpkg安装opensslvcpkginstall--trip......
  • ZLMediaKit Windows 编译
    下载ZLToolKit,放到ZLMediaKit-master\3rdpart\ZLToolKit下。https://github.com/ZLMediaKit/ZLToolKit/tree/master下载media-server放到ZLMediaKit-master\3rdpart\media-server下。https://github.com/ireader/media-server下载jsoncpp放到ZLMediaKit-master\3rdpart\jsoncp......
  • MediaPlayer
    1.MediaPlayer使用MediaPlayer媒体框架最重要的组件之一是MediaPlayer类。这个类的对象可以使用最少的设置获取、解码和播放音频和视频。它支持几种不同的媒体来源,如:本地资源内部uri,例如您可能从contentProvider获得的uri外部url(流)有关Android支持的媒体格式列表,请参阅......
  • ZLmediakit集群部署
    1简单理解2如何简单部署测试2.1:我在10.1.1.1的机器上有拉IPC-A摄像头的数据源;这个摄像头通过部署在10.1.1.1上的ZLM去拉,转协议RTSP的地址是rtsp://10.1.1.1:554/rtp/stream_12.2:我本地(127.0.0.1)启动ZLM服务,通过VLC拉stream1:此时肯定是找不到的,因为我本地就没有叫stream_1......