首页 > 其他分享 >课前准备-单细胞转录组联合VDJ数据分析

课前准备-单细胞转录组联合VDJ数据分析

时间:2024-06-16 11:29:08浏览次数:27  
标签:数据分析 triplet embeddings val train 转录 test VDJ TYPE

作者,Evil Genius

而我们需要实现的分析,即VDJ聚类与motif分析

分析会在课上讲到,报名链接在2024年单细胞空间系列课程
完整脚本如下(封装版)

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from keras.callbacks import EarlyStopping
from keras.callbacks import TensorBoard
import tensorflow as tf
import tensorflow_addons as tfa
#import tensorflow_datasets as tfds


import csv
import os 
import argparse

#import other python files
import config_script
import model_architecture
import enc_decode
import plot_latent_space
import cluster_performance_functions

# parse the command line
if __name__ == "__main__":
    # initialize Argument Parser
    parser = argparse.ArgumentParser()    #help="Debug", nargs='?', type=int, const=1, default=7
    # TCR input parameters
    parser.add_argument("--input_type", help="beta_chain or paired_chain", type=str)
    parser.add_argument("--encoding_type", help="onehot or ESM", type=str) ## one-hot or ESM embeddings
    parser.add_argument("--esm_type", help="ESM1v or ESM2", type=str, default='ESM1v')
    parser.add_argument("--chain_type", help="TCR or ab_VC or ab_V", type=str, default='ab_V')

    # model parameters
    parser.add_argument("--embedding_space", type=int) ## embedding dimension for TCR
    parser.add_argument("--epochs", type=int, default=100)
    parser.add_argument("--patience", type = int, default = 15)
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--learning_rate", type=float, default=0.0001)
    parser.add_argument("--triplet_loss", help="hard or semihard", type=str, default='semihard')
    
    #plot embeddings
    parser.add_argument("--plot_embedding", type = bool, default = False)


######################################## --> started here
# we then read the arguments from the comand line
args = parser.parse_args()

INPUT_TYPE = args.input_type  #beta_chain or paired_chain
ENCODING_TYPE = args.encoding_type
ESM_TYPE = args.esm_type
CHAIN_TYPE = args.chain_type

EMBEDDING_SPACE = args.embedding_space
EPOCHS = args.epochs
PATIENCE = args.patience
BATCH_SIZE = args.batch_size
LRATE = args.learning_rate
LOSS = args.triplet_loss # 'semihard' or 'hard'

PLOT_EMB = args.plot_embedding  
# we then read the arguments from the comand line
# DATA_SIZE = args.data_size

#triplet hyperparameters
LOSS = args.triplet_loss #'semihard' or 'hard'
PLOT_EMB = args.plot_embedding

########################################################################################
############################ Open TCRpMHC and encode them ##############################

# open the files and assign some constant variables
TCRpMHC = pd.read_csv(config_script.TCRpMHC_ESM2_data, header = 0)
epi_list = TCRpMHC['Epitope'].unique().tolist()
epi_color = plot_latent_space.palette_31[:len(epi_list)]
epi_color_dict = dict(zip(epi_list, epi_color))
TCRpMHC['color'] = TCRpMHC['Epitope'].map(epi_color_dict)

TCRpMHC['CDR123ab'] = TCRpMHC['CDR1a'] + TCRpMHC['CDR2a'] +  TCRpMHC['CDR3a'] + \
    TCRpMHC['CDR1b'] + TCRpMHC['CDR2b'] + TCRpMHC['CDR3b']
TCRpMHC['CDR123b'] = TCRpMHC['CDR1b'] + TCRpMHC['CDR2b'] + TCRpMHC['CDR3b']


if ENCODING_TYPE == 'onehot':      
    model_input_dict = config_script.model_input_dict
    model_input_key = INPUT_TYPE
    df_columns = model_input_dict[model_input_key][0] #get columns to encode
    max_seq_len = model_input_dict[model_input_key][1] #get max length for padding
    print ('df_columns: ', df_columns, 'max_seq_len: ', max_seq_len)

    # make a dist of one-hot encoded amino acids for data encoding:
    AA_list = list("ACDEFGHIKLMNPQRSTVWY_")
    AA_one_hot_dict = {char : l for char, l in zip(AA_list, np.eye(len(AA_list), k=0))}

if ENCODING_TYPE == 'ESM':
    if ESM_TYPE == 'ESM1v':
        if CHAIN_TYPE == 'TCR':
            embeddings = np.loadtxt(config_script.esm_embed_dir + 'albeta_VC_esm1v.txt', dtype = float)
        if CHAIN_TYPE == 'ab_VC':
            if INPUT_TYPE == 'beta_chain':
                embeddings = np.loadtxt(config_script.esm_embed_dir + 'beta_VC_esm1v.txt', dtype = float)
            if INPUT_TYPE == 'paired_chain':
                embeddings_b = np.loadtxt(config_script.esm_embed_dir + 'beta_VC_esm1v.txt', dtype = float)
                embeddings_a = np.loadtxt(config_script.esm_embed_dir + 'alpha_VC_esm1v.txt', dtype = float)
                embeddings = np.concatenate((embeddings_a, embeddings_b), axis=1) 
        if CHAIN_TYPE == 'ab_V':
            if INPUT_TYPE == 'beta_chain':
                embeddings = np.loadtxt(config_script.esm_embed_dir + 'beta_V_esm1v.txt', dtype = float)
            if INPUT_TYPE == 'paired_chain':
                embeddings_a = np.loadtxt(config_script.esm_embed_dir + 'alpha_V_esm1v.txt', dtype = float)
                embeddings_b = np.loadtxt(config_script.esm_embed_dir + 'beta_V_esm1v.txt', dtype = float)
                embeddings = np.concatenate((embeddings_a, embeddings_b), axis=1) 

    if ESM_TYPE == 'ESM2':
        if CHAIN_TYPE == 'TCR':
            embeddings = np.loadtxt(config_script.esm_embed_dir + 'albeta_VC_esm2.txt', dtype = float)
        if CHAIN_TYPE == 'ab_VC':
            if INPUT_TYPE == 'beta_chain':
                embeddings = np.loadtxt(config_script.esm_embed_dir + 'beta_VC_esm2.txt', dtype = float)
            if INPUT_TYPE == 'paired_chain':
                embeddings_a = np.loadtxt(config_script.esm_embed_dir + 'alpha_VC_esm2.txt', dtype = float)
                embeddings_b = np.loadtxt(config_script.esm_embed_dir + 'beta_VC_esm2.txt', dtype = float)
                embeddings = np.concatenate((embeddings_a, embeddings_b), axis=1) 
        if CHAIN_TYPE == 'ab_V':
            if INPUT_TYPE == 'beta_chain':
                embeddings = np.loadtxt(config_script.esm_embed_dir + 'beta_V_esm2.txt', dtype = float)
            if INPUT_TYPE == 'paired_chain':
                embeddings_a = np.loadtxt(config_script.esm_embed_dir + 'alpha_V_esm2.txt', dtype = float)
                embeddings_b = np.loadtxt(config_script.esm_embed_dir + 'beta_V_esm2.txt', dtype = float)
                embeddings = np.concatenate((embeddings_a, embeddings_b), axis=1) 

# K-fold Cross Validation model evaluation #when model parameters are already settled
#folds_list = TCRpMHC['part'].unique().tolist()      
tf.keras.utils.set_random_seed(1234) ## set ALL random seeds for the program
import random
all_results = pd.DataFrame()
test_val_folds = [[0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 0]]
for test_val in test_val_folds: #len(test_val) range(6)
    #test_val = random.sample(folds_list, 2)
    test_val_fold_str = str(test_val[0]) + str(test_val[1])
    #test_val_fold_str = str(fold[0]) + str(fold[1])
    val = TCRpMHC[TCRpMHC['part']==test_val[0]]
    test = TCRpMHC[TCRpMHC['part']==test_val[1]]
    train = TCRpMHC[~TCRpMHC['part'].isin(test_val)]
    
    #test = TCRpMHC[TCRpMHC['part']==5] #6

    #encode the data
    if ENCODING_TYPE == 'onehot':
        X_train = enc_decode.CDR_one_hot_encoding(train, df_columns, AA_list, AA_one_hot_dict, max_seq_len)
        X_val = enc_decode.CDR_one_hot_encoding(val, df_columns, AA_list, AA_one_hot_dict, max_seq_len)
        X_test = enc_decode.CDR_one_hot_encoding(test, df_columns, AA_list, AA_one_hot_dict, max_seq_len)
        #X_all_data = enc_decode.CDR_one_hot_encoding(TCRpMHC, df_columns, AA_list, AA_one_hot_dict, max_seq_len)
        print ('Seqs encoded!')

    if ENCODING_TYPE == 'ESM':
        #1) get indexes
        train_idx = train.index.tolist()
        val_idx = val.index.tolist()
        test_idx = test.index.tolist()

        #2) get embeddings
        X_train = embeddings[train_idx]
        X_val = embeddings[val_idx]
        X_test = embeddings[test_idx]

        #3) change the dimension to train CNN
        dim_2 = int(len(X_train[0])/20) #get the dimension for reshaping
        X_train = np.reshape(X_train, (-1, dim_2, 20)) #it's either 64*20 or 128*20, depending on TCR input type
        X_val = np.reshape(X_val, (-1, dim_2, 20))
        X_test = np.reshape(X_test, (-1, dim_2, 20))

    y_train = train['label'].values
    y_val = val['label'].values
    y_test = test['label'].values

    #save the Epi label for plotting
    y_train_epi = train['Epitope'].tolist()
    y_val_epi = val['Epitope'].tolist()
    y_test_epi = test['Epitope'].tolist()

    print ('X_train: ', X_train.shape, \
           'X_val: ',  X_val.shape, \
           'X_test: ', X_test.shape)

    #turn array into tensors
    train_triplet = tf.data.Dataset.from_tensor_slices((X_train, y_train))
    train_triplet = train_triplet.shuffle(1024).batch(BATCH_SIZE)
    val_triplet = tf.data.Dataset.from_tensor_slices((X_val, y_val))
    val_triplet = val_triplet.batch(BATCH_SIZE)

######################################################################################
############################ Set up & train a model  #################################

    if LOSS == 'hard':
        loss = tfa.losses.TripletHardLoss()
    if LOSS == 'semihard':
        loss = tfa.losses.TripletSemiHardLoss()


    tf.keras.backend.clear_session()

    triplet = model_architecture.Triplet_CNN_naive(X_train.shape[1:], embedding_units = EMBEDDING_SPACE)
    triplet.compile(optimizer=tf.keras.optimizers.Adam(LRATE), loss=loss) #tfa.losses.TripletSemiHardLoss()
    triplet.summary()
    early_stop = EarlyStopping(monitor='val_loss', mode='min', patience = PATIENCE, \
                       verbose = 1, restore_best_weights = True)
    
    print ('Starting training...')
    history = triplet.fit(
        train_triplet,
        epochs=EPOCHS, 
        validation_data = val_triplet, 
        callbacks = [early_stop])

    print ('Model trained!')

    # Extract the best loss value (from early stopping) 
    loss_hist = history.history['val_loss']
    best_val_loss = np.min(loss_hist)
    best_epochs = np.argmin(loss_hist) + 1

    if ENCODING_TYPE == 'onehot':
        TRIPLET_NAME =f'triplet_{ENCODING_TYPE}_{INPUT_TYPE}_{EMBEDDING_SPACE}d_{LRATE}lr_{BATCH_SIZE}btch_{EPOCHS}ep_{PATIENCE}ptence_fold{test_val_fold_str}'
    if ENCODING_TYPE == 'ESM':
        TRIPLET_NAME =f'triplet_{INPUT_TYPE}_{ESM_TYPE}_{CHAIN_TYPE}_{EMBEDDING_SPACE}d_{LRATE}lr_{BATCH_SIZE}btch_{EPOCHS}ep_{PATIENCE}ptence_fold{test_val_fold_str}'

    modelpath = config_script.triplet_out_model + TRIPLET_NAME
    if not os.path.exists(modelpath):
        os.makedirs(modelpath)
    # save the model
    triplet.save(config_script.triplet_out_model + TRIPLET_NAME)

    
    #print ('\n')
    print ('Model name: ', TRIPLET_NAME)
    print ('best loss: ', best_val_loss, 'best_epochs: ', best_epochs)
    print ('Triplet model saved!')
    
    ############################## extract clustering performance ###############################
    ############################## ################################ #############################
    if INPUT_TYPE == 'beta_chain':
        seq_column = 'CDR123b'
    if INPUT_TYPE == 'paired_chain':
        seq_column = 'CDR123ab'
    
    results_val = triplet(X_val)
    val_emb_numpy = results_val.numpy()
    results_test = triplet(X_test)
    test_emb_numpy = results_test.numpy()
    
    epsilon_threshold = [0.2, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.7, 0.85, 0.95]
    for epsilon in epsilon_threshold:
        result_dbscan_val = cluster_performance_functions.get_performance_DBSCAN(val, TRIPLET_NAME, \
                                                        val_emb_numpy, seq_column, epsilon)
        result_dbscan_val['split'] = 'val'
        result_dbscan_val['fold'] = test_val[0]
    
    
        result_dbscan_test = cluster_performance_functions.get_performance_DBSCAN(test, TRIPLET_NAME, \
                                                            test_emb_numpy, seq_column, epsilon)
        result_dbscan_test['split'] = 'test'
        result_dbscan_test['fold'] = test_val[1]

        all_results = pd.concat([all_results, result_dbscan_val, result_dbscan_test], ignore_index = True)


all_results.to_csv(config_script.triplet_analysis_dir + TRIPLET_NAME + '_WITH_FOLDS_clstr_metrics.csv')

"""
########################################################################################
############################## Plot model embeddings ###################################
triplet = keras.models.load_model(config_script.triplet_out_model + TRIPLET_NAME)  
if PLOT_EMB == True:
    
    results_val = triplet(X_val)
    results_test = triplet(X_test)
    results_train = triplet(X_train)

    umap_pca_test = plot_latent_space.umap_pca_df(results_test, y_test_epi, random_seed = 42)
    plot_latent_space.plot_umap_pca_latent(umap_pca_test, 'Epitope', color_dict = epi_color_dict, \
                         fig_title = f'Epitope distribution by {TRIPLET_MODE} model with {INPUT_TYPE} triplet {EMBEDDING_SPACE}embedding ', \
                         legend_title = 'Epitope', node_size = 70, \
                         save = True, figname = f'Epitope_{TRIPLET_NAME}.jpg')
生活很好,有你更好

TCR clustering by contrastive learning on antigen specificity

标签:数据分析,triplet,embeddings,val,train,转录,test,VDJ,TYPE
From: https://blog.csdn.net/weixin_53637133/article/details/139669349

相关文章

  • 统计建模选题推荐在数据驱动的现代社会中,统计建模已经成为一种重要的数据分析工具。通
    统计建模选题推荐在数据驱动的现代社会中,统计建模已经成为一种重要的数据分析工具。通过构建统计模型,我们可以深入探索数据的内在规律,预测未来的趋势,为决策提供有力的支持。本文将为您推荐几个统计建模的选题,帮助您开启一段富有挑战性和收获的数据分析之旅。首先,我们推荐探讨......
  • 数据分析必备:一步步教你如何用matplotlib做数据可视化(1)
    1、Matplotlib教程Matplotlib是用于数据可视化的最流行的Python包之一。它是一个跨平台库,用于根据数组中的数据制作2D图。它提供了一个面向对象的API,有助于使用PythonGUI工具包(如PyQt,WxPythonotTkinter)在应用程序中嵌入绘图。它也可以用于Python和IPythonshell,Jupyte......
  • R语言数据分析案例27-使用随机森林模型对家庭资产的回归预测分析
    一、研究背景及其意义家庭资产分析在现代经济学中的重要性不仅限于单个家庭的财务健康状况,它还与整个经济体的发展紧密相关。家庭资产的增长通常反映了国家经济的整体增长,而资产分布的不均则暴露了经济不平等的问题。因此,全球视角下的家庭资产分析可以揭示国际经济动态,有助于......
  • python数据分析-淘票票电影可视化
    一、研究背景和意义在当今数字化和媒体饱和的时代,电影产业不仅是文化的重要组成部分,也是全球经济的一大推动力。电影不仅能够反映社会现实和文化趋势,还能预示和塑造公众的兴趣与期待。因此,深入分析电影数据集具有重要的实践和理论意义。通过对电影数据进行描述性统计分析,在电......
  • 18.9k star!一个高性能的嵌入式分析型数据库,主要用于数据分析和数据处理任务。
    大家好,今天给大家分享的是一个开源的面向列的关系数据库管理系统(RDBMS)。DuckDB是一个嵌入式的分析型数据库,它提供了高性能的数据分析和数据处理能力。DuckDB的设计目标是为数据科学家、分析师和数据工程师提供一个快速、灵活且易于使用的数据分析工具。它支持SQL查询语言,并提......
  • 消防科技的未来已来:可视化数据分析平台揭秘
    一、什么是智慧消防可视化数据分析平台?智慧消防可视化数据分析平台,运用大数据、云计算、物联网等先进技术,将消防信息以直观、易懂的图形化方式展示出来。它不仅能够实时监控消防设备的运行状态,还能对火灾风险进行预测和评估,为消防部门提供决策支持。山海鲸可视化智慧消防可视化......
  • 腾讯云 BI 数据分析与可视化的快速入门指南
    前言腾讯云BI是一款商业智能解决方案,提供数据接入、分析、可视化、门户搭建和权限管理等全流程服务。它支持敏捷自助设计,简化报表制作,并通过企业微信等渠道实现协作。产品分为个人版、基础版、专业版和私有化版,满足不同规模企业的需求,从个人学习到大型企业数字化转型,提供......
  • R数据分析:临床研究样本量计算、结果解读与实操
    很久之前给大家写过一篇文章详细介绍了样本量计算的底层逻辑,不过那篇文章原理是依照卡方比较来写的,可以拓展到均值比较,但视角还是比较小,今天从整个临床研究的角度结合具体的例子谈谈大家遇到的样本量的计算方法。有操作,有原理,有比较,认真阅读下来应该会大有裨益。统计课上都讲过:我......
  • python数据分析-笔记本内存和价格预测分析
    一、背景和研究意义计算机已成为现代社会不可或缺的工具,广泛应用于个人生活、学术研究和商业领域。随着科学技术的飞速发展,计算机不仅在性能上不断突破,在种类和品牌上也呈现出多样化和差异化。无论是办公、娱乐、学习还是创作,人们都离不开电脑的帮助。然而,随着电脑市场的不断......
  • python数据分析-房价数据集聚类分析
    一、研究背景和意义随着房地产市场的快速发展,房价数据成为了人们关注的焦点。了解房价的分布特征、影响因素以及不同区域之间的差异对于购房者、房地产开发商、政府部门等都具有重要的意义。通过对房价数据的聚类分析,可以深入了解房价的内在结构和规律,为相关决策提供科学依据......