首页 > 编程语言 >【Python实现连续学习算法】复现2018年ECCV经典算法RWalk

【Python实现连续学习算法】复现2018年ECCV经典算法RWalk

时间:2025-01-05 15:35:04浏览次数:3  
标签:task ECCV self torch RWalk 算法 train id size

Python实现连续学习Baseline 及经典算法RWalk

在这里插入图片描述

1 连续学习概念及灾难性遗忘

连续学习(Continual Learning)是一种模拟人类学习过程的机器学习方法,它旨在让模型在面对多个任务时能够连续学习,而不会遗忘已学到的知识。然而,大多数深度学习模型在连续学习多个任务时会出现“灾难性遗忘”(Catastrophic Forgetting)现象。灾难性遗忘指模型在学习新任务时会大幅度遗忘之前学到的任务知识,这是因为模型参数在新任务的训练过程中被完全覆盖。

解决灾难性遗忘问题是连续学习研究的核心。目前已有多种方法被提出,包括正则化方法、回放、架构等等的方法,其中EWC(Elastic Weight Consolidation)是一种经典的正则化方法。

2 PermutdMNIST数据集及模型

PermutedMNIST是连续学习领域的一种经典测试数据集。它通过对MNIST数据集中的像素进行随机置换生成不同的任务。每个任务都是一个由置换规则决定的分类问题,但所有任务共享相同的标签空间。

对于模型的选择,通常采用简单的全连接神经网络。网络结构可以包含若干个隐藏层,每个隐藏层具有一定数量的神经元,并使用ReLU作为激活函数。网络的输出层与标签类别数一致。

模型在训练每个任务时需要调整参数,研究灾难性遗忘问题的严重程度,并在引入算法时测试其对连续学习能力的改善效果。

import random
import torch
from torchvision import datasets
import os
from torch.utils.data import DataLoader
import numpy as np
import torch.nn as nn
from torch.nn import functional as F
import warnings
warnings.filterwarnings("ignore")
# Set seeds
random.seed(2024)
torch.manual_seed(2024)
np.random.seed(2024)

# Ensure deterministic behavior
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

class PermutedMNIST(datasets.MNIST):
    def __init__(self, root="./data/mnist", train=True, permute_idx=None):
        super(PermutedMNIST, self).__init__(root, train, download=True)
        assert len(permute_idx) == 28 * 28
        if self.train:
            self.data = torch.stack([img.float().view(-1)[permute_idx] / 255
                                      for img in self.data])
        else:
            self.data = torch.stack([img.float().view(-1)[permute_idx] / 255
                                      for img in self.data])

    def __getitem__(self, index):
        if self.train:
            img, target = self.data[index], self.train_labels[index]
        else:
            img, target = self.data[index], self.test_labels[index]
        return img.view(1, 28, 28), target

    def get_sample(self, sample_size):
        random.seed(2024)
        sample_idx = random.sample(range(len(self)), sample_size)
        return [img.view(1, 28, 28) for img in self.data[sample_idx]]
def worker_init_fn(worker_id):
    # 确保每个 worker 的随机种子一致
    random.seed(2024 + worker_id)
    np.random.seed(2024 + worker_id)
def get_permute_mnist(num_task, batch_size):
    random.seed(2024)
    train_loader = {}
    test_loader = {}
    root_dir = './data/permuted_mnist'
    os.makedirs(root_dir, exist_ok=True)

    for i in range(num_task):
        permute_idx = list(range(28 * 28))
        random.shuffle(permute_idx)

        train_dataset_path = os.path.join(root_dir, f'train_dataset_{i}.pt')
        test_dataset_path = os.path.join(root_dir, f'test_dataset_{i}.pt')

        if os.path.exists(train_dataset_path) and os.path.exists(test_dataset_path):

            train_dataset = torch.load(train_dataset_path)
            test_dataset = torch.load(test_dataset_path)
        else:
            train_dataset = PermutedMNIST(train=True, permute_idx=permute_idx)
            test_dataset = PermutedMNIST(train=False, permute_idx=permute_idx)
            torch.save(train_dataset, train_dataset_path)
            torch.save(test_dataset, test_dataset_path)

        train_loader[i] = DataLoader(train_dataset,
                                     batch_size=batch_size,
                                     shuffle=True,
                                    #  num_workers=1,
                                     worker_init_fn=worker_init_fn,
                                     pin_memory=True)
        test_loader[i] = DataLoader(test_dataset,
                                    batch_size=batch_size,
                                    shuffle=False,
                                    #  num_workers=1,
                                     worker_init_fn=worker_init_fn,
                                     pin_memory=True)

    return train_loader, test_loader

class MLP(nn.Module):
    def __init__(self, input_size=28 * 28, num_classes_per_task=10, hidden_size=[400, 400, 400]):
        super(MLP, self).__init__()
        self.hidden_size = hidden_size
        self.input_size = input_size
        
        # 初始化类别计数器
        self.total_classes = num_classes_per_task
        self.num_classes_per_task = num_classes_per_task
        
        # 定义网络结构
        self.fc1 = nn.Linear(input_size, hidden_size[0])
        self.fc2 = nn.Linear(hidden_size[0], hidden_size[1])
        self.fc_before_last = nn.Linear(hidden_size[1], hidden_size[2])
        
        self.fc_out = nn.Linear(hidden_size[2], self.total_classes)
    
    def forward(self, input, task_id=-1):
        x = F.relu(self.fc1(input))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc_before_last(x))
        x = self.fc_out(x)
        return x

3 Baseline代码

没有任何连续学习算法的Baseline代码实现仅仅是将任务逐个训练。具体过程为:依次加载每个任务的数据集,独立训练模型,而不考虑模型对前一个任务的记忆能力。


class Baseline:
    def __init__(self, num_classes_per_task=10, num_tasks=10, batch_size=256, epochs=2, neurons=0):
        self.num_classes_per_task = num_classes_per_task
        self.num_tasks = num_tasks
        self.batch_size = batch_size
        self.epochs = epochs
        self.neurons = neurons
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.input_size = 28 * 28

        # Initialize model
        self.model = MLP(num_classes_per_task=self.num_classes_per_task).to(self.device)
        self.criterion = nn.CrossEntropyLoss()


        # Get dataset
        self.train_loaders, self.test_loaders = get_permute_mnist(self.num_tasks, self.batch_size)
    def evaluate(self, test_loader, task_id):
        self.model.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for images, labels in test_loader:
                # Move data to GPU in batches
                images = images.view(-1,self.input_size)
                images = images.to(self.device, non_blocking=True)
                labels = labels.to(self.device, non_blocking=True)
                outputs = self.model(images, task_id)
                predicted = torch.argmax(outputs, dim=1)
                correct += (predicted == labels).sum().item()
                total += labels.size(0)

        return 100.0 * correct / total


    def train_task(self, train_loader,optimizer, task_id):
        self.model.train()
        for images, labels in train_loader:
            images = images.view(-1,self.input_size)
            images = images.to(self.device, non_blocking=True)
            labels = labels.to(self.device, non_blocking=True)
            optimizer.zero_grad()
            outputs = self.model(images, task_id)
            loss = self.criterion(outputs, labels)
            loss.backward()
            optimizer.step()

    def run(self):
        all_avg_acc = []
        
        for task_id in range(self.num_tasks):
            train_loader = self.train_loaders[task_id]
            self.model = self.model.to(self.device)
            optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-3, weight_decay=1e-4)
            for epoch in range(self.epochs):
                self.train_task(train_loader,optimizer, task_id)
            task_acc = []
            for eval_task_id in range(task_id + 1):
                accuracy = self.evaluate(self.test_loaders[eval_task_id], eval_task_id)
                task_acc.append(accuracy)
            mean_avg = np.round(np.mean(task_acc), 2)

            print(f"Task {task_id}: Task Acc = {task_acc},AVG={mean_avg}")
            all_avg_acc.append(mean_avg)
        avg_acc = np.mean(all_avg_acc)
        print(f"Task AVG Acc: {all_avg_acc},AVG = {avg_acc}")

if __name__ == '__main__':
    print('Baseline'+"=" * 50)
    random.seed(2024)
    torch.manual_seed(2024)
    np.random.seed(2024)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    baseline = Baseline(num_classes_per_task=10, num_tasks=3, batch_size=256, epochs=2)
    baseline.run()

Baseline==================================================

Task 0: Task Acc = [96.78],AVG=96.78

Task 1: Task Acc = [85.19, 97.0],AVG=91.1

Task 2: Task Acc = [52.66, 89.14, 97.27],AVG=79.69

Task AVG Acc: [96.78, 91.1, 79.69],AVG = 89.19

可以看到模型在学习新任务后,旧任务的准确率在下降,在学习完Task2后,第一个任务的准确率只有52.66,第二个任务的准确率只有89.14。

4 MAS算法

4.1 算法原理

RWalk算法是一种增量学习框架,它通过结合Fisher信息矩阵和优化路径上参数重要性的累积来平衡对旧任务的记忆保持(避免灾难性遗忘)和新任务的学习能力(减少固执性)。

论文《Chaudhry A, Dokania P K, Ajanthan T, et al. Riemannian walk for incremental learning: Understanding forgetting and intransigence[C]//Proceedings of the European conference on computer vision (ECCV). 2018: 532-547.》Riemannian Walk for Incremental Learning (RWalk) 算法中,计算重要性权重和损失函数的公式如下:

  1. 重要性权重的计算:

    • Fisher 信息矩阵的更新:
      F t θ = α F t θ + ( 1 − α ) F t − 1 θ F_t^\theta = \alpha F_t^\theta + (1 - \alpha) F_{t-1}^\theta Ftθ​=αFtθ​+(1−α)Ft−1θ​
      其中, F t θ F_t^\theta Ftθ​ 是在第 t t t 次迭代时的 Fisher 信息矩阵, α \alpha α 是一个超参数。

    • 参数重要性得分的累积:
      s t 2 t 1 ( θ i ) = ∑ t = t 1 t 2 Δ L t t + Δ t ( θ i ) 1 2 F t θ i Δ θ i ( t ) 2 + ϵ s_{t_2}^{t_1}(\theta_i) = \sum_{t=t_1}^{t_2} \frac{\Delta L_t^{t+\Delta t}(\theta_i)}{\frac{1}{2} F_t^{\theta_i} \Delta \theta_i(t)^2 + \epsilon} st2​t1​​(θi​)=t=t1​∑t2​​21​Ftθi​​Δθi​(t)2+ϵΔLtt+Δt​(θi​)​

      其中, Δ L t t + Δ t ( θ i ) \Delta L_t^{t+\Delta t}(\theta_i) ΔLtt+Δt​(θi​) 是参数 θ i \theta_i θi​ 从时间步 t t t 到 t + Δ t t + \Delta t t+Δt 的损失变化, F t θ i F_t^{\theta_i} Ftθi​​ 是第 t t t 次迭代时 θ i \theta_i θi​ 的 Fisher 信息, Δ θ i ( t ) = θ i ( t + Δ t ) − θ i ( t ) \Delta \theta_i(t) = \theta_i(t + \Delta t) - \theta_i(t) Δθi​(t)=θi​(t+Δt)−θi​(t), ϵ \epsilon ϵ 是一个正的常数。

  2. 损失函数的计算:

    • 最终目标函数 (RWalk):
      L ~ k ( θ ) = L k ( θ ) + λ ∑ i = 1 P ( F k − 1 θ i + s t 0 t k − 1 ( θ i ) ) ( θ i − θ k − 1 i ) 2 \tilde{L}_k(\theta) = L_k(\theta) + \lambda \sum_{i=1}^P \left( F_{k-1}^{\theta_i} + s_{t_0}^{t_{k-1}}(\theta_i) \right) (\theta_i - \theta_{k-1}^i)^2 L~k​(θ)=Lk​(θ)+λi=1∑P​(Fk−1θi​​+st0​tk−1​​(θi​))(θi​−θk−1i​)2

    其中, L k ( θ ) L_k(\theta) Lk​(θ) 是第 k k k 个任务的损失函数, λ \lambda λ 是一个超参数, F k − 1 θ i F_{k-1}^{\theta_i} Fk−1θi​​ 是第 k − 1 k-1 k−1 个任务结束时 θ i \theta_i θi​ 的 Fisher 信息, s t 0 t k − 1 ( θ i ) s_{t_0}^{t_{k-1}}(\theta_i) st0​tk−1​​(θi​) 是从第 t 0 t_0 t0​ 次迭代到第 t k − 1 t_{k-1} tk−1​ 次迭代 θ i \theta_i θi​ 的重要性得分, θ k − 1 i \theta_{k-1}^i θk−1i​ 是第 k − 1 k-1 k−1 个任务结束时 θ i \theta_i θi​ 的值。

4.2 代码实现


import torch
import torch.nn as nn
import random
import warnings
import numpy as np
import warnings
warnings.filterwarnings("ignore")

# Set seeds
random.seed(2024)
torch.manual_seed(2024)
np.random.seed(2024)

# Ensure deterministic behavior
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True  # Enable for GPU efficiency

class RWalk:
    def __init__(self, num_classes_per_task=10, num_tasks=10, batch_size=256, epochs=2, neurons=0):
        self.num_classes_per_task = num_classes_per_task
        self.num_tasks = num_tasks
        self.batch_size = batch_size
        self.epochs = epochs
        self.neurons = neurons
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.input_size = 28 * 28

        self.model = MLP(num_classes_per_task=self.num_classes_per_task).to(self.device)
        self.criterion = nn.CrossEntropyLoss()
        self.scaler = torch.cuda.amp.GradScaler()  # Enable mixed precision
        self.importance_dict = {}
        self.previous_params = {}
        self.path_integral = {}

        self.train_loaders, self.test_loaders = get_permute_mnist(self.num_tasks, self.batch_size)

        self.update_params()

    def evaluate(self, test_loader, task_id):
        self.model.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for images, labels in test_loader:
                images = images.view(-1,self.input_size)
                images = images.to(self.device, non_blocking=True)
                labels = labels.to(self.device, non_blocking=True)
                outputs = self.model(images, task_id)
                predicted = torch.argmax(outputs, dim=1)
                correct += (predicted == labels).sum().item()
                total += labels.size(0)

        return 100.0 * correct / total

    def train_task(self, train_loader,optimizer, task_id):
        self.model.train()

        for images, labels in train_loader:
            images = images.view(-1,self.input_size)
            images = images.to(self.device, non_blocking=True)
            labels = labels.to(self.device, non_blocking=True)
            optimizer.zero_grad()
            outputs = self.model(images, task_id)
            if task_id > 0:
                loss = self.rwalk_multi_objective_loss(outputs, labels)
            else:
                loss = self.criterion(outputs, labels)
            loss.backward()
            optimizer.step()

    def compute_importance(self, data_loader, task_id):
        # EWC++ 
        importance_dict = {name: torch.zeros_like(param, device=self.device) for name, param in self.model.named_parameters() if 'task' not in name}
        self.model.eval()

        for images, labels in data_loader:
            images = images.view(-1,self.input_size)
            images = images.to(self.device, non_blocking=True)
            labels = labels.to(self.device, non_blocking=True)
            self.model.zero_grad()

            outputs = self.model(images, task_id=task_id)
            loss = self.criterion(outputs, labels)
            loss.backward()
            for name, param in self.model.named_parameters():
                if name in importance_dict and param.requires_grad:
                    importance_dict[name] += param.grad ** 2 / len(data_loader)

        # 移动平均更新Fisher Matrix
        for name in importance_dict:
            if name in self.importance_dict:
                self.importance_dict[name] = 0.9 * self.importance_dict[name] + 0.1 * importance_dict[name]
            else:
                self.importance_dict[name] = importance_dict[name]

    def update_path_integral(self):
        # 计算累计重要性
        for name, param in self.model.named_parameters():
            if name in self.path_integral:
                self.path_integral[name] += (param.detach() - self.previous_params[name]) ** 2
            else:
                self.path_integral[name] = (param.detach() - self.previous_params[name]) ** 2

    def update_params(self):
        for name, param in self.model.named_parameters():
            self.previous_params[name] = param.clone().detach()

    def update(self, dataset, task_id):
        self.compute_importance(dataset, task_id)
        self.update_path_integral()
        self.update_params()

    def rwalk_multi_objective_loss(self, outputs, labels, lambda_=100):
        regularization_loss = 0.0
        for name, param in self.model.named_parameters():
            if name in self.importance_dict and name in self.previous_params and name in self.path_integral:
                fisher_importance = self.importance_dict[name]
                path_penalty = self.path_integral[name]
                previous_param = self.previous_params[name]
                regularization_loss += ((fisher_importance + path_penalty) * (param - previous_param).pow(2)).sum()
        loss = self.criterion(outputs, labels)
        total_loss = loss + lambda_ * regularization_loss
        return total_loss

    def run(self):
        all_avg_acc = []
        for task_id in range(self.num_tasks):
            train_loader = self.train_loaders[task_id]
            self.model = self.model.to(self.device)
            optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-3, weight_decay=1e-4)
            for epoch in range(self.epochs):
                self.train_task(train_loader,optimizer, task_id)
            self.update(train_loader, task_id)

            task_acc = []
            for eval_task_id in range(task_id + 1):
                accuracy = self.evaluate(self.test_loaders[eval_task_id], eval_task_id)
                task_acc.append(accuracy)
            mean_avg = np.round(np.mean(task_acc), 2)
            all_avg_acc.append(mean_avg)
            print(f"Task {task_id}: Task Acc = {task_acc},AVG={mean_avg}")
        avg_acc = np.mean(all_avg_acc)
        print(f"Task AVG Acc: {all_avg_acc}, AVG = {avg_acc}")

if __name__ == '__main__':
    print('RWalk' + "=" * 50)
    for _ in range(1):
        random.seed(2024)
        torch.manual_seed(2024)
        np.random.seed(2024)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        rwalk = RWalk(num_classes_per_task=10, num_tasks=3, batch_size=256, epochs=2)
        rwalk.run()

RWalk==================================================
Task 0: Task Acc = [96.78],AVG=96.78
Task 1: Task Acc = [94.91, 95.73],AVG=95.32
Task 2: Task Acc = [86.88, 89.66, 93.76],AVG=90.1
Task AVG Acc: [96.78, 95.32, 90.1], AVG = 94.06666666666666

在学习完每个任务后,旧任务的准确率只是轻微的下降,说明该算法有效的缓解了灾难性遗忘。

标签:task,ECCV,self,torch,RWalk,算法,train,id,size
From: https://blog.csdn.net/weixin_43935696/article/details/144942305

相关文章