首页 > 其他分享 >DA-CLIP-universal-image-restoration代码详解

DA-CLIP-universal-image-restoration代码详解

时间:2024-10-25 17:21:12浏览次数:8  
标签:dim CLIP text self DA restoration context time out

DA-CLIP-universal-image-restoration代码详解

  1. 创建模型

        model = create_model(opt) 
        device = model.device
    

    creat_model最终指向ConditionalUNet类,类的主要结构如下:

    class ConditionalUNet(nn.Module):
        def __init__(self, in_nc, out_nc, nf, ch_mult=[1, 2, 4, 4], 
                        context_dim=512, use_degra_context=True, use_image_context=False, upscale=1):
    	...
    	...
    	...
            self.downs = nn.ModuleList([])
            self.ups = nn.ModuleList([])
    
            for i in range(self.depth):
                if use_image_context and context_dim > 0:
                    att_down = LinearAttention(dim_in) if i < 3 else SpatialTransformer(dim_in, num_heads_in, dim_head, depth=1, context_dim=context_dim)
                    att_up = LinearAttention(dim_out) if i < 3 else SpatialTransformer(dim_out, num_heads_out, dim_head, depth=1, context_dim=context_dim)
                else:
                    att_down = LinearAttention(dim_in) # if i < 2 else Attention(dim_in)
                    att_up = LinearAttention(dim_out) # if i < 2 else Attention(dim_out)
    
                self.downs.append(nn.ModuleList([
                    block_class(dim_in=dim_in, dim_out=dim_in, time_emb_dim=time_dim),
                    block_class(dim_in=dim_in, dim_out=dim_in, time_emb_dim=time_dim),
                    Residual(PreNorm(dim_in, att_down)),
                    Downsample(dim_in, dim_out) if i != (self.depth-1) else default_conv(dim_in, dim_out)
                ]))
    
                self.ups.insert(0, nn.ModuleList([
                    block_class(dim_in=dim_out + dim_in, dim_out=dim_out, time_emb_dim=time_dim),
                    block_class(dim_in=dim_out + dim_in, dim_out=dim_out, time_emb_dim=time_dim),
                    Residual(PreNorm(dim_out, att_up)),
                    Upsample(dim_out, dim_in) if i!=0 else default_conv(dim_out, dim_in)
                ]))
    
            mid_dim = nf * ch_mult[-1]
            num_heads_mid = mid_dim // num_head_channels
            self.mid_block1 = block_class(dim_in=mid_dim, dim_out=mid_dim, time_emb_dim=time_dim)
            if use_image_context and context_dim > 0:
                self.mid_attn = Residual(PreNorm(mid_dim, SpatialTransformer(mid_dim, num_heads_mid, dim_head, depth=1, context_dim=context_dim)))
            else:
                self.mid_attn = Residual(PreNorm(mid_dim, LinearAttention(mid_dim)))
            self.mid_block2 = block_class(dim_in=mid_dim, dim_out=mid_dim, time_emb_dim=time_dim)
    
            self.final_res_block = block_class(dim_in=nf * 2, dim_out=nf, time_emb_dim=time_dim)
            self.final_conv = nn.Conv2d(nf, out_nc, 3, 1, 1)
    
        def forward(self, xt, cond, time, text_context=None, image_context=None):
            x = xt - cond
            x = torch.cat([x, cond], dim=1)
    
            H, W = x.shape[2:]
            x = self.check_image_size(x, H, W)
    
            x = self.init_conv(x)
            x_ = x.clone()
    
            t = self.time_mlp(time) 
            if self.context_dim > 0:
                if self.use_degra_context and text_context is not None:
                    prompt_embedding = torch.softmax(self.text_mlp(text_context), dim=1) * self.prompt
                    prompt_embedding = self.prompt_mlp(prompt_embedding)
                    t = t + prompt_embedding
    
            h = []
            for b1, b2, attn, downsample in self.downs:
                x = b1(x, t)
                h.append(x)
    
                x = b2(x, t)
                x = attn(x, context=image_context)
                h.append(x)
    
                x = downsample(x)
    
            x = self.mid_block1(x, t)
            x = self.mid_attn(x, context=image_context)
            x = self.mid_block2(x, t)
    
            for b1, b2, attn, upsample in self.ups:
                x = torch.cat([x, h.pop()], dim=1)
                x = b1(x, t)
                
                x = torch.cat([x, h.pop()], dim=1)
                x = b2(x, t)
    
                x = attn(x, context=image_context)
                x = upsample(x)
    
            x = torch.cat([x, x_], dim=1)
    
            x = self.final_res_block(x, t)
            x = self.final_conv(x)
    
            x = x[..., :H, :W].contiguous()
            
            return x
    

    LQ是作为cond与当前输入一起送入模块;text_context是语义信息,是用注意力模块嵌入的;image_context退化信息,作为prompt,加入方式是相加到时间t上面的。然后一起送入b1b1的结构是

class ResBlock(nn.Module):
       def __init__(self, conv, dim_in, dim_out, time_emb_dim=None, act=NonLinearity()):
           super(ResBlock, self).__init__()
           self.mlp = nn.Sequential(
            act, nn.Linear(time_emb_dim, dim_out * 2)
           ) if time_emb_dim else None
   
           self.block1 = Block(conv, dim_in, dim_out, act)
           self.block2 = Block(conv, dim_out, dim_out, act)
        self.res_conv = conv(dim_in, dim_out, 1) if dim_in != dim_out else nn.Identity()
   
    def forward(self, x, time_emb=None):
           scale_shift = None
        if exists(self.mlp) and exists(time_emb):
               time_emb = self.mlp(time_emb)
               time_emb = rearrange(time_emb, 'b c -> b c 1 1')
               scale_shift = time_emb.chunk(2, dim=1)
   
           h = self.block1(x, scale_shift=scale_shift)
           h = self.block2(h)
   
           return h + self.res_conv(x)
   class Block(nn.Module):
       def __init__(self, conv, dim_in, dim_out, act=NonLinearity()):
           super().__init__()
           self.proj = conv(dim_in, dim_out)
           self.act = act
   
       def forward(self, x, scale_shift=None):
           x = self.proj(x)
   
           if exists(scale_shift):
               scale, shift = scale_shift
               x = x * (scale + 1) + shift
   
           x = self.act(x)

可以看到time_emb在第二个维度上被切割为两个部分,分别当成scaleshift,与x一起送入模块进行计算 x = x ⋅ ( s c a l e + 1 ) + s h i f t x = x\cdot(scale + 1) + shift x=x⋅(scale+1)+shift,从而对 x x x进行调整。

训练过程中:

timesteps, states = sde.generate_random_states(x0=GT, mu=LQ)
    def mu_bar(self, x0, t):
        return self.mu + (x0 - self.mu) * torch.exp(-self.thetas_cumsum[t] * self.dt)

生成一个中间量,包括时间和状态。 在数学公式中,函数 mu_bar 可以表示为:

μ bar ( x 0 , t ) = μ + ( x 0 − μ ) ⋅ e − θ cumsum [ t ] ⋅ Δ t \mu_{\text{bar}}(x_0, t) = \mu + (x_0 - \mu) \cdot e^{-\theta_{\text{cumsum}}[t] \cdot \Delta t} μbar​(x0​,t)=μ+(x0​−μ)⋅e−θcumsum​[t]⋅Δt

这个公式描述了一个指数衰减的过程,其中初始状态 x 0 x_0 x0​随时间 t t t向均值 μ \mu μ 衰减。衰减的速率由 θ cumsum [ t ] \theta_{\text{cumsum}}[t] θcumsum​[t]和 Δ t \Delta t Δt 控制。随着 t t t 的增加, e − θ cumsum [ t ] ⋅ Δ t e^{-\theta_{\text{cumsum}}[t] \cdot \Delta t} e−θcumsum​[t]⋅Δt项逐渐减小,导致 ( x 0 − μ ) (x_0 - \mu) (x0​−μ)的影响减少,最终在$t $ 趋向无穷大时, μ bar ( x 0 , t ) \mu_{\text{bar}}(x_0, t) μbar​(x0​,t)将趋向于 μ \mu μ。

model.feed_data(states, LQ, GT, text_context=degra_context, image_context=image_context) 
model.optimize_parameters(current_step, timesteps, sde)
    def optimize_parameters(self, step, timesteps, sde=None):
        sde.set_mu(self.condition)
        self.optimizer.zero_grad()
        timesteps = timesteps.to(self.device)

        # Get noise and score
        noise = sde.noise_fn(self.state, timesteps.squeeze(), text_context=self.text_context, image_context=self.image_context)
        score = sde.get_score_from_noise(noise, timesteps)

        # Learning the maximum likelihood objective for state x_{t-1}
        xt_1_expection = sde.reverse_sde_step_mean(self.state, score, timesteps)
        xt_1_optimum = sde.reverse_optimum_step(self.state, self.state_0, timesteps)
        loss = self.weight * self.loss_fn(xt_1_expection, xt_1_optimum)

        loss.backward()
        self.optimizer.step()
        self.ema.update()
    def reverse_sde_step_mean(self, x, score, t):
        return x - self.sde_reverse_drift(x, score, t)
    def reverse_optimum_step(self, xt, x0, t):
        A = torch.exp(-self.thetas[t] * self.dt)
        B = torch.exp(-self.thetas_cumsum[t] * self.dt)
        C = torch.exp(-self.thetas_cumsum[t-1] * self.dt)

        term1 = A * (1 - C**2) / (1 - B**2)
        term2 = C * (1 - A**2) / (1 - B**2)

        return term1 * (xt - self.mu) + term2 * (x0 - self.mu) + self.mu

reverse_sde_step_mean函数:
reverse_sde_step_mean ( x , − noise σ t ˉ , t ) = x − ( θ t ⋅ ( μ − x ) + σ t 2 ⋅ noise σ t ˉ ) ⋅ Δ t \text{reverse\_sde\_step\_mean}(x, -\frac{\text{noise}}{\sigma_{\bar{t}}}, t) = x - \left( \theta_t \cdot (\mu - x) +\frac{\sigma_t^2 \cdot \text{noise}}{\sigma_{\bar{t}}} \right) \cdot \Delta t reverse_sde_step_mean(x,−σtˉ​noise​,t)=x−(θt​⋅(μ−x)+σtˉ​σt2​⋅noise​)⋅Δt

  • θ t ⋅ ( μ − x ) \theta_t \cdot (\mu - x) θt​⋅(μ−x)项表示状态 x x x向均值 μ \mu μ的漂移。
  • σ t 2 ⋅ noise σ bar ( t ) \frac{\sigma_t^2 \cdot \text{noise}}{\sigma_{\text{bar}}(t)} σbar​(t)σt2​⋅noise​项表示状态 x x x受到扩散系数和噪声影响的调整,其中 σ bar ( t ) \sigma_{\text{bar}}(t) σbar​(t)是在时间步 t t t的标准化扩散系数。

总的来说,这个公式通过结合漂移和扩散的影响,计算了在逆向 SDE 过程中的期望状态,这有助于模型学习如何从噪声数据中恢复出原始数据。

reverse_optimum_step 函数:

reverse_optimum_step ( x t , x 0 , t ) = ( A ⋅ ( 1 − C 2 ) 1 − B 2 ) ⋅ ( x t − μ ) + ( C ⋅ ( 1 − A 2 ) 1 − B 2 ) ⋅ ( x 0 − μ ) + μ \text{reverse\_optimum\_step}(x_t, x_0, t) = \left( \frac{A \cdot (1 - C^2)}{1 - B^2} \right) \cdot (x_t - \mu) + \left( \frac{C \cdot (1 - A^2)}{1 - B^2} \right) \cdot (x_0 - \mu) + \mu reverse_optimum_step(xt​,x0​,t)=(1−B2A⋅(1−C2)​)⋅(xt​−μ)+(1−B2C⋅(1−A2)​)⋅(x0​−μ)+μ

其中:

  • A = e − θ t ⋅ Δ t A = e^{-\theta_t \cdot \Delta t} A=e−θt​⋅Δt
  • B = e − θ cumsum [ t ] ⋅ Δ t B = e^{-\theta_{\text{cumsum}}[t] \cdot \Delta t} B=e−θcumsum​[t]⋅Δt
  • C = e − θ cumsum [ t − 1 ] ⋅ Δ t C = e^{-\theta_{\text{cumsum}}[t-1] \cdot \Delta t} C=e−θcumsum​[t−1]⋅Δt

这个公式描述了在给定时间步 t t t下,如何根据当前状态 x t x_t xt​和初始状态 x 0 x_0 x0​计算最优的前一状态。这个计算涉及到三个项:

  1. 项 1当前状态对最优前一状态的贡献,考虑了从时间步 ( t ) 到 ( t-1 ) 的衰减效应。

  2. 项 2表示初始状态对最优前一状态的贡献,同样考虑了衰减效应。

  3. 均值项是过程的均值,它被加到最后的结果中,以确保计算的最优前一状态是相对于过程均值的。

总的来说,这个公式通过结合当前状态和初始状态的信息,以及时间步 t t t的累积效应,来估计最优的前一状态。这种方法在扩散模型中用于模拟数据的逆向生成过程。

标签:dim,CLIP,text,self,DA,restoration,context,time,out
From: https://blog.csdn.net/qq_45747799/article/details/143116958

相关文章

  • Day 12 闭包函数 + 装饰器 (如懂*-*)
    目录0昨日复习0.1函数对象0.1.1引用0.1.2当作函数的返回值0.1.3当作函数的参数0.1.4当作容器的元素0.2函数的嵌套0.3空间名称与作用域1闭包函数1.1何为闭包?1.2代码展示(这就是闭包!)1.3闭包的应用2装饰器2.1装饰器是什么?2.2为什么要用装饰器?(如懂,2.3怎么用装饰器?2.3.1......
  • 鸿蒙编程江湖:ArkTS中Sendable数据在并发实例间的传递
    本文旨在深入探讨华为鸿蒙HarmonyOSNext系统(截止目前API12)的技术细节,基于实际开发实践进行总结。主要作为技术分享与交流载体,难免错漏,欢迎各位同仁提出宝贵意见和问题,以便共同进步。本文为原创内容,任何形式的转载必须注明出处及原作者。Sendable是ArkTS中用于实现数据在并......
  • 计算机视觉库supervision学习-day(3)-各种Annotator
    上一次学习了supervision库的Detections类,按照官方文档,接下来学习的是各种Annotator标注器类,我主要学习几个我感兴趣的、有意思的Annotator类型一、Annotator所有的XxxAnnotator类都是继承自BaseAnnotator类,并重写了其中的annotator方法(注:由于几乎大部分的XxxAnnotator类的构......
  • CAN201 In Class Test 1 Thursday Session
    CAN201InClassTest1ThursdaySession2MultiplayerNumberGuessingGame(UDPSockets)ObjectiveThisinclasstestisrequiredtousePythonforsocketprogramming.Youwillcreateamultiplayer“numberguessing”gameusingUDPsocketsprogramming,wi......
  • AM05 AUT24 Outfit Of The Day Recommendation
    AM05AUT24FinalProjectAssignment:OutfitOfTheDayRecommendationSystem1AM05AUT24FinalProjectAssignment:OutfitOfTheDayRecommendationSystemIntroductionWelcometoyourfinalprojectfortheDataManagementcourse.Thisprojectisdesigned......
  • Unet网络搭建Day1
    Pycharm内搭建虚拟环境:一、将PyCharm中的终端运行前面的PS修改成当前环境解决方法:只需要在pycharm的设置中修改一些terminal的环境即可,具体步骤如下:1.打开pycharm中的settings;2.找到Terminal选项;3.将shellpath的位置改为cmd.exe;4.点击ok;5.重启pycharm即可。二、wandb......
  • Java开发学习day06--方法
    随笔记录学习之路,如有侵权请联系我删除,学习内容主要来自黑马 1.方法概述1.1方法的概念:方法(method)是程序中最小的执行单元注意:    方法必须先创建才可以使用,该过程成为方法定义    方法创建后并不是直接可以运行的,需要手动使用后,才执行,该过程成为方法调用2.......
  • 网络编程(Day34)
    一、学习内容网络发展历史发展阶段1.APRAnet阶段---冷战产物2.TCP/IP协议阶段--只有TCP和IP两个协议3.osi开放系统互联模型4.TCP/IP协议族(重要)5.量子通信(可能)TCP/IP两个协议阶段概念在计算机网络中,要做到有条不紊的交换数据,需要遵循一些事先约定好的规则......
  • java基础day04:方法(函数),练习
    一、Java中的方法(函数)1.定义方法也叫代码块,用于实现代码的封装,方便调用。主要目的是实现代码复用。2、方法和函数关系方法和函数的关系定义方法(面向对象)在类中定义的成员。也可以叫成员方法,属于类创建出来的对象的功能函数(面向过程)3、定义方法        ......
  • git报错系列---unable to update local ref
    报错:root@928c09c89c1c:/home/work/bag#gitpullerror:cannotlockref'refs/remotes/origin/lozen/remux':'refs/remotes/origin/lozen'exists;cannotcreate'refs/remotes/origin/lozen/remux'Fromgit.baijiashilian.com:LLL/glou......