DA-CLIP-universal-image-restoration代码详解
-
创建模型
model = create_model(opt) device = model.device
creat_model
最终指向ConditionalUNet
类,类的主要结构如下:class ConditionalUNet(nn.Module): def __init__(self, in_nc, out_nc, nf, ch_mult=[1, 2, 4, 4], context_dim=512, use_degra_context=True, use_image_context=False, upscale=1): ... ... ... self.downs = nn.ModuleList([]) self.ups = nn.ModuleList([]) for i in range(self.depth): if use_image_context and context_dim > 0: att_down = LinearAttention(dim_in) if i < 3 else SpatialTransformer(dim_in, num_heads_in, dim_head, depth=1, context_dim=context_dim) att_up = LinearAttention(dim_out) if i < 3 else SpatialTransformer(dim_out, num_heads_out, dim_head, depth=1, context_dim=context_dim) else: att_down = LinearAttention(dim_in) # if i < 2 else Attention(dim_in) att_up = LinearAttention(dim_out) # if i < 2 else Attention(dim_out) self.downs.append(nn.ModuleList([ block_class(dim_in=dim_in, dim_out=dim_in, time_emb_dim=time_dim), block_class(dim_in=dim_in, dim_out=dim_in, time_emb_dim=time_dim), Residual(PreNorm(dim_in, att_down)), Downsample(dim_in, dim_out) if i != (self.depth-1) else default_conv(dim_in, dim_out) ])) self.ups.insert(0, nn.ModuleList([ block_class(dim_in=dim_out + dim_in, dim_out=dim_out, time_emb_dim=time_dim), block_class(dim_in=dim_out + dim_in, dim_out=dim_out, time_emb_dim=time_dim), Residual(PreNorm(dim_out, att_up)), Upsample(dim_out, dim_in) if i!=0 else default_conv(dim_out, dim_in) ])) mid_dim = nf * ch_mult[-1] num_heads_mid = mid_dim // num_head_channels self.mid_block1 = block_class(dim_in=mid_dim, dim_out=mid_dim, time_emb_dim=time_dim) if use_image_context and context_dim > 0: self.mid_attn = Residual(PreNorm(mid_dim, SpatialTransformer(mid_dim, num_heads_mid, dim_head, depth=1, context_dim=context_dim))) else: self.mid_attn = Residual(PreNorm(mid_dim, LinearAttention(mid_dim))) self.mid_block2 = block_class(dim_in=mid_dim, dim_out=mid_dim, time_emb_dim=time_dim) self.final_res_block = block_class(dim_in=nf * 2, dim_out=nf, time_emb_dim=time_dim) self.final_conv = nn.Conv2d(nf, out_nc, 3, 1, 1) def forward(self, xt, cond, time, text_context=None, image_context=None): x = xt - cond x = torch.cat([x, cond], dim=1) H, W = x.shape[2:] x = self.check_image_size(x, H, W) x = self.init_conv(x) x_ = x.clone() t = self.time_mlp(time) if self.context_dim > 0: if self.use_degra_context and text_context is not None: prompt_embedding = torch.softmax(self.text_mlp(text_context), dim=1) * self.prompt prompt_embedding = self.prompt_mlp(prompt_embedding) t = t + prompt_embedding h = [] for b1, b2, attn, downsample in self.downs: x = b1(x, t) h.append(x) x = b2(x, t) x = attn(x, context=image_context) h.append(x) x = downsample(x) x = self.mid_block1(x, t) x = self.mid_attn(x, context=image_context) x = self.mid_block2(x, t) for b1, b2, attn, upsample in self.ups: x = torch.cat([x, h.pop()], dim=1) x = b1(x, t) x = torch.cat([x, h.pop()], dim=1) x = b2(x, t) x = attn(x, context=image_context) x = upsample(x) x = torch.cat([x, x_], dim=1) x = self.final_res_block(x, t) x = self.final_conv(x) x = x[..., :H, :W].contiguous() return x
LQ
是作为cond
与当前输入一起送入模块;text_context
是语义信息,是用注意力模块嵌入的;image_context
退化信息,作为prompt
,加入方式是相加到时间t
上面的。然后一起送入b1
,b1
的结构是
class ResBlock(nn.Module):
def __init__(self, conv, dim_in, dim_out, time_emb_dim=None, act=NonLinearity()):
super(ResBlock, self).__init__()
self.mlp = nn.Sequential(
act, nn.Linear(time_emb_dim, dim_out * 2)
) if time_emb_dim else None
self.block1 = Block(conv, dim_in, dim_out, act)
self.block2 = Block(conv, dim_out, dim_out, act)
self.res_conv = conv(dim_in, dim_out, 1) if dim_in != dim_out else nn.Identity()
def forward(self, x, time_emb=None):
scale_shift = None
if exists(self.mlp) and exists(time_emb):
time_emb = self.mlp(time_emb)
time_emb = rearrange(time_emb, 'b c -> b c 1 1')
scale_shift = time_emb.chunk(2, dim=1)
h = self.block1(x, scale_shift=scale_shift)
h = self.block2(h)
return h + self.res_conv(x)
class Block(nn.Module):
def __init__(self, conv, dim_in, dim_out, act=NonLinearity()):
super().__init__()
self.proj = conv(dim_in, dim_out)
self.act = act
def forward(self, x, scale_shift=None):
x = self.proj(x)
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
x = self.act(x)
可以看到time_emb
在第二个维度上被切割为两个部分,分别当成scale
和shift
,与x
一起送入模块进行计算
x
=
x
⋅
(
s
c
a
l
e
+
1
)
+
s
h
i
f
t
x = x\cdot(scale + 1) + shift
x=x⋅(scale+1)+shift,从而对
x
x
x进行调整。
训练过程中:
timesteps, states = sde.generate_random_states(x0=GT, mu=LQ)
def mu_bar(self, x0, t):
return self.mu + (x0 - self.mu) * torch.exp(-self.thetas_cumsum[t] * self.dt)
生成一个中间量,包括时间和状态。 在数学公式中,函数 mu_bar
可以表示为:
μ bar ( x 0 , t ) = μ + ( x 0 − μ ) ⋅ e − θ cumsum [ t ] ⋅ Δ t \mu_{\text{bar}}(x_0, t) = \mu + (x_0 - \mu) \cdot e^{-\theta_{\text{cumsum}}[t] \cdot \Delta t} μbar(x0,t)=μ+(x0−μ)⋅e−θcumsum[t]⋅Δt
这个公式描述了一个指数衰减的过程,其中初始状态 x 0 x_0 x0随时间 t t t向均值 μ \mu μ 衰减。衰减的速率由 θ cumsum [ t ] \theta_{\text{cumsum}}[t] θcumsum[t]和 Δ t \Delta t Δt 控制。随着 t t t 的增加, e − θ cumsum [ t ] ⋅ Δ t e^{-\theta_{\text{cumsum}}[t] \cdot \Delta t} e−θcumsum[t]⋅Δt项逐渐减小,导致 ( x 0 − μ ) (x_0 - \mu) (x0−μ)的影响减少,最终在$t $ 趋向无穷大时, μ bar ( x 0 , t ) \mu_{\text{bar}}(x_0, t) μbar(x0,t)将趋向于 μ \mu μ。
model.feed_data(states, LQ, GT, text_context=degra_context, image_context=image_context)
model.optimize_parameters(current_step, timesteps, sde)
def optimize_parameters(self, step, timesteps, sde=None):
sde.set_mu(self.condition)
self.optimizer.zero_grad()
timesteps = timesteps.to(self.device)
# Get noise and score
noise = sde.noise_fn(self.state, timesteps.squeeze(), text_context=self.text_context, image_context=self.image_context)
score = sde.get_score_from_noise(noise, timesteps)
# Learning the maximum likelihood objective for state x_{t-1}
xt_1_expection = sde.reverse_sde_step_mean(self.state, score, timesteps)
xt_1_optimum = sde.reverse_optimum_step(self.state, self.state_0, timesteps)
loss = self.weight * self.loss_fn(xt_1_expection, xt_1_optimum)
loss.backward()
self.optimizer.step()
self.ema.update()
def reverse_sde_step_mean(self, x, score, t):
return x - self.sde_reverse_drift(x, score, t)
def reverse_optimum_step(self, xt, x0, t):
A = torch.exp(-self.thetas[t] * self.dt)
B = torch.exp(-self.thetas_cumsum[t] * self.dt)
C = torch.exp(-self.thetas_cumsum[t-1] * self.dt)
term1 = A * (1 - C**2) / (1 - B**2)
term2 = C * (1 - A**2) / (1 - B**2)
return term1 * (xt - self.mu) + term2 * (x0 - self.mu) + self.mu
reverse_sde_step_mean
函数:
reverse_sde_step_mean
(
x
,
−
noise
σ
t
ˉ
,
t
)
=
x
−
(
θ
t
⋅
(
μ
−
x
)
+
σ
t
2
⋅
noise
σ
t
ˉ
)
⋅
Δ
t
\text{reverse\_sde\_step\_mean}(x, -\frac{\text{noise}}{\sigma_{\bar{t}}}, t) = x - \left( \theta_t \cdot (\mu - x) +\frac{\sigma_t^2 \cdot \text{noise}}{\sigma_{\bar{t}}} \right) \cdot \Delta t
reverse_sde_step_mean(x,−σtˉnoise,t)=x−(θt⋅(μ−x)+σtˉσt2⋅noise)⋅Δt
- θ t ⋅ ( μ − x ) \theta_t \cdot (\mu - x) θt⋅(μ−x)项表示状态 x x x向均值 μ \mu μ的漂移。
- σ t 2 ⋅ noise σ bar ( t ) \frac{\sigma_t^2 \cdot \text{noise}}{\sigma_{\text{bar}}(t)} σbar(t)σt2⋅noise项表示状态 x x x受到扩散系数和噪声影响的调整,其中 σ bar ( t ) \sigma_{\text{bar}}(t) σbar(t)是在时间步 t t t的标准化扩散系数。
总的来说,这个公式通过结合漂移和扩散的影响,计算了在逆向 SDE 过程中的期望状态,这有助于模型学习如何从噪声数据中恢复出原始数据。
reverse_optimum_step
函数:
reverse_optimum_step ( x t , x 0 , t ) = ( A ⋅ ( 1 − C 2 ) 1 − B 2 ) ⋅ ( x t − μ ) + ( C ⋅ ( 1 − A 2 ) 1 − B 2 ) ⋅ ( x 0 − μ ) + μ \text{reverse\_optimum\_step}(x_t, x_0, t) = \left( \frac{A \cdot (1 - C^2)}{1 - B^2} \right) \cdot (x_t - \mu) + \left( \frac{C \cdot (1 - A^2)}{1 - B^2} \right) \cdot (x_0 - \mu) + \mu reverse_optimum_step(xt,x0,t)=(1−B2A⋅(1−C2))⋅(xt−μ)+(1−B2C⋅(1−A2))⋅(x0−μ)+μ
其中:
- A = e − θ t ⋅ Δ t A = e^{-\theta_t \cdot \Delta t} A=e−θt⋅Δt
- B = e − θ cumsum [ t ] ⋅ Δ t B = e^{-\theta_{\text{cumsum}}[t] \cdot \Delta t} B=e−θcumsum[t]⋅Δt
- C = e − θ cumsum [ t − 1 ] ⋅ Δ t C = e^{-\theta_{\text{cumsum}}[t-1] \cdot \Delta t} C=e−θcumsum[t−1]⋅Δt
这个公式描述了在给定时间步 t t t下,如何根据当前状态 x t x_t xt和初始状态 x 0 x_0 x0计算最优的前一状态。这个计算涉及到三个项:
-
项 1当前状态对最优前一状态的贡献,考虑了从时间步 ( t ) 到 ( t-1 ) 的衰减效应。
-
项 2表示初始状态对最优前一状态的贡献,同样考虑了衰减效应。
-
均值项是过程的均值,它被加到最后的结果中,以确保计算的最优前一状态是相对于过程均值的。
总的来说,这个公式通过结合当前状态和初始状态的信息,以及时间步 t t t的累积效应,来估计最优的前一状态。这种方法在扩散模型中用于模拟数据的逆向生成过程。
标签:dim,CLIP,text,self,DA,restoration,context,time,out From: https://blog.csdn.net/qq_45747799/article/details/143116958