首页 > 其他分享 >Pytorch图像分类训练框架

Pytorch图像分类训练框架

时间:2023-01-07 13:22:47浏览次数:57  
标签:.__ 框架 self args batch Pytorch train 图像 model

Pytorch图像分类训练框架

使用pytorch进行图像分类训练是一个大部分代码可复用的过程,我将在kaggle 比赛Paddy Doctor中写的训练代码抽取出来,方便以后图像分类任务使用。

代码基于ahangchen/torch_base修改,完整源码:Base-Pytorch-Trainer


1. 依赖

pip install tensorboard

conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch pyyaml

2. 文件组织方式

.
├── checkpoints
├── configs
├── dataset
├── submissions
├── do_predict.sh
├── do_split_dataset.sh
├── do_tensorboard.sh
├── do_train.sh
├── LICENSE
├── README.md
└── src
    └── main
        ├── data
        ├── engine
        ├── model
        ├── options
        ├── predict.py
        ├── split_dataset.py
        └── train.py
  • checkpoints:放加载点文件

  • configs:放默认参数配置文件

  • dataset:放数据集文件

  • submissions:放最后提交的submission文件

  • do_predict.sh:预测脚本

  • do_split_dataset.sh:划分数据集脚本

  • do_tensorboard.sh:打开tensorboard脚本

  • do_train.sh:训练脚本

  • src:放源码

    • main:放源码

      • data:数据预处理源代码,包括数据增强和数据迭代器生成

      • engine:训练 预测 划分数据集的代码

      • model:神经网络模型定义代码

      • options:参数设定代码

      • predict.py:预测源码

      • split_dataset.py:划分数据集源码

      • train.py:训练源码

3. 程序入口

使用脚本作为程序入口,方便设置参数

  • do_predict.sh:预测脚本

  • do_split_dataset.sh:划分数据集脚本

  • do_tensorboard.sh:打开tensorboard脚本

  • do_train.sh:训练脚本

训练脚本:

if [ ! -d "checkpoints" ];then
  mkdir checkpoints;
fi
cd ./src/main/ && \
python ./train.py \
--config_file_path ../../configs/train_config.yaml \
--epochs 500 \
--batch_size 8 \
--dataset_dir ../../dataset/paddy-disease-classification/ \
--model_type base_model \
| tee ../../checkpoints/output.txt

预测脚本:

cd ./src/main/ && \
python ./predict.py \
--config_file_path ../../configs/eval_config.yaml \
--model_type efficient_model \
--dataset_dir ../../dataset/ \
--submission_file_path ../../submissions/submission.csv

划分数据集脚本:

cd ./src/main/ && \
python ./split_dataset.py \
--config_file_path ../../configs/split_config.yaml \
--dataset_dir ../../dataset/paddy-disease-classification/

4. 参数设定

通过扩展python的argparse.ArgumentParser类,实现从yaml文件加载默认参数,并可以在命令行中重写参数。

class ConfigArgumentParser(argparse.ArgumentParser):
    def __init__(self, *args, **kwargs):
        self.config_parser = argparse.ArgumentParser()
        self.config_parser.add_argument("-c", "--config_file_path", default=None, metavar="FILE",
                                        help="where to load YAML configuration")
        self.option_names = []
        super(ConfigArgumentParser, self).__init__(*args, **kwargs)

    def add_override_argument(self, *args, **kwargs):
        arg = super().add_argument(*args, **kwargs)
        self.option_names.append(arg.dest)
        return arg

    def parse_args(self, args=None):
        res, remaining_argv = self.config_parser.parse_known_args(args)
        if res.config_file_path is not None:
            with open(res.config_file_path, "r") as f:
                config_vars = yaml.safe_load(f)
            for key in config_vars:
                if key not in self.option_names:
                    self.error(f"unexpected configuration entry: {key}")
            self.set_defaults(**config_vars)

        return super().parse_args(remaining_argv)

训练参数 设定如下:

def prepare_train_args():
    train_parser = ConfigArgumentParser()
    train_parser.add_override_argument('--seed', type=int,
                                       help='a random seed')
    train_parser.add_override_argument('--gpus', nargs='+', type=int,
                                       help='numbers of GPU')
    train_parser.add_override_argument('--epochs', type=int,
                                       help='total epochs')
    train_parser.add_override_argument('--batch_size', type=int,
                                       help='batch size')
    train_parser.add_override_argument('--lr', type=float,
                                       help='learning rate')
    train_parser.add_override_argument('--momentum', type=float,
                                       help='momentum for sgd, alpha parameter for adam')
    train_parser.add_override_argument('--beta', default=0.999, type=float,
                                       help='beta parameters for adam')
    train_parser.add_override_argument('--weight_decay', '--wd', type=float,
                                       help='weight decay')
    train_parser.add_override_argument('--save_prefix', type=str,
                                       help='some comment for model or test result dir')
    train_parser.add_override_argument('--model_type', type=str,
                                       help='choose a model type, which is defined in model folder')
    train_parser.add_override_argument('--loss_type', type=str,
                                       help='choose a loss function, which is defined in loss folder')
    train_parser.add_override_argument('--acc_type', type=str,
                                       help='choose a acc function, which is defined in metrics folder')
    train_parser.add_override_argument('--is_load_strict', action='store_false',
                                       help='allow to load only common state dicts')
    train_parser.add_override_argument('--is_load_pretrained_weight', action='store_true',
                                       help='True means try to load pretrained weights')
    train_parser.add_override_argument('--pretrained_weights_path', type=str,
                                       help='pretrained weights path')
    train_parser.add_override_argument('--is_resuming_training', action='store_true',
                                       help='True means try to resume previous train')
    train_parser.add_override_argument('--checkpoint_path', type=str,
                                       help='checkpoints path')
    train_parser.add_override_argument('--dataset_dir', type=str,
                                       help='dataset directory')
    train_parser.add_override_argument('--checkpoints_dir', type=str,
                                       help='checkpoints directory')
    args = train_parser.parse_args()
    get_train_model_dir(args)
    save_args(args, args.checkpoints_dir)
    return args

训练默认yaml参数文件:

seed: 42
gpus: [0]
epochs: 100
batch_size: 128
lr: 1e-3
momentum: 0.9
beta: 0.999
weight_decay: 0
save_prefix: "test"
model_type: "base_model"
loss_type: "focal_loss"
acc_type: "classification_acc"
is_load_strict: true
is_load_pretrained_weight: false
pretrained_weights_path: ''
is_resuming_training: false
checkpoint_path: ''
dataset_dir: ../../dataset/
checkpoints_dir:

预测参数设定如下:

def prepare_eval_args():
    eval_parser = ConfigArgumentParser()
    eval_parser.add_override_argument('--seed', type=int,
                                      help='a random seed')
    eval_parser.add_override_argument('--gpus', nargs='+', type=int,
                                      help='numbers of GPU')
    eval_parser.add_override_argument('--model_type', type=str,
                                      help='used in model_interface.py')
    eval_parser.add_override_argument('--weights_path', type=str,
                                      help='weights path')
    eval_parser.add_override_argument('--dataset_dir', type=str,
                                      help='dataset directory')
    eval_parser.add_override_argument('--submission_file_path', type=str,
                                      help='submission.csv path')
    args = eval_parser.parse_args()
    return args

预测默认yaml参数文件:

seed: 42
gpus: [0]
model_type: "base_model"
weights_path:
dataset_dir: ../../dataset/
submission_file_path: ../../submissions/submission.csv

划分数据集参数设定如下:

def prepare_split_dataset_args():
    split_parser = ConfigArgumentParser()
    split_parser.add_override_argument('--seed', type=int,
                                       help='a random seed')
    split_parser.add_override_argument('--valid_ratio', type=float,
                                       help='valid ratio')
    split_parser.add_override_argument('--dataset_dir', type=str,
                                       help='dataset directory')
    args = split_parser.parse_args()
    save_args(args, args.dataset_dir)
    return args

划分数据集默认yaml文件:

seed: 42
valid_ratio: 0.2
dataset_dir: ../../dataset/

5. 划分数据集

图像分类中,pytorch提供torchvision.datasets.ImageFolder函数来进行数据集加载。使用该API,数据集应使用如下目录结构:

root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png

本项目也使用该API,所以应在Splitter.py中针对自己的数据集修改划分数据集函数,使得数据集有如上目录结构。

6. 数据加载

使用pytorch的torch.utils.data.DataLoader

训练数据集加载:

def select_train_loader(args):
    train_dataset = torchvision.datasets.ImageFolder(os.path.join(args.dataset_dir, 'train_valid_test', "train"),
                                                     transform=transform_train)
    print(train_dataset.class_to_idx)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True,
                              drop_last=False)
    return train_loader

评估数据集加载:

def select_eval_loader(args):
    eval_dataset = torchvision.datasets.ImageFolder(os.path.join(args.dataset_dir, 'train_valid_test', "valid"),
                                                    transform=transform_eval)
    val_loader = DataLoader(eval_dataset, batch_size=1, shuffle=False, num_workers=1, pin_memory=True, drop_last=False)
    return val_loader

7. 模型建立

src/main/model下建立新模型文件夹,并在其中编写模型代码。然后在model_interface.py中的select_model()函数中添加该模型对应字典。

def select_model(args):
    type2model = {
        'base_model': base_model(),
        'better_model': better_model(),
        'efficient_model': efficient_model(),
    }
    model = type2model[args.model_type]
    return model

8. 评估量建立

src/main/engine/metrics/loss下定义loss函数。然后在metrics_interface.py中的select_loss()函数中添加该模型对应字典。

def select_loss(args):
    type2lossFunction = {
        'focal_loss': FocalLoss(num_class=10),
    }
    loss_function = type2lossFunction[args.loss_type]
    return loss_function

同样,分类精确度也可重新定义。

9. 训练

直接运行do_train.sh脚本,命令行输出会同时保存在checkpoints文件就的output.txt中,方便查找错误。训练中的参数也会保存在加载点路径中的args.txt文件中。

训练器类:

class Trainer(object):
    def __init__(self, args, model, train_loader, val_loader):
        torch.manual_seed(args.seed)
        self.__args = args
        self.__logger = Logger(args)
        self.__loss_function = select_loss(args)
        self.__acc_function = select_acc(args)
        self.__train_loader = train_loader
        self.__val_loader = val_loader
        self.__start_epoch = 0

        train_status = 'Normal'
        train_status_logs = []

        # loading model
        self.__model = model
        if args.is_load_pretrained_weight:
            train_status = 'Continuance'
            self.__model.load_state_dict(torch.load(args.pretrained_weights_path), strict=args.is_load_strict)
            train_status_logs.append('Log   Output: Loaded pretrained weights successfully')

        if args.is_resuming_training:
            train_status = 'Restoration'
            checkpoint = torch.load(args.checkpoint_path)
            self.__start_epoch = checkpoint['epoch'] + 1
            self.__model.load_state_dict(checkpoint['model_state_dict'], strict=args.is_load_strict)
            train_status_logs.append('Log   Output: Resumed previous model state successfully')

        if args.gpus == [0]:
            gpu_status = 'Single-GPU'
            device = torch.device("cuda:0")
            self.__model.to(device)
        else:
            gpu_status = 'Multi-GPU'
            self.__model = torch.nn.DataParallel(self.__model, device_ids=args.gpus, output_device=args.gpus[0])

        # initialize the optimizer
        self.__optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.__model.parameters()),
                                            self.__args.lr,
                                            betas=(self.__args.momentum, self.__args.beta),
                                            weight_decay=self.__args.weight_decay)
        if args.is_resuming_training:
            checkpoint = torch.load(args.checkpoint_path)
            self.__optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            train_status_logs.append('Log   Output: Resumed previous optimizer state successfully')

        # print status
        print('****************************************************************************************************')
        print('Model:')
        print(self.__model)
        print('****************************************************************************************************')
        print('Params To Learn:')
        for name, param in self.__model.named_parameters():
            if param.requires_grad:
                print('\t', name)
        print('****************************************************************************************************')
        print('Train Status: ' + train_status)
        print('GPU   Status: ' + gpu_status)
        for train_status_log in train_status_logs:
            print(train_status_log)
        print('****************************************************************************************************')

    def train(self):
        for epoch in range(self.__start_epoch, self.__args.epochs):
            # train for one epoch
            since = time.time()
            self.__train_per_epoch()
            self.__val_per_epoch()
            self.__logger.save_curves(epoch)
            self.__logger.save_checkpoint(epoch, self.__model, self.__optimizer)
            self.__logger.print_logs(epoch, time.time() - since)
            self.__logger.clear_scalar_cache()

    def __train_per_epoch(self):
        # switch to train mode
        self.__model.train()

        for i, data_batch in enumerate(self.__train_loader):
            input_batch, output_batch, label_batch = self.__step(data_batch)

            # compute loss and acc
            loss, metrics = self.__compute_metrics(output_batch, label_batch, is_train=True)

            # compute gradient and do Adam step
            self.__optimizer.zero_grad()
            loss.backward()
            self.__optimizer.step()

            # logger record
            for key in metrics.keys():
                self.__logger.record_scalar(key, metrics[key])

    def __val_per_epoch(self):
        # switch to eval mode
        self.__model.eval()

        with torch.no_grad():
            for i, data_batch in enumerate(self.__val_loader):
                input_batch, output_batch, label_batch = self.__step(data_batch)

                # compute loss and acc
                loss, metrics = self.__compute_metrics(output_batch, label_batch, is_train=False)

                for key in metrics.keys():
                    self.__logger.record_scalar(key, metrics[key])

    def __step(self, data_batch):
        input_batch, label_batch = data_batch
        # warp input
        input_batch = Variable(input_batch).cuda()
        label_batch = Variable(label_batch).cuda()

        # compute output
        output_batch = self.__model(input_batch)
        return input_batch, output_batch, label_batch

    def __compute_metrics(self, output_batch, label_batch, is_train):
        # you can call functions in metrics_interface.py
        loss = self.__calculate_loss(output_batch, label_batch)
        acc = self.__evaluate_accuracy(output_batch, label_batch)
        prefix = 'train/' if is_train else 'val/'
        metrics = {
            prefix + 'loss': loss.item(),
            prefix + 'accuracy': acc,
        }
        return loss, metrics

    def __calculate_loss(self, output_batch: torch.Tensor, label_batch: torch.Tensor) -> torch.Tensor:
        loss = self.__loss_function(output_batch, label_batch)
        return loss

    def __evaluate_accuracy(self, output_batch: torch.Tensor, label_batch: torch.Tensor) -> float:
        acc = self.__acc_function(output_batch, label_batch)
        return acc

    @staticmethod
    def __gen_imgs_to_write(img, is_train):
        # override this method according to your visualization
        prefix = 'train/' if is_train else 'val/'
        return {
            prefix + 'img': img[0],
        }

10. 预测

根据数据集,修改src/main/engine/predictor.py,然后运行do_predict.sh脚本。

预测器类:

class Predictor(object):
    def __init__(self, args, model, transform):
        self.__args = args
        self.__model = model
        self.__transform = transform
        self.__model.load_state_dict(torch.load(args.weights_path), strict=True)

        if args.gpus == [0]:
            gpu_status = 'Single-GPU'
            device = torch.device("cuda:0")
            self.__model.to(device)
        else:
            gpu_status = 'Multi-GPU'
            self.__model = torch.nn.DataParallel(self.__model, device_ids=args.gpus, output_device=args.gpus[0])

        print('****************************************************************************************************')
        print('Model:')
        print(self.__model)
        print('****************************************************************************************************')
        print('GPU   Status: ' + gpu_status)
        print('****************************************************************************************************')
        self.__model.eval()

    def predict_csv(self):
        df = pd.read_csv(self.__args.submission_file_path)
        for index, row in df.iterrows():
            test_file_dir = os.path.join(self.__args.dataset_dir, 'train_valid_test', 'test', 'unknown', row[0])
            img = PIL.Image.open(test_file_dir)
            input_test = self.__transform(img).unsqueeze(0)
            input_test = Variable(input_test).cuda()
            with torch.no_grad():
                output_test = self.__model.forward(input_test)
                softmax = torch.nn.Softmax(dim=1)
                output_test = softmax(output_test)
            output_test = output_test.cpu().detach().numpy()
            label_test = np.argmax(output_test)
            df.iloc[index, 1] = (labels_map[label_test.item()])
        print(df)
        df.to_csv(self.__args.submission_file_path, index=None)

11. 声明

由于表述能力不行,很多细节没能说明完整,本文会慢慢修改完善。 感谢[ahangchen/torch_base](GitHub - ahangchen/torch_base: Quickly bring up your PyTorch project(a skeleton))提供的基础代码。

标签:.__,框架,self,args,batch,Pytorch,train,图像,model
From: https://www.cnblogs.com/dromer/p/17032500.html

相关文章