Let me help analyze the residue position embedding from the code.
Looking at the AminoAcidEmbedding
class in the first file, there are two options for position embeddings:
- Relative Position Embedding:
if relative_position:
self.res_pos_embedding = SinusoidalPositionEmbedding(res_embed_size) # relative positional encoding
- Absolute Position Embedding:
else:
self.res_pos_embedding = nn.Embedding(max_position, res_embed_size) # absolute position encoding
# max_position=192 by default
The maximum length of residue position embedding depends on which option is used:
-
For relative position embedding (using
SinusoidalPositionEmbedding
), there is no fixed maximum length limitation since it uses sinusoidal functions to encode positions. This allows it to theoretically handle any sequence length. -
For absolute position embedding, there is a fixed maximum length of 192 positions, as specified in the default parameter
max_position=192
in theAminoAcidEmbedding
constructor.
Looking at the usage in the code, the relative position embedding appears to be the default choice, as seen in the model instantiation:
self.aa_embedding = AminoAcidEmbedding(
self.num_aa_type, self.num_atom_type, self.num_atom_pos,
embed_size, embed_size, self.atom_pad_idx, relative_position=True)
So in summary:
- By default, it uses relative positional encoding which has no fixed maximum length
- If switched to absolute positional encoding (relative_position=False), then it has a fixed maximum length of 192 positions
- This max_position parameter of 192 appears to be chosen based on IMGT numbering scheme for antibodies, as commented in the code
i am wondering whether antigen, heavy chain, light chain using different position?
Yes, antigen, heavy chain, and light chain use different position indices. Let me explain by looking at the code:
In AminoAcidFeature
class, there's a method _construct_residue_pos
that handles this:
def _construct_residue_pos(self, S):
# construct residue position. global node is 1, the first residue is 2, ... (0 for padding)
glbl_node_mask = self._is_global(S) # Identifies global nodes (BOA, BOH, BOL)
glbl_node_idx = torch.nonzero(glbl_node_mask).flatten() # [batch_size * 3] (boa, boh, bol)
# Calculate shifts between global nodes
shift = F.pad(glbl_node_idx[:-1] - glbl_node_idx[1:] + 1, (1, 0), value=1)
# Initialize all positions with ones
residue_pos = torch.ones_like(S)
residue_pos[glbl_node_mask] = shift
# Cumulative sum to generate positions
residue_pos = torch.cumsum(residue_pos, dim=0)
Let's break down how the positions are assigned:
-
The sequence starts with three special tokens:
- BOA (Begin of Antigen)
- BOH (Begin of Heavy chain)
- BOL (Begin of Light chain)
-
After each special token, position numbering starts fresh:
- Antigen residues: Start after BOA token
- Heavy chain residues: Start after BOH token
- Light chain residues: Start after BOL token
-
The positioning is done through cumulative sum, ensuring:
- Each section (antigen/heavy/light) has its own independent position numbering
- The positions are reset at each global token (BOA/BOH/BOL)
For example, a sequence might be numbered like this:
BOA Ag1 Ag2 Ag3 BOH H1 H2 H3 BOL L1 L2 L3
1 2 3 4 1 2 3 4 1 2 3 4
This separate position numbering for each chain type helps the model distinguish between residues from different chains while maintaining relative positional information within each chain.
标签:dymean,residue,chain,self,pos,relative,维度,position From: https://www.cnblogs.com/GraphL/p/18634168