首页 > 编程语言 >monodepth学习4-训练讲解

monodepth学习4-训练讲解

时间:2022-08-20 15:23:02浏览次数:77  
标签:opt 训练 models train self num 讲解 monodepth

训练学习

monodepth2的训练过程由于存在多个训练模式和网络结构导致部分比较难以理解,这里我们结合网上的资料和自己对代码的理解进行简要地介绍,个人能力有限,对计算机视觉接触较少,如果有错误欢迎指正。

三种训练模式

monodepth2在readme中表示他们采用了三种训练方式,单目(Mono)、双目(Stereo)、联合(Mono+Stereo),提供的代码可以支持三种模式,不同模式中的训练指令可以查看experiments文件夹内文件,内部对多种模式的训练都提供了代码。
对于不同帧使用的是不同字母代替,-1表示前一帧,0表示当前帧,1使用后一帧,s表示双目另一侧的图片。
使用单目模式进行训练时,使用的数据是[-1,0,1]
使用双目训练时为[0,s]
联合训练时为[-1,0,1,s]

使用不同网络结构进行训练

除了三种训练模式,monodepth2还提供了不同的网络结构进行训练,这里主要是对于姿势网络进行了不同网络的测试,有三种分别是:separate_resnet,
posecnn,shared。这三种模式表示的是对于姿态网络的encode部分采用了不同的网络结构,separate_resnet模式下,使用的是不同的Resnet网络,shared模式下是和获取深度的网络采用同一个编码网络,posecnn表示共享编码器,且解码器采用的是posecnn网络。monodepth2推荐情况下使用separate_resnet模式。

初始化

def __init__(self, options):
        self.opt = options
        self.log_path = os.path.join(self.opt.log_dir, self.opt.model_name)
        self.device = torch.device("cpu" if self.opt.no_cuda else "cuda
        self.num_scales = len(self.opt.scales)#这个是初始化是用于模型encode和decode缩放的倍数
        self.num_input_frames = len(self.opt.frame_ids)#这里表示的就是[-1,0,1]图片类型,不同模型不同
        self.num_pose_frames = 2 if self.opt.pose_model_input == "pairs" else self.num_input_frames
        #表示模型的pose network 图片数量,在单目下是3,双目是2,联合是4

        assert self.opt.frame_ids[0] == 0, "frame_ids must start with 0"
        #表示输入的的图片是哪个帧数,当前,前一,后一,默认[0,-1,1],如果是改为stereo双面立体预测就要设置为0
        self.use_pose_net = not (self.opt.use_stereo and self.opt.frame_ids == [0])
        #使用pose net
        if self.opt.use_stereo:#双目训练
            self.opt.frame_ids.append("s")#双目的另一侧图片
      
        self.models["encoder"] = networks.ResnetEncoder(
            self.opt.num_layers, self.opt.weights_init == "pretrained")
        #深度网络编码器
        self.models["encoder"].to(self.device)
        self.parameters_to_train += list(self.models["encoder"].parameters())

        self.models["depth"] = networks.DepthDecoder(
            self.models["encoder"].num_ch_enc, self.opt.scales)
        self.models["depth"].to(self.device)#初始化深度网络解码器
        self.parameters_to_train += list(self.models["depth"].parameters())

        if self.use_pose_net:#初始化pose net,不太理解这里,我猜是测试不同的pose net 的差异,用于测试定义的pose net 是否有用
           ......#省略

        if self.opt.predictive_mask:
            assert self.opt.disable_automasking, \
                "When using predictive_mask, please disable automasking with --disable_automasking"

            # Our implementation of the predictive masking baseline has the the same arc
            # as our depth decoder. We predict a separate mask for each source frame.
            # 我们对预测屏蔽基线的实现具有相同的架构
            # 作为我们的深度解码器。我们为每个源帧预测一个单独的掩码。如果不使用会使用要给别人的掩码器,
            #这里应该对应的是第三点优化,auto-mask设置
            self.models["predictive_mask"] = networks.DepthDecoder(
                self.models["encoder"].num_ch_enc, self.opt.scales,
                num_output_channels=(len(self.opt.frame_ids) - 1))
            self.models["predictive_mask"].to(self.device)
            self.parameters_to_train += list(self.models["predictive_mask"].parameters())

        self.model_optimizer = optim.Adam(self.parameters_to_train, self.opt.learning_rate)#使用adam训练
        self.model_lr_scheduler = optim.lr_scheduler.StepLR(
            self.model_optimizer, self.opt.scheduler_step_size, 0.1)
        ......
        # data
        datasets_dict = {"kitti": datasets.KITTIRAWDataset,
                         "kitti_odom": datasets.KITTIOdomDataset}
                         #第一个是深度数据,第二个是训练图片路径
        self.dataset = datasets_dict[self.opt.dataset]

        fpath = os.path.join(os.path.dirname(__file__), "splits", self.opt.split, "{}_files.txt")

        train_filenames = readlines(fpath.format("train"))#训练的数据列表
        val_filenames = readlines(fpath.format("val"))#深度数据
        img_ext = '.png' if self.opt.png else '.jpg'

        num_train_samples = len(train_filenames)
        self.num_total_steps = num_train_samples // self.opt.batch_size * self.opt.num_epochs

        train_dataset = self.dataset(
            self.opt.data_path, train_filenames, self.opt.height, self.opt.width,
            self.opt.frame_ids, 4, is_train=True, img_ext=img_ext)
            #这里获取数据
        self.train_loader = DataLoader(
            train_dataset, self.opt.batch_size, True,
            num_workers=self.opt.num_workers, pin_memory=True, drop_last=True)
            #获取数据的迭代器
        val_dataset = self.dataset(
            self.opt.data_path, val_filenames, self.opt.height, self.opt.width,
            self.opt.frame_ids, 4, is_train=False, img_ext=img_ext)
        self.val_loader = DataLoader(
            val_dataset, self.opt.batch_size, True,
            num_workers=self.opt.num_workers, pin_memory=True, drop_last=True)
        self.val_iter = iter(self.val_loader)

        self.writers = {}
        for mode in ["train", "val"]:
            self.writers[mode] = SummaryWriter(os.path.join(self.log_path, mode))

        if not self.opt.no_ssim:
            self.ssim = SSIM()
            self.ssim.to(self.device)

        self.backproject_depth = {}
        self.project_3d = {}
        for scale in self.opt.scales:
            h = self.opt.height // (2 ** scale)
            w = self.opt.width // (2 ** scale)
            #这里修改高和宽就是为了多尺度进行损失函数技术,是优化的一种
            self.backproject_depth[scale] = BackprojectDepth(self.opt.batch_size, h, w)#这个使用讲使转化深度图
            self.backproject_depth[scale].to(self.device)

            self.project_3d[scale] = Project3D(self.opt.batch_size, h, w)#将深度图转化二维图像
            self.project_3d[scale].to(self.device)

        self.depth_metric_names = [
            "de/abs_rel", "de/sq_rel", "de/rms", "de/log_rms", "da/a1", "da/a2", "da/a3"]
        self.save_opts()

对于初始化没有特别多好介绍的主要就是加载了不同的网络模型,创建了数据对象,创建了几个之后用于自监督的对象,之后进行训练。

标签:opt,训练,models,train,self,num,讲解,monodepth
From: https://www.cnblogs.com/blackworld-sp/p/16607675.html

相关文章

  • "蔚来杯"2022牛客暑期多校训练营4
    A.TaskComputing给定\(n\)个任务,每个任务有两个权值\(w_i,p_i\),从中按任意顺序选出\(m\)个任务\((a_1,a_2,...,a_m)\),收益为\(\sum\limits_{i=1}^mw_{a_i}\prod\limits_{......
  • monodepth2-代码目录讲解
    代码目录讲解这里对个人对代码的理解进行讲解,个人由于设备不太行,没有去对模型进行复现,仅集合了网上内容对代码进行了解读,希望可以有帮助。目录结构asserts:这个主要是......
  • JQuery_遍历for循环&each方法$全局each&forof讲解
    遍历js的遍历方式for(初始化值;循环结束条件;步长)JQuery遍历方式JQuery对象.each(callback)$.each(object,[callback])for..of;<!DOCTYPEhtml><html><hea......
  • monodepth2学习-KITTI数据集内容
    KITTI数据集介绍monodepth2采用KITTI数据集进行训练,KITTI数据集主要是针对自动驾驶领域的图形处理技术,主要应用在评测立体图像(stereo)、光流(opticalflow)、3D物体检查等计......
  • monodepth2学习1-原理介绍
    monodepth2介绍monodepth2是在2019年CVPR会议上提出的一种三维重建算法,monodepth2是基于monodepth进行了改进,采用的是基于自监督的神经网络,提出了一下三点优化:一个最小......
  • 18js面向对象回顾及原型讲解
    面向对象回顾核心概念:万物皆对象(顶层对象Object)抽取名词作为属性抽取行为作为方法俩种构建对象的方式构造函数构建es6的形式classclassPerson{constructor(......
  • 1021 ObstacleCourse障碍训练课 优先队列+bfs+转弯
    链接:https://ac.nowcoder.com/acm/contest/26077/1021来源:牛客网题目描述考虑一个NxN(1<=N<=100)的有1个个方格组成的正方形牧场。......
  • 【限时领奖】消息队列 MNS 训练营重磅来袭,边学习充电,边领充电宝~
    阿里云消息队列MNS定位是RocketMQ轻量版,提供轻量模型、轻量HTTPRESTful协议,运维轻量、计费轻量,具备易集成等特点。为了帮助大家由浅入深的对阿里云消息队列MNS有......
  • 69用于预训练BERT的数据集
    点击查看代码importosimportrandomimporttorchfromd2limporttorchasd2l#@saved2l.DATA_HUB['wikitext-2']=('https://s3.amazonaws.com/research.m......
  • 69预训练BERT
    点击查看代码importtorchfromtorchimportnnfromd2limporttorchasd2lbatch_size,max_len=512,64train_iter,vocab=d2l.load_data_wiki(batch_size,......