首页 > 编程问答 >HuggingFace:使用 Transformer 对 DNA 序列进行高效大规模嵌入提取

HuggingFace:使用 Transformer 对 DNA 序列进行高效大规模嵌入提取

时间:2024-08-09 11:54:46浏览次数:7  
标签:python huggingface-transformers huggingface huggingface-tokenizers huggingface-d

我有一个非常大的数据框(60+ 百万行),我想使用转换器模型来获取这些行(DNA 序列)的嵌入。基本上,这首先涉及标记化,然后我可以获得嵌入。 由于 RAM 限制,我发现标记化然后将所有内容嵌入到一个 py 文件中是行不通的。这是我发现的解决方法,适用于大约 3000 万行的数据帧(但不适用于较大的 df):

  1. 标记化 - 将输出保存为 200 个块/分片
  2. 将这 200 个块分别提供给获取嵌入
  3. 这些嵌入,然后连接成一个更大的嵌入文件

最终嵌入文件应包含以下列: [['Cromosome', 'label', 'embeddings']]

总的来说,我对如何让它适用于我的更大数据集有点迷失。

我已经研究过流媒体数据集,但我认为这实际上没有帮助,因为我需要所有嵌入,而不仅仅是一些嵌入。如果我流式传输标记化并将其一次一点地输入到嵌入过程中(沿途删除标记),也许它可以工作。这样,我就不必保存令牌了。如果这不可行,请纠正我。

理想情况下,我希望避免对数据进行分片,但我只是希望代码在此时运行而不达到 RAM 限制。

步骤 1|| |第 2 步(此代码在 200 个分片中运行)

dataset = Dataset.from_pandas(element_final[['Chromosome', 'sequence', 'label']]) 

dataset = dataset.shuffle(seed=42)
tokenizer = AutoTokenizer.from_pretrained(f"InstaDeepAI/nucleotide-transformer-500m-human-ref")
def tokenize_function(examples):
    outputs = tokenizer.batch_encode_plus(examples["sequence"], return_tensors="pt", truncation=False, padding=False, max_length=80)
    return outputs
    
# Creating tokenized  dataset
tokenized_dataset = dataset.map(
    tokenize_function,
    batched=True, batch_size=2000)

tokenized_dataset.save_to_disk(f"tokenized_elements/tokenized_{ELEMENT}", num_shards=200)

第 3 步:将所有 200 个嵌入分片连接成 1 个分片。

input_file = f"tokenized_elements/tokenized_{ELEMENT_LABEL}/{filename}.arrow"

# Load input data
d1 = Dataset.from_file(input_file)

def embed_function(examples):
    torch.cuda.empty_cache()
    gc.collect()

    inputs = torch.tensor(examples['input_ids'])  # Convert to tensor
    inputs = inputs.to(device)

    with torch.no_grad():
        outputs = model(input_ids=inputs, output_hidden_states=True)

    # Step 3: Extract the embeddings
    hidden_states = outputs.hidden_states  # List of hidden states from all layers
    embeddings = hidden_states[-1]  # Assuming you want embeddings from the last layer
    averaged_embeddings = torch.mean(embeddings, dim=1)  # Calculate mean along dimension 1 (the dimension with size 86)
    averaged_embeddings = averaged_embeddings.to(torch.float32)  # Ensure float32 data type
    return {'embeddings': averaged_embeddings}

# Map embeddings function to input data
embeddings = d1.map(embed_function, batched=True, batch_size=1550)
embeddings = embeddings.remove_columns(["input_ids", "attention_mask"])

# Save embeddings to disk
output_dir = f"embedded_elements/embeddings_{ELEMENT_LABEL}/{filename}"  # Assuming ELEMENT_LABEL is defined elsewhere

step 3: concatenate all 200 shards of embeddings into 1.


以下是如何使用 Hugging Face Transformers 有效地从 DNA 序列中提取大规模嵌入的方法,并解决遇到的内存问题:

import pandas as pd
from transformers import AutoTokenizer, AutoModel
import torch
from torch.utils.data import Dataset, DataLoader
import gc

# 定义设备(CPU 或 GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 定义模型和分词器
model_name = "InstaDeepAI/nucleotide-transformer-500m-human-ref" 
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device)

# 定义一个自定义数据集类
class DNADataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_length):
        self.sequences = dataframe['sequence'].tolist()
        self.labels = dataframe['label'].tolist()
        self.chromosomes = dataframe['Chromosome'].tolist()
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = self.sequences[idx]
        label = self.labels[idx]
        chromosome = self.chromosomes[idx]

        encoding = self.tokenizer(
            sequence,
            add_special_tokens=True,
            truncation=True,
            padding="max_length",
            max_length=self.max_length,
            return_tensors="pt",
        )

        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'label': torch.tensor(label),
            'chromosome': chromosome
        }

# 定义数据加载器和嵌入提取函数
def get_embeddings(model, dataloader, device):
    model.eval()
    all_embeddings = []
    all_labels = []
    all_chromosomes = []
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            embeddings = outputs.last_hidden_state[:, 0, :].cpu() # 使用 [CLS] 标记的嵌入

            all_embeddings.append(embeddings)
            all_labels.extend(batch['label'])
            all_chromosomes.extend(batch['chromosome'])

    return torch.cat(all_embeddings), all_labels, all_chromosomes

# 加载数据帧(分块加载以节省内存)
chunksize = 1000000  # 根据需要调整分块大小
embeddings_list = []
labels_list = []
chromosomes_list = []

for chunk in pd.read_csv("your_data.csv", chunksize=chunksize):
    # 创建数据集和数据加载器
    dataset = DNADataset(chunk, tokenizer, max_length=80)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=False) # 根据需要调整批大小

    # 提取嵌入
    embeddings, labels, chromosomes = get_embeddings(model, dataloader, device)

    # 将结果存储到列表中
    embeddings_list.append(embeddings)
    labels_list.extend(labels)
    chromosomes_list.extend(chromosomes)

    # 释放内存
    del dataset, dataloader, embeddings, labels, chromosomes
    gc.collect()
    torch.cuda.empty_cache()

# 连接所有分块的结果
all_embeddings = torch.cat(embeddings_list)
all_labels = labels_list
all_chromosomes = chromosomes_list

# 创建最终数据帧
final_df = pd.DataFrame({
    'Chromosome': all_chromosomes,
    'label': all_labels,
    'embeddings': all_embeddings.tolist() # 将张量转换为列表以存储到数据帧
})

# 保存最终数据帧
final_df.to_csv("embeddings.csv", index=False)

说明:

  1. 数据分块: 该代码使用 pandas.read_csv chunksize 参数将大型数据帧加载到内存中。这允许以较小的块处理数据,从而防止内存溢出。
  2. 自定义数据集和数据加载器: 使用 torch.utils.data.Dataset torch.utils.data.DataLoader 可以有效地加载和预处理数据。这使可以对批处理、混洗和将数据加载到设备(CPU/GPU)进行细粒度控制。
  3. GPU 加速(如果可用): 代码会自动检测并使用可用的 GPU 来加速嵌入提取过程。
  4. 批处理: 嵌入提取在批处理模式下执行,这有助于提高效率并减少内存消耗。
  5. 内存管理: 该代码使用 del gc.collect() torch.cuda.empty_cache() 在每个分块处理后显式释放内存。这有助于最大限度地减少内存使用并防止内存泄漏。

其他提示:

  • 调整批大小和分块大小: 根据的硬件资源和数据集大小,可能需要调整 batch_size chunksize 参数以优化性能和内存使用。
  • 使用更小的模型: 如果仍然遇到内存问题,请考虑使用更小、更快的 Transformer 模型。
  • 降维: 可以使用 PCA 或 t-SNE 等降维技术来减少嵌入的维数。这可以帮助减少内存使用,而不会牺牲太多信息。

通过遵循这些步骤,应该能够有效地从大型 DNA 序列数据集中提取嵌入,而不会遇到内存问题。

标签:python,huggingface-transformers,huggingface,huggingface-tokenizers,huggingface-d
From: 78837463

相关文章

  • open3d python 法线估计
    测试效果废话Open3D中的法线估计是一个重要的功能,它可以帮助用户了解三维点云中每个点的局部表面方向。以下是对Open3D法线估计的详细解释:一、法线估计的基本原理法线估计通常基于局部表面拟合的方法。在点云数据中,每个点的局部邻域可以视为一个平面或曲面的近似。通......
  • jenkins的shell command中如何让python 实时显示执行日志
    在使用Jenkins的shellcommand里面执行python脚本时,我们希望在构建shell脚本时可以实时输出日志,但是在构建python脚本时,是等到python执行完成以后,才显示结果,这个对于我们判断脚本执行状态非常不友好。而之所以会出现这种情况,是因为python默认是有缓存的,所以我们需要禁用输入......
  • onnx转engine工具(包含量化) python脚本
    量化工具在网上搜索五花八门,很多文章没有说明使用的版本导致无法复现,这里参考了一些写法实现量化,并转为engine。具体实现代码见下方,欢迎各位小伙伴批评指正。tensorrt安装参考windows11下安装TensorRT,并在conda虚拟环境下使用_tensorrt免费吗-CSDN博客pycuda安装参考GPU......
  • 20:Python函数
    #Python3函数#函数是组织好的,可重复使用的,用来实现单一,或相关联功能的代码段。#函数能提高应用的模块性,和代码的重复利用率。你已经知道Python提供了许多内建函数,比如print()。#但你也可以自己创建函数,这被叫做用户自定义函数。#定义一个函数#你可以定义一个由自己想要功能......
  • 使用python做页面,测试数据库连通性!免费分享!测试通过~
    免费分享刚刚写的一个小程序,测试通过没问题,解BUG也就花了半小时吧有更好的方法欢迎评论区推给我谢谢。importtkinterastkfromtkinterimportmessageboximportpymysqldefget_db_info(db_source):ifdb_source=='database1':hostname=e1.get()......
  • Python面试宝典第30题:找出第K大元素
    题目        给定一个整数数组nums,请找出数组中第K大的数,保证答案存在。其中,1<=K<=nums数组长度。        示例1:输入:nums=[3,2,1,5,6,4],K=2输出:5        示例2:输入:nums=[50,23,66,18,72],K=3输出:50快速选择算法......
  • 使用Python和Flask框架实现简单的RESTful API
    目录环境准备创建Flask应用运行Flask应用测试API注意事项在当今的Web开发领域,RESTfulAPI因其简洁性和高效性而备受欢迎。本文将引导你使用Python的Flask框架来创建一个简单的RESTfulAPI,用于增删改查(CRUD)用户信息。环境准备在开始之前,请确保你的Python环境中已经安......
  • nodejs语言,MySQL数据库;springboot的个性化资讯推荐系统66257(免费领源码)计算机毕业设计
    摘 要随着科学技术的飞速发展,社会的方方面面、各行各业都在努力与现代的先进技术接轨,通过科技手段来提高自身的优势,个性化资讯推荐系统当然也不能排除在外。个性化资讯推荐系统是以实际运用为开发背景,运用软件工程原理和开发方法,采用springboot技术构建的一个管理系统。整......
  • c#语言,SQL server数据库;基于Web的社区人员管理系统的设计与实现36303(免费领源码)计算机
    目 录摘要1绪论1.1慨述1.2课题意义1.3B/S体系结构介绍1.4ASP.NET框架介绍2 社区人员管理系统分析2.1可行性分析2.2系统流程分析2.2.1数据增加流程2.2.2数据修改流程52.2.3数据删除流程52.3系统功能分析62.3.1功能性分析62.3.2非功能性......
  • Python多种接口请求方式示例
    发送JSON数据如果你需要发送JSON数据,可以使用json参数。这会自动设置Content-Type为application/json。importrequestsimportjsonurl='http://example.com/api/endpoint'data={"key":"value","another_key":"another_value"......