目录
1. 数据准备 data process
经过科研02部分的数据预处理,我们已经得到了以下内容:
- train-image & train-label --> train.npz, train.txt
- test-image & test-label --> test.npz, text.txt
- val-image & val-label --> val.npz, val.txt
2. 文件更名 files rename
已经确认了数据内容的正确。
在复现代码时,尽量不更改代码中的文件名字,因此接下来还需要将这些数据更改为原始代码中使用的文件名。
2.1. 数据更名 npz rename
05_npz_files文件夹更名为Synapse,该文件夹放在data文件夹中。
其下有两个文件夹:
-
train_npz文件夹名不做变更。
-
test_npz和val_npz两个文件夹中的内容合并,并更名为test_vol_h5。
不用担心test_npz和val_npz混淆在一起,代码会通过txt文档来筛选。
2.2. 文档更名 txt rename
06_npzFiles_txt文件夹更名为lists_Synapse。
其下有三个文件:
-
train.txt文件名不做更改。
-
test.txt和val.txt中的内容合并在一起,更名为test_vol.txt。
3. 代码修改 code change
大多还是依据TransUnet官方代码训练自己数据集中的内容进行的修改。
3.1. 目录调整 contents
目录安排并未按照上述csdn的链接,data、model、predictions和TransUNet-main文件夹是同级别的。
TransUNet-main下包含的文件夹包括:networks、datasets、test_log和lists,文件包括trainer.py,test.py,utils.py和train.py。
可以按照上述内容做一下核查。
3.2. 数据读取 code1
- TransUNet-main-> datasets -> dataset_synapse.py
按照TransUnet官方代码训练自己数据集中的内容进行修改。
dataset_synapse.py文件中的Synapse_dataset类中,修改__getitem__函数如下:
def __getitem__(self, idx):
if self.split == "train":
slice_name = self.sample_list[idx].strip('\n')
data_path = self.data_dir+"/"+slice_name+'.npz'
data = np.load(data_path)
image, label = data['image'], data['label']
else:
slice_name = self.sample_list[idx].strip('\n')
data_path = self.data_dir+"/"+slice_name+'.npz'
data = np.load(data_path)
image, label = data['image'], data['label']
image = torch.from_numpy(image.astype(np.float32))
image = image.permute(2,0,1)
label = torch.from_numpy(label.astype(np.float32))
sample = {'image': image, 'label': label}
if self.transform:
sample = self.transform(sample)
sample['case_name'] = self.sample_list[idx].strip('\n')
return sample
dataset_synapse.py文件中的RandomGenerator类,修改__call__函数如下:
def __call__(self, sample):
image, label = sample['image'], sample['label']
if random.random() > 0.5:
image, label = random_rot_flip(image, label)
elif random.random() > 0.5:
image, label = random_rotate(image, label)
x, y,_ = image.shape
if x != self.output_size[0] or y != self.output_size[1]:
image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y,1), order=3) # why not 3?
label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=0)
image = torch.from_numpy(image.astype(np.float32))
image = image.permute(2,0,1)
label = torch.from_numpy(label.astype(np.float32))
sample = {'image': image, 'label': label.long()}
return sample
3.2. 训练参数 parameter set
主要是train.py文件中的参数。
而为尽可能的保证不修改源代码中的内容,博主尽可能的保证文件夹的位置与github中的设置一致。
TransUnet官方代码训练自己数据集中提到的../内容修改为./的部分都未进行修改。
3.2.1. 目标类别 num classes
train.py文件。
下面代码中的default依据label中的物体类别来定义。
约在第18行。
-
建筑物识别:类别1是背景,类别2是建筑物,default=2
-
道路识别:类别1是背景,类别2是建筑物,default=2
-
土地利用覆盖分类:类别1是大豆,类别2是小麦,类别3是水稻,类别4是其他,default=4
-
···
parser.add_argument('--num_classes', type=int, default=2, help='output channel of network')
3.2.2. 运行轮次 max epochs
train.py文件。
下面代码中的default设定为想要运行多少个epoch来决定。
约在第22行。
parser.add_argument('--max_epochs', type=int, default=150, help='maximum epoch number to train')
3.2.3. 批次传入 batch size
train.py文件。
下面代码中的default设定一个iteration传入多少个image和label。
- 对于512 X 512大小的image和label,16GB的显存,可以设定batch_size为4,如果显存为12GB的话,应该可以设定为2,如果显存为8GB及以下,建议设定为1吧,不然电脑会卡。
约在第24行。
parser.add_argument('--batch_size', type=int, default=4, help='batch_size per gpu')
3.2.4. 图片尺寸 image size
train.py文件。
因处理的图片是512大小的,故而设定default=512。
约在第31行。
parser.add_argument('--img_size', type=int, default=512, help='input patch size of network input')
3.2.5. 数据设置 dataset_config
train.py文件。
一些不是参数的部分,但需要修改。
因为数据集的设置,所以只需要更改下面代码中的num_classes为2。
这里其实只是修改了参数部分的默认值?
约在第60行。
dataset_config = {
'Synapse': {
'root_path': '../data/Synapse/train_npz',
'list_dir': './lists/lists_Synapse',
'num_classes': 2,
},
}
3.2.6. 保存名称 save name
train.py文件。
建议每次都对下面这些内容进行修改,确保生成的模型文件是在一个新的文件夹中,而不会覆盖前一次的模型训练结果。
如果不修改,并且也没修改max_epochs,那么会无情的覆盖上一次的模型结果,导致无法在必要时调用不同阶段训练的模型文件对结果进行测试。
建议将两行的TU都设定成任务名_日期,如RoadExtract_231013。
-
道路提取任务:RoadExtract
-
日期:23年10月13日
约在第67行。
# 修改前
args.exp = 'TU_' + dataset_name + str(args.img_size)
snapshot_path = "../model/{}/{}".format(args.exp, 'TU')
# 修改后
args.exp = 'RoadExtract_231013_' + dataset_name + str(args.img_size)
snapshot_path = "../model/{}/{}".format(args.exp, 'RoadExtract_231013')
3.2.6. 重要修改 important set
trainer.py文件。
如果不修改,会出现电脑自动重启or程序莫名中断等问题。
约在第33行。
trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True, worker_init_fn=worker_init_fn)
4. 实际实验 real experiment
train-image & train-label : 8660
16GB显存,batch-size设置为4,epochs设置为150。
预计运行时间:1week!!!!!
这也太慢了!早知道设置epochs设置小一点了。
5. 改进想法 think
5.1. 权重保存 parameter save
不能保存每一次训练的权重。
这确实减少了存储空间的损耗,但是导致不训练够100个epoch或训练完所有的epoch,就没有权重文件被保存下来。
如果在即将训练完时电脑断电,那么可能一周的结果就徒然无功了。
不说保存每一个epoch,起码应当考虑每5个epochs之类的保存一下权重。
5.2. 权重名称 parameter name
权重的名字保存下来是无意义的epoch_99.pth或epoch_149.pth,或许可以将权重的名字更改为含有精度评价指标的名字,如:epoch_49_miou-72.33_oa-88.75_···.pth。
后续有机会将针对这些内容做一些修改。
标签:03,image,label,train,复现,TransUnet,data,self,size From: https://www.cnblogs.com/If-I-Were-A-Bird/p/17759904.html