一、手动法
二、利用lr_scheduler()提供的集中衰减函数
2.1 利用lr_lambda函数
具体使用:
from torch.optim import SGD, lr_scheduler
import matplotlib.pyplot as plt
from torch.nn import Module, Sequential, Linear, CrossEntropyLoss
# 定义网络模型
class model(Module):
def __init__(self):
super(model, self).__init__()
self.fc = Sequential(
Linear(1,10)
)
def forward(self, input):
output = self.fc(input)
return output
# 初始化网络模型
Model = model()
# 定义损失函数
Loss = CrossEntropyLoss()
# 创建优化器
lr = 0.01
optimizer = SGD(Model.parameters(), lr=lr)
# 定义一个list保存学习率
lr_list = []
# 定义学习率与轮数关系的函数
lambda1 = lambda epoch:0.95 ** epoch # 学习率 = 0.95**(轮数)
scheduler = lr_scheduler.LambdaLR(optimizer,lr_lambda = lambda1)
for epoch in range(100):
print("epoch={}, lr={}".format(epoch, optimizer.state_dict()['param_groups'][0]['lr']))
scheduler.step()
lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])
plt.plot(range(100),lr_list,color = 'r',label = 'LambdaLR')
plt.ylabel('learning rate')
plt.xlabel('epoch')
plt.legend()
plt.show()