Let me help analyze the residue position embedding from the code.

Looking at the AminoAcidEmbedding class in the first file, there are two options for position embeddings:

  1. Relative Position Embedding:
if relative_position:
    self.res_pos_embedding = SinusoidalPositionEmbedding(res_embed_size)  # relative positional encoding
  1. Absolute Position Embedding:
    self.res_pos_embedding = nn.Embedding(max_position, res_embed_size)  # absolute position encoding  
    # max_position=192 by default

The maximum length of residue position embedding depends on which option is used:

  • For relative position embedding (using SinusoidalPositionEmbedding), there is no fixed maximum length limitation since it uses sinusoidal functions to encode positions. This allows it to theoretically handle any sequence length.

  • For absolute position embedding, there is a fixed maximum length of 192 positions, as specified in the default parameter max_position=192 in the AminoAcidEmbedding constructor.

Looking at the usage in the code, the relative position embedding appears to be the default choice, as seen in the model instantiation:

self.aa_embedding = AminoAcidEmbedding(
    self.num_aa_type, self.num_atom_type, self.num_atom_pos,
    embed_size, embed_size, self.atom_pad_idx, relative_position=True)

So in summary:

  • By default, it uses relative positional encoding which has no fixed maximum length
  • If switched to absolute positional encoding (relative_position=False), then it has a fixed maximum length of 192 positions
  • This max_position parameter of 192 appears to be chosen based on IMGT numbering scheme for antibodies, as commented in the code

i am wondering whether antigen, heavy chain, light chain using different position?

Yes, antigen, heavy chain, and light chain use different position indices. Let me explain by looking at the code:

In AminoAcidFeature class, there's a method _construct_residue_pos that handles this:

def _construct_residue_pos(self, S):
    # construct residue position. global node is 1, the first residue is 2, ... (0 for padding)
    glbl_node_mask = self._is_global(S)  # Identifies global nodes (BOA, BOH, BOL)
    glbl_node_idx = torch.nonzero(glbl_node_mask).flatten()  # [batch_size * 3] (boa, boh, bol)
    # Calculate shifts between global nodes
    shift = F.pad(glbl_node_idx[:-1] - glbl_node_idx[1:] + 1, (1, 0), value=1)
    # Initialize all positions with ones
    residue_pos = torch.ones_like(S)
    residue_pos[glbl_node_mask] = shift
    # Cumulative sum to generate positions
    residue_pos = torch.cumsum(residue_pos, dim=0)

Let's break down how the positions are assigned:

  1. The sequence starts with three special tokens:

    • BOA (Begin of Antigen)
    • BOH (Begin of Heavy chain)
    • BOL (Begin of Light chain)
  2. After each special token, position numbering starts fresh:

    • Antigen residues: Start after BOA token
    • Heavy chain residues: Start after BOH token
    • Light chain residues: Start after BOL token
  3. The positioning is done through cumulative sum, ensuring:

    • Each section (antigen/heavy/light) has its own independent position numbering
    • The positions are reset at each global token (BOA/BOH/BOL)

For example, a sequence might be numbered like this:

BOA   Ag1  Ag2  Ag3   BOH   H1   H2   H3   BOL   L1   L2   L3
1     2    3    4     1     2    3    4    1     2    3    4

This separate position numbering for each chain type helps the model distinguish between residues from different chains while maintaining relative positional information within each chain.

