一般来说,我们batch size 大一些,则learning rate也要大一些。且有一个数学关系,当我们改变batch_size的时候,可能遵循这样的规律:
newlearningrate = oldlearningrate × newbatchsize oldbatchsize {\text{newlearningrate}}={\text{oldlearningrate}} \times {\sqrt{\frac{{\text{newbatchsize}}}{{\text{oldbatchsize}}}}} newlearningrate=oldlearningrate×oldbatchsizenewbatchsize
举例来说,原先别人的batch size 为128, learning rate为0.0005, 那么当我们把batch size改为1024时,则新的学习率有这样的推荐值:0.0005 * sqrt(1024/128) = 0.0005 * sqrt(8) = 0.001412
import math
def calculate_new_learning_rate(old_learning_rate, new_batch_size, old_batch_size):
# 计算新的学习率
new_learning_rate = old_learning_rate * math.sqrt(new_batch_size / old_batch_size)
# 将新学习率转换为科学计数法,保留整数部分
exponent = int(math.floor(math.log10(new_learning_rate)))
coefficient = round(new_learning_rate / 10**exponent)
new_learning_rate_str = f"{coefficient}e{exponent}"
return new_learning_rate_str
# 示例
old_learning_rate = 5e-4
old_batch_size = 128
new_batch_size = 1024
# 调用函数计算新的学习率
new_learning_rate_str = calculate_new_learning_rate(old_learning_rate, new_batch_size, old_batch_size)
print(f"新的学习率为: {new_learning_rate_str}")
# 新的学习率为: 1e-3