# -*- coding: utf-8 -*-
"""" This document is a simple Demo for DDP Image Classification """
from typing import Callable
from argparse import ArgumentParser, Namespace
import torch
from torch.backends import cudnn
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torchvision import transforms
from torchvision.datasets.cifar import CIFAR10
from torchvision.models import resnet
from tqdm import tqdm
from accelerate import Accelerator
from accelerate.utils import set_seed
def parse_args() -> Namespace:
"""Handling command-line input."""
parser = ArgumentParser()
# 数据集路径
parser.add_argument(
"-d",
"--dataset",
action="store",
default="/dev/shm/dataset",
type=str,
help="Dataset folder.",
)
# 训练轮数
parser.add_argument(
"-e",
"--epochs",
action="store",
default=248,
type=int,
help="Number of epochs to train.",
)
# Mini Batch大小
parser.add_argument(
"-bs",
"--batch-size",
action="store",
default=128,
type=int,
help="Size of mini batch.",
)
# 优化器选择
parser.add_argument(
"-opt",
"--optimizer",
action="store",
default="SGD",
type=str,
choices=["Adam", "SGD"],
help="Optimizer used to train the model.",
)
# 初始学习率
parser.add_argument(
"-lr",
"--learning-rate",
action="store",
default=2e-3,
type=float,
help="Learning rate.",
)
# 随机数种子
parser.add_argument(
"-s",
"--seed",
action="store",
default=0,
type=int,
help="Random Seed.",
)
return parser.parse_args()
def prepare_model(num_classes: int = 1000) -> torch.nn.Module:
"""ResNet18,并替换FC层"""
with accelerator.local_main_process_first():
model: resnet.ResNet = resnet.resnet18(
weights=resnet.ResNet18_Weights.DEFAULT
)
# 对于CIFAR数据集,ResNet-18将首层的7x7卷积核换成了3x3卷积核(参数量基本不变)
model.conv1 = torch.nn.Conv2d(
3, 64, kernel_size=3, stride=1, padding=1, bias=False
)
if num_classes != 1000:
model.fc = torch.nn.Linear(512, num_classes)
total_params = sum([param.nelement() for param in model.parameters()])
accelerator.print(f"#params: {total_params / 1e6}M")
return model
def prepare_dataset(folder: str):
"""采用CIFAR-10数据集"""
normalize_transform = transforms.Normalize(
(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
)
with accelerator.local_main_process_first():
train_data = CIFAR10(
folder,
train=True,
transform=transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(0.25),
transforms.AutoAugment(
transforms.AutoAugmentPolicy.CIFAR10
),
transforms.ToTensor(),
normalize_transform,
]
),
download=accelerator.is_local_main_process,
)
test_data = CIFAR10(
folder,
train=False,
transform=transforms.Compose(
[transforms.ToTensor(), normalize_transform]
),
download=accelerator.is_local_main_process,
)
train_eval_data = CIFAR10(
folder,
train=True,
transform=transforms.Compose(
[transforms.ToTensor(), normalize_transform]
),
)
return train_data, train_eval_data, test_data
def get_data_loader(
batch_size: int,
train_data: Dataset,
train_eval_data: Dataset,
test_data: Dataset,
) -> tuple[DataLoader, DataLoader, DataLoader]:
"""获取DataLoader"""
train_loader: DataLoader = DataLoader(
train_data,
batch_size,
shuffle=True,
pin_memory=True,
num_workers=2 if accelerator.num_processes == 1 else 0,
)
train_eval_loader: DataLoader = DataLoader(
train_eval_data,
batch_size * 2,
shuffle=False,
pin_memory=True,
num_workers=2 if accelerator.num_processes == 1 else 0,
)
test_loader: DataLoader = DataLoader(
test_data,
batch_size * 2,
shuffle=False,
pin_memory=True,
num_workers=2 if accelerator.num_processes == 1 else 0,
)
return accelerator.prepare(train_loader, train_eval_loader, test_loader)
@torch.enable_grad()
def train_epoch(
model: torch.nn.Module,
loss_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
dataloader: DataLoader,
optimizer: torch.optim.Optimizer,
) -> None:
"""训练一轮"""
model.train()
dataloader_with_bar = tqdm(
dataloader, disable=(not accelerator.is_local_main_process)
)
for source, targets in dataloader_with_bar:
optimizer.zero_grad()
output: torch.Tensor = model(source)
loss = loss_func(output, targets)
accelerator.backward(loss)
optimizer.step()
@torch.no_grad()
def eval_epoch(
model: torch.nn.Module,
loss_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
dataloader: DataLoader,
) -> tuple[float, float]:
"""在指定测试集上测试模型的损失和准确率"""
model.eval()
dataloader_with_bar = tqdm(
dataloader, disable=(not accelerator.is_local_main_process)
)
correct_sum, loss_sum, cnt_samples = 0, 0.0, 0
for source, targets in dataloader_with_bar:
output: torch.Tensor = model(source)
loss = loss_func(output, targets)
prediction: torch.Tensor = accelerator.gather_for_metrics(
output.argmax(dim=1) == targets
) # type: ignore
correct_sum += prediction.sum().item()
loss_sum += loss.item()
cnt_samples += len(prediction)
return loss_sum / len(dataloader), correct_sum / cnt_samples
def main(args: Namespace):
"""训练的主函数"""
set_seed(args.seed)
model = prepare_model(10)
train_data, train_eval_data, test_data = prepare_dataset(args.dataset)
train_loader, train_eval_loader, test_loader = get_data_loader(
args.batch_size, train_data, train_eval_data, test_data
)
optimizer: torch.optim.Optimizer = (
torch.optim.SGD(
model.parameters(),
args.learning_rate,
momentum=0.90,
weight_decay=2e-2,
)
if args.optimizer != "SGD"
else torch.optim.Adam(model.parameters(), args.learning_rate)
)
loss_func = torch.nn.CrossEntropyLoss(label_smoothing=0.05)
scheduler: CosineAnnealingWarmRestarts = CosineAnnealingWarmRestarts(
optimizer, 8, 2
)
model, optimizer, loss_func, scheduler = accelerator.prepare(
model, optimizer, loss_func, scheduler
)
best_acc = 0
log_file = open("log.csv", "wt")
if accelerator.is_local_main_process:
print(
"epoch,train_loss,train_acc,val_loss,val_acc,learning_rate",
file=log_file,
)
log_file.flush()
for epoch in range(args.epochs + 1):
accelerator.print(
f"Epoch {epoch}/{args.epochs}",
f"(lr={optimizer.param_groups[-1]['lr']}):",
)
# 训练模型
if epoch != 0:
train_epoch(model, loss_func, train_loader, optimizer)
accelerator.wait_for_everyone()
# 在训练集和测试集上评估模型
train_loss, train_acc = eval_epoch(model, loss_func, train_eval_loader)
val_loss, val_acc = eval_epoch(model, loss_func, test_loader)
accelerator.print(
f"[ Training ] Acc: {train_acc * 100:.2f}% Loss: {train_loss:.4f}"
)
# 保存最佳权重
accelerator.wait_for_everyone()
if accelerator.is_local_main_process:
print(
epoch,
train_loss,
train_acc,
val_loss,
val_acc,
optimizer.param_groups[-1]["lr"],
sep=",",
file=log_file,
)
log_file.flush()
accelerator.save_model(model, "./weights/last")
if val_acc > best_acc:
best_acc = val_acc
accelerator.save_model(model, "./weights/best")
accelerator.wait_for_everyone()
accelerator.print(
f"[Validation] Acc: {val_acc * 100:.2f}%",
f"Loss: {val_loss:.4f}",
f"Best: {best_acc * 100:.2f}%",
)
if epoch != 0:
scheduler.step()
log_file.close()
if __name__ == "__main__":
cudnn.benchmark = True
accelerator = Accelerator()
main(parse_args())
标签:loss,torch,accelerator,DDP,Huggingface,train,Accelerate,model,data
From: https://www.cnblogs.com/fang-d/p/18011917