扩散模型轨迹预测
文章目录
参考论文《Leapfrog Diffusion Model for Stochastic Trajectory Prediction》
CVPR2024
1. 问题定义
目的是得到model
g
θ
(
⋅
)
g_{\theta}(\cdot)
gθ(⋅),参数
θ
\theta
θ来生成分布
P
θ
=
g
θ
(
X
,
X
N
)
\mathcal{P}_{\theta}=g_{\theta}(\mathbf{X},\mathbb{X}_{\mathcal{N}})
Pθ=gθ(X,XN),基于分布
P
θ
\mathcal{P}_\theta
Pθ来画
K
K
K个样本,
Y
^
=
{
Y
1
^
,
Y
2
^
,
.
.
.
,
Y
K
^
}
\hat{\mathcal{\mathbf{Y}}}=\{\hat{\mathbf{Y_1}},\hat{\mathbf{Y_2}},...,\hat{\mathbf{Y_K}}\}
Y^={Y1^,Y2^,...,YK^},这样至少有一个样本是接近真实的未来轨迹。总体问题定义:
θ
∗
=
min
θ
min
Y
^
i
∈
Y
^
D
(
Y
^
i
,
Y
)
,
s.t.
Y
^
∼
P
θ
\theta^{*}=\min _{\theta} \min _{\widehat{\mathbf{Y}}_{i} \in \widehat{\mathcal{Y}}} D\left(\widehat{\mathbf{Y}}_{i}, \mathbf{Y}\right), \quad \text { s.t. } \widehat{\mathcal{Y}} \sim \mathcal{P}_{\theta}
θ∗=minθminY
i∈Y
D(Y
i,Y), s.t. Y
∼Pθ
X
\mathbf{X}
X和
X
N
\mathbb{X}_{\mathcal{N}}
XN分别表示ego车辆的过去轨迹和neighboring车辆,
Y
\mathbf{Y}
Y是ego车辆的未来轨迹。
通过一系列去噪步骤来学习轨迹分布,先执行前向diffusion的加噪到 未来轨迹的ground-truth上,然后,用条件去噪过程从过去轨迹的噪声中来恢复未来轨迹。
Diffusion 过程:
2. 方法论
2.1 前向扩散
Y
0
=
Y
\mathbf{Y}^{0}=\mathbf{Y}
Y0=Y
初始化扩散轨迹,
Y
γ
=
f
diffuse
(
Y
γ
−
1
)
,
γ
=
1
,
⋯
,
Γ
\mathbf{Y}^{\gamma}=f_{\text {diffuse }}\left(\mathbf{Y}^{\gamma-1}\right), \gamma=1, \cdots, \Gamma
Yγ=fdiffuse (Yγ−1),γ=1,⋯,Γ
使用前向
f
d
i
f
f
u
s
e
(
⋅
)
f_{diffuse}(\cdot)
fdiffuse(⋅)向
Y
γ
−
1
\mathbf{Y}^{\gamma-1}
Yγ−1添加连续噪声来获取扩散后的
Y
γ
\mathbf{Y}^{\gamma}
Yγ,其中
Y
γ
\mathbf{Y}^{\gamma}
Yγ是第
γ
\gamma
γ次diffusion步骤
2.2 逆过程
Y
^
k
Γ
∼
i
.
i
.
d
P
(
Y
^
Γ
)
=
N
(
Y
^
Γ
;
0
,
I
)
, sample
K
times
\widehat{\mathbf{Y}}_{k}^{\Gamma} \stackrel{i . i . d}{\sim} \mathcal{P}\left(\widehat{\mathbf{Y}}^{\Gamma}\right)=\mathcal{N}\left(\widehat{\mathbf{Y}}^{\Gamma} ; \mathbf{0}, \mathbf{I}\right) \text {, sample } K \text { times }
Y
kΓ∼i.i.dP(Y
Γ)=N(Y
Γ;0,I), sample K times
从正态分布中抽取
K
K
K个独立同分布的样本初始化去噪轨迹
Y
k
Γ
^
\hat{\mathbf{Y}^\Gamma_k}
YkΓ^
Y
^
k
γ
=
f
denoise
(
Y
^
k
γ
+
1
,
X
,
X
N
)
,
γ
=
Γ
−
1
,
⋯
,
0
\widehat{\mathbf{Y}}_{k}^{\gamma}=f_{\text {denoise }}\left(\widehat{\mathbf{Y}}_{k}^{\gamma+1}, \mathbf{X}, \mathbb{X}_{\mathcal{N}}\right), \gamma=\Gamma-1, \cdots, 0
Y
kγ=fdenoise (Y
kγ+1,X,XN),γ=Γ−1,⋯,0
迭代应用去噪操作
f
d
e
n
o
i
s
e
(
⋅
)
f_{denoise}(\cdot)
fdenoise(⋅)以过去轨迹
X
,
X
N
\mathbf{X},\mathbb{X}_\mathcal{N}
X,XN为条件获取去噪轨迹。
Y
k
γ
^
\hat{\mathbf{Y}_k^\gamma}
Ykγ^是第
γ
\gamma
γ次去噪轨迹第
k
k
k次采样,最终
K
K
K个预测轨迹
Y
=
{
Y
^
1
0
,
Y
^
2
0
,
…
,
Y
^
K
0
}
^
\hat{\mathcal{Y}=\left\{\hat{\mathbf{Y}}_{1}^{0}, \widehat{\mathbf{Y}}_{2}^{0}, \ldots, \widehat{\mathbf{Y}}_{K}^{0}\right\}}
Y={Y^10,Y
20,…,Y
K0}^
Note:
前向扩散处理不会用于推理步骤,在训练期间,
Y
γ
\mathbf{Y}^\gamma
Yγ是第
γ
\gamma
γ步
Y
k
γ
^
\hat{\mathbf{Y}^\gamma_k}
Ykγ^的监督
每个去噪步骤都是扩散步骤的逆过程,每个
Y
γ
\mathbf{Y}^{\gamma}
Yγ和
Y
k
γ
^
\hat{\mathbf{Y}^\gamma_{k}}
Ykγ^共享基础分布。
以上是问题定义建模,传统扩散模型受限于大量去噪的步骤的运算时间限制,但是轨迹预测需要实时推理,如果去噪的步骤很少会导致未来分布的表示能力很弱。
方法
[图片]
2.3 蛙跳扩散模型的步骤
X
\mathbf{X}
X和
X
N
\mathbb{X}_\mathcal{N}
XN分别是ego和neighboring智能体过去的轨迹,
Y
\mathbf{Y}
Y是ego的未来轨迹,
τ
\tau
τ是leapfrog的步数。
Y
0
=
Y
\mathbf{Y}^{0}=\mathbf{Y}
Y0=Y
Y
γ
=
f
diffuse
(
Y
γ
−
1
)
,
γ
=
1
,
⋯
,
Γ
\mathbf{Y}^{\gamma}=f_{\text {diffuse }}\left(\mathbf{Y}^{\gamma-1}\right), \gamma=1, \cdots, \Gamma
Yγ=fdiffuse (Yγ−1),γ=1,⋯,Γ
Y
^
τ
∼
K
P
(
Y
^
τ
)
=
f
L
S
G
(
X
,
X
N
)
\widehat{\mathcal{Y}}^{\tau} \stackrel{K}{\sim} \mathcal{P}\left(\widehat{\mathbf{Y}}^{\tau}\right)=f_{\mathrm{LSG}}\left(\mathbf{X}, \mathbb{X}_{\mathcal{N}}\right)
Y
τ∼KP(Y
τ)=fLSG(X,XN)
和标准扩散模型的步骤最主要的区别在这里,初始化器
f
L
S
G
(
⋅
)
f_{LSG}(\cdot)
fLSG(⋅)直接对第
τ
\tau
τ个去噪分布
P
(
Y
τ
^
)
\mathcal{P}(\hat{\mathbf{Y}^\tau})
P(Yτ^)建模,假设等价于执行
(
Γ
−
τ
)
(\Gamma-\tau)
(Γ−τ)个去噪步骤,从分布
P
(
Y
τ
^
)
\mathcal{P}(\hat{\mathbf{Y}^\tau})
P(Yτ^)中抽样并获取
K
K
K个未来轨迹
Y
^
τ
=
{
Y
^
1
τ
,
Y
^
2
τ
,
…
,
Y
^
K
τ
}
\widehat{\mathcal{Y}}^{\tau}=\left\{\widehat{\mathbf{Y}}_{1}^{\tau}, \widehat{\mathbf{Y}}_{2}^{\tau}, \ldots, \widehat{\mathbf{Y}}_{K}^{\tau}\right\}
Y
τ={Y
1τ,Y
2τ,…,Y
Kτ}
Y
^
k
γ
=
f
denoise
(
Y
^
k
γ
+
1
,
X
,
X
N
)
,
γ
=
τ
−
1
,
⋯
,
0
\widehat{\mathbf{Y}}_{k}^{\gamma}=f_{\text {denoise }}\left(\widehat{\mathbf{Y}}_{k}^{\gamma+1}, \mathbf{X}, \mathbb{X}_{\mathcal{N}}\right), \gamma=\tau-1, \cdots, 0
Y
kγ=fdenoise (Y
kγ+1,X,XN),γ=τ−1,⋯,0
在这一步,只需要对每个轨迹
Y
k
γ
^
\hat{\mathbf{Y}^{\gamma}_k}
Ykγ^应用剩余
τ
\tau
τ个去噪步骤来获取最终的预测
Y
^
=
{
Y
^
1
0
,
Y
^
2
0
,
…
,
Y
^
K
0
}
\widehat{\mathcal{Y}}=\left\{\widehat{\mathbf{Y}}_{1}^{0}, \widehat{\mathbf{Y}}_{2}^{0}, \ldots, \widehat{\mathbf{Y}}_{K}^{0}\right\}
Y
={Y
10,Y
20,…,Y
K0}
Note:
去噪步骤由
Γ
\Gamma
Γ减少到了
τ
\tau
τ,其远远小于
Γ
\Gamma
Γ,模型初始化器对
τ
\tau
τ去噪步骤直接提供了轨迹,加快了推理。和标准扩散模型相比,这里的抽样并非来自独立同分布的结果。
新模型和蛙跳模型使用相同的前向扩散过程,保证了表达能力。
2.4 蛙跳初始化器
通过学习的方式建模第
τ
\tau
τ个去噪分布
P
(
Y
^
)
\mathcal{P}({\hat{\mathbf{Y}}})
P(Y^),将分布拆解为三个部分:均值、全局方差和样本预测部分。过程如下:
μ
θ
=
f
μ
(
X
,
X
N
)
∈
R
T
f
×
2
\mu_{\theta}=f_{\mu}\left(\mathbf{X}, \mathbb{X}_{\mathcal{N}}\right) \in \mathbb{R}^{T_{\mathrm{f}} \times 2}
μθ=fμ(X,XN)∈RTf×2
σ
θ
=
f
σ
(
X
,
X
N
)
∈
R
\sigma_{\theta}=f_{\sigma}\left(\mathbf{X}, \mathbb{X}_{\mathcal{N}}\right) \in \mathbb{R}
σθ=fσ(X,XN)∈R
S
^
θ
=
[
S
^
θ
,
1
,
⋯
,
S
^
θ
,
K
]
=
f
S
^
(
X
,
X
N
,
σ
θ
)
∈
R
T
f
×
2
×
K
\widehat{\mathbb{S}}_{\theta}=\left[\widehat{\mathbf{S}}_{\theta, 1}, \cdots, \widehat{\mathbf{S}}_{\theta, K}\right]=f_{\widehat{\mathbf{S}}}\left(\mathbf{X}, \mathbb{X}_{\mathcal{N}}, \sigma_{\theta}\right) \in \mathbb{R}^{T_{f} \times 2 \times K}
S
θ=[S
θ,1,⋯,S
θ,K]=fS
(X,XN,σθ)∈RTf×2×K
Y
^
k
τ
=
μ
θ
+
σ
θ
⋅
S
^
θ
,
k
∈
R
T
f
×
2
Y
^
k
τ
=
μ
θ
+
σ
θ
⋅
S
^
θ
,
k
∈
R
T
f
×
2
\widehat{\mathbf{Y}}_{k}^{\tau}=\mu_{\theta}+\sigma_{\theta} \cdot \widehat{\mathbf{S}}_{\theta, k} \in \mathbb{R}^{T_{\mathrm{f}} \times 2}\widehat{\mathbf{Y}}_{k}^{\tau}=\mu_{\theta}+\sigma_{\theta} \cdot \widehat{\mathbf{S}}_{\theta, k} \in \mathbb{R}^{T_{\mathrm{f}} \times 2}
Y
kτ=μθ+σθ⋅S
θ,k∈RTf×2Y
kτ=μθ+σθ⋅S
θ,k∈RTf×2
其中,
f
μ
(
⋅
)
,
f
σ
(
⋅
)
,
f
S
^
(
⋅
)
f_{\mu}(\cdot), f_{\sigma}(\cdot), f_{\widehat{\mathbb{S}}}(\cdot)
fμ(⋅),fσ(⋅),fS
(⋅)是三个可训练模块,
μ
θ
,
σ
θ
\mu_{\theta}, \sigma_{\theta}
μθ,σθ是
P
(
Y
^
τ
)
\mathcal{P}\left(\widehat{\mathbf{Y}}^{\tau}\right)
P(Y
τ)的均值和方差,
S
^
θ
,
k
\widehat{\mathbf{S}}_{\theta, k}
S
θ,k是第
k
k
k次采样的正则化位置。
e
μ
θ
social
=
softmax
(
f
q
(
X
)
f
k
(
X
N
)
⊤
d
)
f
v
(
X
N
)
\mathbf{e}_{\mu_{\theta}}^{\text {social }}=\operatorname{softmax}\left(\frac{f_{\mathrm{q}}(\mathbf{X}) f_{\mathrm{k}}\left(\mathbb{X}_{\mathcal{N}}\right)^{\top}}{\sqrt{d}}\right) f_{\mathrm{v}}\left(\mathbb{X}_{\mathcal{N}}\right)
eμθsocial =softmax(d
fq(X)fk(XN)⊤)fv(XN)注意力模块社交编码
e
μ
θ
t
e
m
p
=
f
G
R
U
(
f
conv
1
D
(
X
)
)
\mathbf{e}_{\mu_{\theta}}^{\mathrm{temp}}=f_{\mathrm{GRU}}\left(f_{\operatorname{conv} 1 \mathrm{D}}(\mathbf{X})\right)
eμθtemp=fGRU(fconv1D(X))GRU时序编码
μ
θ
=
f
fusion
(
[
e
μ
θ
social
:
e
μ
θ
temp
]
)
\mu_{\theta}=f_{\text {fusion }}\left(\left[\mathbf{e}_{\mu_{\theta}}^{\text {social }}: \mathbf{e}_{\mu_{\theta}}^{\text {temp }}\right]\right)
μθ=ffusion ([eμθsocial :eμθtemp ])融合均值估计,得到轨迹均值(MLP)
采样预测模块
f
S
^
(
⋅
)
f_{\widehat{\mathbb{S}}}(\cdot)
fS
(⋅)将标准差的估计作为输入,计算过程如下:
e
S
^
θ
σ
=
f
encode
(
σ
θ
)
\mathbf{e}_{\widehat{\mathbb{S}}_{\theta}}^{\sigma}=f_{\text {encode }}\left(\sigma_{\theta}\right)
eS
θσ=fencode (σθ)将标准差的估计经过编码生成高维的embedding
e
S
^
θ
σ
\mathbf{e}_{\widehat{\mathbb{S}}_{\theta}}^{\sigma}
eS
θσ,这样标准差的估计也在样本的预测过程中涉及了。
S
^
θ
=
f
fusion
(
[
e
S
θ
social
:
e
S
^
θ
temp
:
e
S
^
θ
σ
]
)
\widehat{\mathbb{S}}_{\theta}=f_{\text {fusion }}\left(\left[\mathbf{e}_{\mathbb{S}_{\theta}}^{\text {social }}: \mathbf{e}_{\widehat{\mathbb{S}}_{\theta}}^{\text {temp }}: \mathbf{e}_{\widehat{\mathbb{S}}_{\theta}}^{\sigma}\right]\right)
S
θ=ffusion ([eSθsocial :eS
θtemp :eS
θσ])
从蛙跳初始化器获取到K个样本
Y
^
τ
=
{
Y
^
1
τ
,
Y
^
2
τ
,
…
,
Y
^
K
τ
}
\widehat{\mathcal{Y}}^{\tau}=\left\{\widehat{\mathbf{Y}}_{1}^{\tau}, \widehat{\mathbf{Y}}_{2}^{\tau}, \ldots, \widehat{\mathbf{Y}}_{K}^{\tau}\right\}
Y
τ={Y
1τ,Y
2τ,…,Y
Kτ},然后执行
τ
\tau
τ个去噪步骤来迭代精炼预测轨迹。
问题:剩余的
τ
\tau
τ步指的是哪
τ
\tau
τ步?采样K次是为了什么?
2.5 去噪模块
去噪模块
f
d
e
n
o
i
s
e
(
⋅
)
f_{denoise}(\cdot)
fdenoise(⋅)从过去的轨迹
(
X
,
X
N
)
(\mathbf{X},\mathbb{X}_\mathcal{N})
(X,XN)的条件下对轨迹,有两个训练模块,一个是基于Transformer的上下文编码模块学习社交-时序embedding和一个噪声估计模块
f
ϵ
(
⋅
)
f_{\epsilon}(\cdot)
fϵ(⋅)用以估计需要减少的噪声。第
γ
\gamma
γ步去噪的流程如下:
C
=
f
context
(
X
,
X
N
)
\mathbf{C}=f_{\text {context }}\left(\mathbf{X}, \mathbb{X}_{\mathcal{N}}\right)
C=fcontext (X,XN)使用上下文编码模块从过去的轨迹中获取上下文条件
C
\mathbf{C}
C,
f
c
o
n
t
e
x
t
(
X
,
X
N
)
f_{context}(\mathbf{X},\mathbb{X}_\mathcal{N})
fcontext(X,XN)和
f
μ
(
⋅
)
f_{\mu}(\cdot)
fμ(⋅)是相同的结构
ϵ
θ
γ
=
f
ϵ
(
Y
^
k
γ
+
1
,
C
,
γ
+
1
)
\boldsymbol{\epsilon}_{\theta}^{\gamma}=f_{\boldsymbol{\epsilon}}\left(\widehat{\mathbf{Y}}_{k}^{\gamma+1}, \mathbf{C}, \gamma+1\right)
ϵθγ=fϵ(Y
kγ+1,C,γ+1)通过上下文
C
\mathbf{C}
C的多层感知机实现的噪声估计
f
ϵ
(
⋅
)
f_{\epsilon}(\cdot)
fϵ(⋅)来估计带噪声轨迹
Y
^
k
γ
+
1
\widehat{\mathbf{Y}}_{k}^{\gamma+1}
Y
kγ+1的噪声
ϵ
θ
γ
\boldsymbol{\epsilon}_{\theta}^{\gamma}
ϵθγ
Y
^
k
γ
=
1
α
γ
(
Y
^
k
γ
+
1
−
1
−
α
γ
1
−
α
ˉ
γ
ϵ
θ
γ
)
+
1
−
α
γ
z
\widehat{\mathbf{Y}}_{k}^{\gamma}=\frac{1}{\sqrt{\alpha_{\gamma}}}\left(\widehat{\mathbf{Y}}_{k}^{\gamma+1}-\frac{1-\alpha_{\gamma}}{\sqrt{1-\bar{\alpha}_{\gamma}}} \boldsymbol{\epsilon}_{\theta}^{\gamma}\right)+\sqrt{1-\alpha_{\gamma} \mathbf{z}}
Y
kγ=αγ
1(Y
kγ+1−1−αˉγ
1−αγϵθγ)+1−αγz
标准去噪步骤
其中
α
ˉ
γ
=
∏
i
=
1
γ
α
i
\bar{\alpha}_{\gamma}=\prod_{i=1}^{\gamma} \alpha_{i}
αˉγ=∏i=1γαi和
α
γ
\alpha_\gamma
αγ是扩散过程的参数,
z
∼
N
(
z
;
0
,
I
)
\mathbf{z} \sim \mathcal{N}(\mathbf{z} ; \mathbf{0}, \mathbf{I})
z∼N(z;0,I)是噪声
3. 实践
3.1训练
分为两个阶段,第一阶段训练去噪模块、第二阶段聚焦于蛙跳初始化器。
(蛙跳初始化器的在给定分布
P
(
Y
^
τ
)
\mathcal{P}(\widehat{\mathbf{Y}}^\tau)
P(Y
τ)的情况下,蛙跳初始化器的训练更加稳定)
第一阶段使用扩散模型的标准训练模式利用噪声估计的loss训练:
L
N
E
=
∥
ϵ
−
f
ϵ
(
Y
γ
+
1
,
f
context
(
X
,
X
N
)
,
γ
+
1
)
∥
2
\mathcal{L}_{\mathrm{NE}}=\left\|\boldsymbol{\epsilon}-f_{\boldsymbol{\epsilon}}\left(\mathbf{Y}^{\gamma+1}, f_{\text {context }}\left(\mathbf{X}, \mathbb{X}_{\mathcal{N}}\right), \gamma+1\right)\right\|_{2}
LNE=
ϵ−fϵ(Yγ+1,fcontext (X,XN),γ+1)
2
其中,
γ
∼
U
{
1
,
2
,
⋯
,
Γ
}
,
ϵ
∼
N
(
ϵ
;
0
,
I
)
\gamma \sim U\{1,2, \cdots, \Gamma\},\boldsymbol{\epsilon} \sim \mathcal{N}(\boldsymbol{\epsilon} ; \mathbf{0}, \mathbf{I})
γ∼U{1,2,⋯,Γ},ϵ∼N(ϵ;0,I),扩散的轨迹
Y
γ
+
1
=
α
ˉ
γ
Y
0
+
1
−
α
ˉ
γ
ϵ
\mathbf{Y}^{\gamma+1}=\sqrt{\bar{\alpha}_{\gamma}} \mathbf{Y}^{0}+\sqrt{1-\bar{\alpha}_{\gamma}} \boldsymbol{\epsilon}
Yγ+1=αˉγ
Y0+1−αˉγ
ϵ
反向传播loss和训练的参数在上下文编码模块和噪声估计模块。
第二阶段使用可训练的蛙跳初始化器优化蛙跳扩散模型并且冻结去噪模块。对于每一个样本,loss为
L
=
L
distance
+
L
uncertainty
=
w
⋅
min
k
∥
Y
−
Y
^
k
∥
2
+
(
∑
k
∥
Y
−
Y
^
k
∥
2
σ
θ
2
K
+
log
σ
θ
2
)
\begin{aligned} \mathcal{L} & =\mathcal{L}_{\text {distance }}+\mathcal{L}_{\text {uncertainty }} \\ & =w \cdot \min _{k}\left\|\mathbf{Y}-\widehat{\mathbf{Y}}_{k}\right\|_{2}+\left(\frac{\sum_{k}\left\|\mathbf{Y}-\widehat{\mathbf{Y}}_{k}\right\|_{2}}{\sigma_{\theta}^{2} K}+\log \sigma_{\theta}^{2}\right) \end{aligned}
L=Ldistance +Luncertainty =w⋅kmin
Y−Y
k
2+
σθ2K∑k
Y−Y
k
2+logσθ2
其中
w
∈
R
w\in\mathbb{R}
w∈R是超参数权重,第一项约束了K个预测的最小距离。直观上说,如果蛙跳初始化器生成分布
P
(
Y
^
τ
)
\mathcal{P}(\widehat{\mathbf{Y}}^\tau)
P(Y
τ)的高质量估计,K个预测中之一一定与groud-truth非常接近。
第二项通过不确定性损失对重新参数化的方差估计
σ
θ
\sigma_\theta
σθ进行归一化,平衡预测多样性和平均精度。方差估计控制预测的分散性,弥合场景复杂性和预测多样性。
∑
k
∥
Y
−
Y
^
k
∥
2
σ
θ
2
K
\frac{\sum_{k}\left\|\mathbf{Y}-\widehat{\mathbf{Y}}_{k}\right\|_{2}}{\sigma_{\theta}^{2} K}
σθ2K∑k∥Y−Y
k∥2将
σ
θ
\sigma_\theta
σθ的值与场景的复杂度呈正比关系。
log
σ
θ
2
\log \sigma_{\theta}^{2}
logσθ2使用正则化器为所有的预测生成高方差。
附:作者表达了技术手段,他们并不在第二阶段对蛙跳初始化器进行估计时使用显式监督的手段的原因如下:
显式监督的实现:在初始化器的估计过程中,分布
P
(
Y
^
Γ
)
\mathcal{P}(\widehat{\mathbf{Y}}^\Gamma)
P(Y
Γ)可以从正态分布中去噪,为了完成显式监督,从正态分布下的
P
(
Y
^
Γ
)
\mathcal{P}(\widehat{\mathbf{Y}}^\Gamma)
P(Y
Γ)中抽样M个(M>>K),然后通过
f
d
e
n
o
i
s
e
f_{denoise}
fdenoise去噪直到获得去噪轨迹
Y
^
τ
\widehat{\mathbf{Y}}^{\tau}
Y
τ,然后使用这M个样本计算统计量作为均值估计
f
μ
(
⋅
)
f_\mu(\cdot)
fμ(⋅)和方差估计
f
σ
(
⋅
)
f_\sigma(\cdot)
fσ(⋅)的显式监督。
然而,由于
τ
<
<
Γ
\tau<<\Gamma
τ<<Γ,那么对
M
>
>
K
M>>K
M>>K个样本运行
(
Γ
−
τ
)
≈
Γ
(\Gamma-\tau) \approx \Gamma
(Γ−τ)≈Γ步去噪获得统计数据的过程会导致训练时间和存储消耗变得无法接受。(NBV数据集上一个epoch要6天)
3.2 推理阶段
在推理过程中,蛙跳扩散模型只需要
τ
\tau
τ步,而不是
Γ
\Gamma
Γ步去噪,从而加快了推理速度。具体来说,我们首先生成 K 个相关样本,以使用经过训练的蛙跳初始化器对分布
P
(
Y
~
τ
)
\mathcal{P}\left(\widetilde{\mathbf{Y}}^{\tau}\right)
P(Y
τ)进行建模。然后,这些样本将被输入去噪过程并迭代微调以产生最终的预测;参见算法 :
3.3 源码
实验复现:
conda create -n led python=3.7
conda activate led
注意spconv 1.x版本已经废弃,找不到了
pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html
git clone -b v1.2.1 --recursive https://github.com/traveller59/spconv.git
cd spconv
python setup.py bdist_wheel
cd dist
pip install spconv-1.2.1-cp37-cp37m-linux_x86_64.whl
https://drive.google.com/drive/folders/1Uy8-WvlCp7n3zJKiEX0uONlEcx2u3Nnx
数据下载地址
训练
python main_led_nba.py --cfg led_augment --gpu 0 --train 1 --info try1
3.4 问题记录
问题:
查看日志文件
看到了两处报错,一处是LIBC_PTHREAD找不到,一处是nvcc报错。
经过排查,需要完整克隆指令(https://github.com/traveller59/spconv/issues/264)
0.11.4这个不兼容torch 1.8.0
标签:轨迹,mathbf,widehat,theta,mathcal,蛙跳,扩散,gamma,left From: https://blog.csdn.net/qq_38853759/article/details/140657281