# 清空环境变量
rm(list = ls())
setwd("C:\\Users\\Administrator\\Desktop\\machine learning\\LSTM")
library(magrittr) # 提供 %>% 管道操作符
library(keras) # 提供 Keras 接口
library(dplyr)
library(caret)
set.seed(123) # 确保结果可复现
# 读取数据
otu <- read.table("genus_otu.txt", sep = "\t", header = TRUE, row.names = 1)
group <- read.table("group.txt", sep = "\t", header = TRUE)
# 合并OTU表和分组数据
otu_transposed <- t(otu) # 转置OTU表,样本作为行,菌属作为列
otu_combined <- merge(group, otu_transposed, by.x = "Sample", by.y = "row.names") # 合并
# 将Gene和Time转为因子(分类变量)
otu_combined$Gene <- as.factor(otu_combined$Gene)
otu_combined$Time <- as.factor(otu_combined$Time)
# 手动选择测试集样本
test_samples <- c("B2W10_1", "B2W8_1", "B2W6_1", "B2W4_1", "M2W10_1", "M2W8_1", "M2W6_1", "M2W4_1")
# "B3W10_1", "B3W8_1", "B3W6_1", "B3W4_1", "M3W10_1", "M3W8_1", "M3W6_1", "M3W4_1",
# 划分训练集和测试集
test_data <- otu_combined %>% filter(Sample %in% test_samples)
train_data <- otu_combined %>% filter(!Sample %in% test_samples)
# 加载gbm包
library(gbm)
#移除无变异的特征
no_variation_cols <- which(apply(train_data[, -1], 2, function(col) length(unique(col)) == 1))
train_data_filtered <- train_data[, -c(1, no_variation_cols + 1)]
test_data_filtered <- test_data[, -c(1, no_variation_cols + 1)]
# 训练GBT模型,预测基因型
gbm_model_gene <- gbm(
Gene ~ ., # 使用所有特征预测基因型
data = train_data_filtered, # 去掉样本名称列
distribution = "multinomial", # 多分类任务
n.trees = 1000, # 树的数量
interaction.depth = 3, # 树的最大深度
shrinkage = 0.01, # 学习率
cv.folds = 5 # 5折交叉验证
)
# 使用测试集进行预测
pred_gene <- predict(gbm_model_gene, newdata = test_data[, -1], n.trees = gbm_model_gene$n.trees, type = "response")
write.table(pred_gene, file = "pred_gene.txt", sep = "\t", row.names = TRUE, col.names = NA)
# 将预测结果转换为分类标签
pred_gene_class <- apply(pred_gene, 1, function(x) colnames(pred_gene)[which.max(x)])
write.table(pred_gene_class, file = "pred_gene_class.txt", sep = "\t", row.names = TRUE, col.names = NA)
# 计算准确率
confusionMatrix(as.factor(pred_gene_class), test_data$Gene)
# 特征重要性分析
importance_gene <- summary(gbm_model_gene)
write.table(importance_gene, file = "results.txt", sep = "\t", row.names = TRUE, col.names = NA)
# 训练GBT模型,预测时间点
gbm_model_time <- gbm(
Time ~ ., # 使用所有特征预测时间点
data = train_data[, -1], # 去掉样本名称列
distribution = "multinomial", # 多分类任务
n.trees = 1000, # 树的数量
interaction.depth = 3, # 树的最大深度
shrinkage = 0.01, # 学习率
cv.folds = 5 # 5折交叉验证
)
# 查看模型输出
summary(gbm_model_time)
pred_time <- predict(gbm_model_time, newdata = test_data[, -1], n.trees = gbm_model_time$n.trees, type = "response")
pred_time_class <- apply(pred_time, 1, function(x) colnames(pred_time)[which.max(x)])
confusionMatrix(as.factor(pred_time_class), test_data$Time)
# 特征重要性分析
importance_time <- summary(gbm_model_time)
print(importance_time)
标签:Sample,梯度,gbm,library,filter,提升,samples,test From: https://www.cnblogs.com/wzbzk/p/18464203