inner_edges (ctx_edges)会包含全局节点的部分
第一部分即使E3-E3和Antibody-Antibody
还有三个是我之前忽略了的
在这里,每个蛋白质的 3 个全局节点分别代表抗原、重链和轻链的全局信息。让我们逐步解析 global_normal
、global_global
和 seq_edges
的含义,以及它们的构建逻辑。
1. 每个蛋白质的 3 个全局节点
每个蛋白质会有 3 个全局节点,用于分别代表:
- 抗原的全局节点(标识为
BOA
,如'&'
) - 重链的全局节点(标识为
BOH
,如'+'
) - 轻链的全局节点(标识为
BOL
,如'-'
)
这些全局节点的设计目的是在模型中为每个链提供一个集中信息的节点,以帮助捕捉整个抗原、重链和轻链的全局特征。
2. global_normal
和 global_global
边的含义
在 self._construct_global_edges()
方法中:
-
global_normal
边:连接 全局节点 和 普通节点 的边。这些边确保每个链的全局节点与其链中的其他节点相连,以便在图中传递全局信息。- 例如,抗原全局节点
BOA
会与抗原链上的所有其他节点连接;重链和轻链的全局节点也分别与重链和轻链的普通节点连接。
- 例如,抗原全局节点
-
global_global
边:连接 全局节点之间 的边,这些边允许不同链的全局节点之间通信。- 比如,抗原的全局节点可以连接到重链和轻链的全局节点,以共享跨链的全局信息。
3. seq_edges
的含义
seq_edges
表示序列中的相邻节点之间的边,即在一维序列上顺序相邻的节点之间建立连接。其目的是确保每个残基的相邻节点能够直接连接起来,从而保持蛋白质链的结构顺序。例如:
- 如果节点
i
和i+1
在一维序列上是相邻的,那么它们之间会建立一个seq_edge
,帮助模型在图上保持链的顺序结构。
总结
global_normal
边:将链的全局节点与该链中的普通节点连接。global_global
边:将不同链的全局节点相互连接,促进跨链的全局信息传递。seq_edges
:在序列上相邻的节点之间建立连接,确保模型理解蛋白质链的线性顺序。
这些边的设计帮助模型在全局和局部层面都能传递信息,确保不同链及其全局节点间的信息交换。
local_edges = local_ctx_edges + local_ inter_edges
在这个模型中,paratope
的设置直接影响 local_ctx_edges
和 local_inter_edges
的构建,限制了边构建的节点范围和连接方式。为了展示这个过程,从初始赋值开始,主要关注 paratope
是如何逐步影响到边构建的区域的。以下是从代码的核心流程一步步展示如何进行的:
1. 初始化 paratope
参数
在模型的初始化配置中,paratope
参数设置为 ["L3", "H3"]
,这意味着只将 L3
和 H3
片段设置为抗体抗原接口(paratope)。在 E2EDataset
类的 __getitem__
函数中,我们使用 paratope
来生成 paratope_mask
,用于标记哪些片段参与到 paratope
区域内的计算。简化展示:
# 根据 paratope 设置生成 paratope_mask
paratope_mask = [0 for _ in range(len(ag_data['S']) + len(hc_data['S']) + len(lc_data['S']))]
paratope = [self.paratope] if type(self.paratope) == str else self.paratope
for cdr in paratope:
cdr_range = item.get_cdr_pos(cdr)
offset = len(ag_data['S']) + 1 + (0 if cdr[0] == 'H' else len(hc_data['S']))
for idx in range(offset + cdr_range[0], offset + cdr_range[1] + 1):
paratope_mask[idx] = 1
data['paratope_mask'] = paratope_mask
在这里,paratope_mask
的位置为 1 的元素对应于 L3
和 H3
片段。这一掩码在 local_ctx_edges
和 local_inter_edges
的构建中直接限制了边的构建范围。
2. 在 _prepare_batch_constants
中生成 local_mask
和 local_is_ab
paratope_mask
接下来会在 dyMEANModel
的 _prepare_batch_constants
函数中用于生成 local_mask
,标记哪些节点被选为局部节点:
# 使用 paratope_mask 限制边构建的范围
local_mask = torch.logical_or(
paratope_mask, torch.logical_and(is_ag, not_ag_global) # 标记抗原的非全局节点
)
local_mask
将只包含 paratope
内的抗体节点 L3
和 H3
,以及非全局的抗原节点。这样,local_ctx_edges
只会考虑 L3
和 H3
内部的连接,而不会涉及其他抗体片段(例如 H1
或 L1
)。
3. 构建 local_ctx_edges
在 dyMEANModel
的 message_passing
函数中,根据 local_mask
和 paratope_mask
来调用 _knn_edges
函数以构建 local_ctx_edges
:
local_ctx_edges = _knn_edges(
local_X, atom_pos, local_ctx_edges.T,
self.aa_feature.atom_pos_pad_idx, self.k_neighbors,
(offsets, local_batch_id, max_n, gni2lni)
)
这里,local_ctx_edges
将只在 paratope_mask
内部构建边,_knn_edges
根据 k_neighbors
的限制生成 L3-L3
和 H3-H3
的边,而不会跨越到其他区域,也不会包含 E-E
边。
4. 构建 local_inter_edges
对于 local_inter_edges
,则是构建抗体和抗原之间的交互边,同样受到 paratope_mask
限制:
local_inter_edges = _knn_edges(
local_X, atom_pos, local_inter_edges.T,
self.aa_feature.atom_pos_pad_idx, self.k_neighbors,
(offsets, local_batch_id, max_n, gni2lni), given_dist=p_edge_dist
)
local_inter_edges
这里会构建 E-H3
和 E-L3
的边,因为它们被 paratope_mask
标记为抗体抗原交互区域,而不会包含 H1-H2
或其他非 paratope
的交互。
总结
通过 paratope_mask
的逐步应用,local_ctx_edges
仅构建 L3-L3
和 H3-H3
内部边,local_inter_edges
仅包含 E-L3
和 E-H3
边。
local_ctx_edges
不会包含 E-E
边,因为 local_ctx_edges
的构建仅在抗体(Antibody)链内有效,构建边的对象只会是 paratope
设置内的抗体片段,也就是 L3-L3
和 H3-H3
。
因此,local_ctx_edges
只包含 L3-L3
和 H3-H3
的局部边,而不会包含 E-E
的边。
在代码中设置了 args.paratope = ["L3", "H3"]
后,local_ctx_edges
和 local_inter_edges
的构建会受到 paratope
设置的影响,使得只在局部包含 paratope
(即 L3 和 H3)相关的边。
以下是如何在代码实现中影响 local_ctx_edges
和 local_inter_edges
:
-
local_ctx_edges
:local_ctx_edges
会构建同一 segment 内的边,即在paratope
设置下仅包含L3-L3
和H3-H3
的边。- 所以,即使代码有逻辑生成 H1-H1、H2-H2 等同一 segment 的边,这些边不会包含在
local_ctx_edges
中。代码会过滤掉这些非paratope
相关的边,只保留L3-L3
和H3-H3
的局部上下文边。
-
local_inter_edges
:local_inter_edges
只包含paratope
(L3 和 H3)与抗原E
的交互边。- 因此,在代码中设置
args.paratope = ["L3", "H3"]
后,local_inter_edges
只包含L3-E
和H3-E
的交互边,而不会包含 H1-H2、L1-L2 等边。
总结:在设置了 paratope = ["L3", "H3"]
后,local_ctx_edges
只包含 L3-L3
和 H3-H3
,而 local_inter_edges
只包含 L3-E
和 H3-E
的边。
关于不同部分的边是否包含全局节点
在代码逻辑中,local_protein_ids
是基于 local_segment_ids
构建的,而 local_segment_ids
本身是对局部区域(如 E
、H3
和 L3
)的节点进行标记。从代码来看,local_segment_ids
不包含全局节点。具体原因如下:
-
local_segment_ids
构建:
在生成local_segment_ids
时,代码中会过滤掉全局节点,仅保留与特定 CDR 相关的节点(例如H3
和L3
)以及抗原的局部节点(E
),因此全局节点不会出现在local_segment_ids
中。 -
local_protein_ids
的生成:
local_protein_ids
是基于local_segment_ids
的索引依次赋值,表明它只在每个局部区域中唯一标识残基链。因此,local_protein_ids
仅包含局部节点,和local_segment_ids
保持一致。 -
边构建中的限制:
local_ctx_edges
和local_inter_edges
的生成同样基于局部区域,不涉及全局节点的边,从而确保了local_protein_ids
和local_segment_ids
中都不含全局节点。
这样,local_protein_ids
中将仅包含 E
、H3
和 L3
区域的残基,并不包括全局节点。这种设置确保了在局部消息传递时,操作仅限于特定区域的局部节点。
是的,代码中 segment_ids
包含全局节点,而 local_segment_ids
则不包含全局节点。以下是详细说明以及相关代码段验证如何在构建 local_segment_ids
时去掉了全局节点,包括 E
区域的全局节点。
1. segment_ids
的构建
在代码中,segment_ids
是通过 _construct_segment_ids(S)
方法构建的,通常如下:
segment_ids = self._construct_segment_ids(S)
segment_ids
包含所有节点,包括全局节点。全局节点通常使用特殊的标记 (boa_idx
、boh_idx
、bol_idx
),这些标记会根据抗原、重链、轻链等分段进行区分和编号,存储在 segment_ids
中。因此,segment_ids
是包含全局节点的。
2. local_segment_ids
构建过程中去掉全局节点的代码逻辑
在 dyMEANModel
类中,通过以下代码实现 local_segment_ids
的构建,其中去掉了全局节点:
# Filtering to remove global nodes, and only keep specific regions based on args.paratope (e.g., H3, L3)
local_mask = torch.logical_or(
paratope_mask, torch.logical_and(is_ag, not_ag_global)
)
paratope_mask
:根据args.paratope
指定的 CDR 区域(例如H3
和L3
),在S
中标记哪些节点属于 CDR 区域,这里不包含任何全局节点。is_ag
和not_ag_global
:is_ag
表示抗原节点,而not_ag_global
去掉了抗原中的全局节点。
最终,local_mask
只保留了局部节点中的 E
(去掉了抗原中的全局节点)和 args.paratope
指定的 CDR 区域。通过 local_mask
筛选后的 local_segment_ids
不包含任何全局节点。
3. 验证代码
# local_segment_ids 从 segment_ids 中去掉全局节点后的结果
local_segment_ids = segment_ids[local_mask]
通过 local_mask
过滤后,local_segment_ids
最终仅包含 E
区域的局部节点以及指定的 H3
和 L3
区域,去掉了所有的全局节点(包括 E
区域的全局节点)。
是的,paratope_mask
仅标记指定的 paratope 区域(如 H3
和 L3
),不包含任何全局节点。其构建逻辑专门排除了全局节点,确保 paratope_mask
仅作用于需要关注的局部区域。
在代码中,paratope_mask
是通过 args.paratope
参数指定的 CDR 区域来设置的。例如,当 args.paratope
设为 ["H3", "L3"]
时,paratope_mask
将仅标记属于 H3
和 L3
的局部区域节点,而不会包含全局节点 (boa_idx
、boh_idx
、bol_idx
)。
以下代码片段展示了 paratope_mask
的构建过程:
# 初始化 paratope_mask
paratope_mask = [0 for _ in range(len(ag_data['S']) + len(hc_data['S']) + len(lc_data['S']))]
# 根据指定的 paratope 区域,如 H3 和 L3,标记 paratope_mask
paratope = [self.paratope] if type(self.paratope) == str else self.paratope
for cdr in paratope:
cdr_range = item.get_cdr_pos(cdr)
offset = len(ag_data['S']) + 1 + (0 if cdr[0] == 'H' else len(hc_data['S']))
for idx in range(offset + cdr_range[0], offset + cdr_range[1] + 1):
paratope_mask[idx] = 1
这个代码确保 paratope_mask
只对指定的 CDR 区域(H3
和 L3
)的节点生效,不会将全局节点包括进来。
关于三角矩阵是对哪部分进行构建
是的,local_segment_ids
确实不包含全局节点,它仅包括被 local_mask
标记为局部的节点。而对于每个批次(batch)的 local_segment_ids
,根据 paratope
设置为 ["L3", "H3"]
的情况,确实会包含以下三个部分:
- 抗原 (Epitope, E): 包含所有被认为是抗原的非全局节点。
- H3 片段: 作为抗体的 CDR 区域之一,标记为
paratope
片段。 - L3 片段: 作为抗体的另一个 CDR 区域,也被标记为
paratope
片段。
更详细说明
在 local_mask
构建过程中,全局节点已经被排除在外,所以 local_segment_ids
只会保留与 paratope
和抗原有关的非全局节点。因此,local_segment_ids
会包括 E
(抗原)、H3
、L3
,且不会包含任何全局节点。
是的,在构建三角矩阵时,按照当前代码中的逻辑,会对 E
(抗原)、H3
、和 L3
分别构建矩阵。以下是如何实现的解释:
-
局部蛋白质的划分:根据
local_protein_ids
中的分配,将inter_h
特征划分成三部分,分别对应E
(抗原)、H3
、和L3
。 -
矩阵构建:
- 使用
local_protein_ids
的unique_protein_ids
获取每个局部蛋白质。 - 对每个局部蛋白质,构建一个矩阵,矩阵中的每个位置代表两两残基的特征组合。
- 通过
mlp
函数进行逐元素拼接和计算,生成矩阵的每个元素。
- 使用
-
三角消息传递:
- 将每个矩阵依次通过
triangle_multiply_outgoing
、triangle_multiply_ingoing
、triangle_attention_outgoing
和triangle_attention_ingoing
进行三角消息传递和更新,保留局部蛋白质的上下文信息。 - 最后在每个矩阵上批量化操作,以维持不同局部蛋白质之间的独立性,确保
E
、H3
、和L3
的矩阵单独操作。
- 将每个矩阵依次通过
-
特征提取和降维:
- 将消息传递后的矩阵提取出对角线的特征,表示每个残基更新后的特征。
- 使用
reduce
将这些特征降维到统一维度,再拼接成最终的inter_h
。
这种构建方式确保了对 E
、H3
、和 L3
的独立处理,并在局部三角矩阵中保留每部分内部残基之间的消息传递。