首页 > 其他分享 >【科研03】【代码复现】TransUnet道路提取

【科研03】【代码复现】TransUnet道路提取

时间:2023-10-15 09:04:16浏览次数:36  
标签:03 image label train 复现 TransUnet data self size

目录

1. 数据准备 data process

  经过科研02部分的数据预处理,我们已经得到了以下内容:

  • train-image & train-label --> train.npz, train.txt
  • test-image & test-label --> test.npz, text.txt
  • val-image & val-label --> val.npz, val.txt

2. 文件更名 files rename

  已经确认了数据内容的正确。

  在复现代码时,尽量不更改代码中的文件名字,因此接下来还需要将这些数据更改为原始代码中使用的文件名。

2.1. 数据更名 npz rename

  05_npz_files文件夹更名为Synapse,该文件夹放在data文件夹中。

  其下有两个文件夹:

  • train_npz文件夹名不做变更。

  • test_npzval_npz两个文件夹中的内容合并,并更名为test_vol_h5

  不用担心test_npz和val_npz混淆在一起,代码会通过txt文档来筛选。

2.2. 文档更名 txt rename

  06_npzFiles_txt文件夹更名为lists_Synapse

  其下有三个文件:

  • train.txt文件名不做更改。

  • test.txtval.txt中的内容合并在一起,更名为test_vol.txt

3. 代码修改 code change

  大多还是依据TransUnet官方代码训练自己数据集中的内容进行的修改。

3.1. 目录调整 contents

  目录安排并未按照上述csdn的链接,data、model、predictions和TransUNet-main文件夹是同级别的。

  TransUNet-main下包含的文件夹包括:networks、datasets、test_log和lists,文件包括trainer.py,test.py,utils.py和train.py。

  可以按照上述内容做一下核查。

3.2. 数据读取 code1

  • TransUNet-main-> datasets -> dataset_synapse.py

  按照TransUnet官方代码训练自己数据集中的内容进行修改。

  dataset_synapse.py文件中的Synapse_dataset类中,修改__getitem__函数如下:

 def __getitem__(self, idx):
        if self.split == "train":
            slice_name = self.sample_list[idx].strip('\n')
            data_path = self.data_dir+"/"+slice_name+'.npz'
            data = np.load(data_path)
            image, label = data['image'], data['label']
        else:
            slice_name = self.sample_list[idx].strip('\n')
            data_path = self.data_dir+"/"+slice_name+'.npz'
            data = np.load(data_path)
            image, label = data['image'], data['label']
            image = torch.from_numpy(image.astype(np.float32))
            image = image.permute(2,0,1)
            label = torch.from_numpy(label.astype(np.float32))
        sample = {'image': image, 'label': label}
        if self.transform:
            sample = self.transform(sample)
        sample['case_name'] = self.sample_list[idx].strip('\n')
        return sample

  dataset_synapse.py文件中的RandomGenerator类,修改__call__函数如下:

    def __call__(self, sample):
        image, label = sample['image'], sample['label']

        if random.random() > 0.5:
            image, label = random_rot_flip(image, label)
        elif random.random() > 0.5:
            image, label = random_rotate(image, label)
        x, y,_ = image.shape
        if x != self.output_size[0] or y != self.output_size[1]:
            image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y,1), order=3)  # why not 3?
            label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=0)
        image = torch.from_numpy(image.astype(np.float32))
        image = image.permute(2,0,1)
        label = torch.from_numpy(label.astype(np.float32))
        sample = {'image': image, 'label': label.long()}
        return sample

3.2. 训练参数 parameter set

  主要是train.py文件中的参数。

  而为尽可能的保证不修改源代码中的内容,博主尽可能的保证文件夹的位置与github中的设置一致。

  TransUnet官方代码训练自己数据集中提到的../内容修改为./的部分都未进行修改。

3.2.1. 目标类别 num classes

  train.py文件。

  下面代码中的default依据label中的物体类别来定义。

  约在第18行。

  • 建筑物识别:类别1是背景,类别2是建筑物,default=2

  • 道路识别:类别1是背景,类别2是建筑物,default=2

  • 土地利用覆盖分类:类别1是大豆,类别2是小麦,类别3是水稻,类别4是其他,default=4

  • ···

parser.add_argument('--num_classes', type=int, default=2, help='output channel of network')

3.2.2. 运行轮次 max epochs

  train.py文件。

  下面代码中的default设定为想要运行多少个epoch来决定。

  约在第22行。

parser.add_argument('--max_epochs', type=int, default=150, help='maximum epoch number to train')

3.2.3. 批次传入 batch size

  train.py文件。

  下面代码中的default设定一个iteration传入多少个image和label。

  • 对于512 X 512大小的image和label,16GB的显存,可以设定batch_size为4,如果显存为12GB的话,应该可以设定为2,如果显存为8GB及以下,建议设定为1吧,不然电脑会卡。

  约在第24行。

parser.add_argument('--batch_size', type=int, default=4, help='batch_size per gpu')

3.2.4. 图片尺寸 image size

  train.py文件。

  因处理的图片是512大小的,故而设定default=512。

  约在第31行。

parser.add_argument('--img_size', type=int, default=512, help='input patch size of network input')

3.2.5. 数据设置 dataset_config

  train.py文件。

  一些不是参数的部分,但需要修改。

  因为数据集的设置,所以只需要更改下面代码中的num_classes为2。

这里其实只是修改了参数部分的默认值?

  约在第60行。

    dataset_config = {
        'Synapse': {
            'root_path': '../data/Synapse/train_npz',
            'list_dir': './lists/lists_Synapse',
            'num_classes': 2,
        },
    }

3.2.6. 保存名称 save name

  train.py文件。

建议每次都对下面这些内容进行修改,确保生成的模型文件是在一个新的文件夹中,而不会覆盖前一次的模型训练结果。

  如果不修改,并且也没修改max_epochs,那么会无情的覆盖上一次的模型结果,导致无法在必要时调用不同阶段训练的模型文件对结果进行测试。

  建议将两行的TU都设定成任务名_日期,如RoadExtract_231013

  • 道路提取任务:RoadExtract

  • 日期:23年10月13日

  约在第67行。

    # 修改前
    args.exp = 'TU_' + dataset_name + str(args.img_size)
    snapshot_path = "../model/{}/{}".format(args.exp, 'TU')
    # 修改后
    args.exp = 'RoadExtract_231013_' + dataset_name + str(args.img_size)
    snapshot_path = "../model/{}/{}".format(args.exp, 'RoadExtract_231013')

3.2.6. 重要修改 important set

  trainer.py文件。

  如果不修改,会出现电脑自动重启or程序莫名中断等问题。

  约在第33行。

    trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True, worker_init_fn=worker_init_fn)

4. 实际实验 real experiment

  train-image & train-label : 8660

  16GB显存,batch-size设置为4,epochs设置为150。

  预计运行时间:1week!!!!!

  这也太慢了!早知道设置epochs设置小一点了。

5. 改进想法 think

5.1. 权重保存 parameter save

  不能保存每一次训练的权重。

  这确实减少了存储空间的损耗,但是导致不训练够100个epoch或训练完所有的epoch,就没有权重文件被保存下来。

  如果在即将训练完时电脑断电,那么可能一周的结果就徒然无功了。

  不说保存每一个epoch,起码应当考虑每5个epochs之类的保存一下权重。

5.2. 权重名称 parameter name

  权重的名字保存下来是无意义的epoch_99.pth或epoch_149.pth,或许可以将权重的名字更改为含有精度评价指标的名字,如:epoch_49_miou-72.33_oa-88.75_···.pth。


  后续有机会将针对这些内容做一些修改。

标签:03,image,label,train,复现,TransUnet,data,self,size
From: https://www.cnblogs.com/If-I-Were-A-Bird/p/17759904.html

相关文章

  • 【科研02】【代码复现】【代码分享】TransUnet-RoadExtract 道路提取【数据预处理-ras
    目录1.数据处理dataprocess1.1.类型转换RastertoPng1.2.边缘填充Resize1.2.1.填充Resizeimage1.2.1.填充Resizelabel1.3.批量裁剪Clip1.4.波段缩减3bandsto1band1.5.筛选图像Choose1.6.转换格式Transformtonpz1.7.读取列表ReadFilesToList1.数......
  • 从链接器的角度详细分析g++报错: (.text+0x24): undefined reference to `main'
    /usr/bin/ld:/usr/lib/gcc/x86_64-linux-gnu/9/../../../x86_64-linux-gnu/Scrt1.o:infunction`_start':(.text+0x24):undefinedreferenceto`main'collect2:error:ldreturned1exitstatus  在使用g++编译链接两个C++源文件main.cpp以及VecAdd.cpp时出现了以上......
  • Flutter错误type 'Null' is not a subtype of type 'Handler'
    报错type'Null'isnotasubtypeoftype'Handler'原因分析在使用Fluro路由库时,出现"type'Null'isnotasubtypeoftype'Handler'"错误通常表示你尝试将一个空(null)值分配给Fluro的Handler对象或调用了未初始化的路由处理程序。解决方法这个错误通常发生在以下......
  • 2023 巅峰极客 m1_read 详细复现
    定位逻辑本题给出了bin文件,即out.bin,故可以猜测其内部包含了加密结果或者密钥等m1_read文件打开后,函数数量不多,并且静态分析WinMain不可行于是翻找函数,可以找到形如AES的函数(sub_4BF0)利用Findcrypt也出现了AES的特征码,于是假定是AES,并且没有魔改函数接近结尾部分可以看出这......
  • Codeforces Round 903 (Div. 3) F. Minimum Maximum Distance(图论)
    CodeforcesRound903(Div.3)F.MinimumMaximumDistance思路对标记点更新fg,从0开始进行bfs,更新d1为所有点到0的距离获得到0最远的标记点L,从L开始bfs,更新d2为所有点到L的距离获得距离L最远的标记点R,从R开始bfs,更新d3为所有点到R的距离遍历所有点,这个点与标记点的最大距......
  • 【科研01】【代码复现】TransUnet-文件目录安排
    目录1.信息TransUnet1.1.时间opentime1.2.链接Linkgithub1.3.应用Use2.自用TransUnet2.1.目录Tree2.2.修改Change1.信息TransUnet1.1.时间opentime20211.2.链接Linkgithubhttps://github.com/Beckschen/TransUNet1.3.应用Use  本是应用于......
  • operator Demo03
    packageoperator;publicclassDemo03{publicstaticvoidmain(String[]args){//关系运算符返回的结果:正确,错误布尔值//ifinta=10;intb=20;intc=21;//取余,模运算System.out.println(c%a);//......
  • 算法讲解0304
    1、打印二进制voidprint(intnum){ for(inti=31;i>=0;i--) if((num&(1<<i))==0) cin>>0; else cin>>1;}2、选择排序voidselectionSort(intarry[]){ intn=sizeof(arry)/sizeof(*a); if(n<2)return; for(inti=0......
  • [AGC033C] Removing Coins题解
    思路可以看出,每次对一个点\(u\)操作一次,就相当于删除以\(u\)为根的所有叶节点。当然我们还是没有什么思路,我们可以想简单一点:在一条链上的情况。如果\(u\)是链的端点:以\(u\)为根节点的叶节点只有一个,所以链的长度减一。如果\(u\)不是链的端点:以\(u\)为根节点......
  • 动手动脑03
    1. 实际操作了一下,发现确实super基类构造法只能在子类构造法前面。放在后面会报错。 2.如果父类的构造方法调用了子类的方法或使用了子类的属性,那么在父类构造方法执行时,子类可能还没有被完全初始化,这将导致运行时错误。因此,不能反过来调用父类的构造方法。必须在子类的......