首页 > 其他分享 >tf2自定义优化器

tf2自定义优化器

时间:2022-09-19 11:58:11浏览次数:127  
标签:自定义 decay self rate mom ada var 优化 tf2

# -*- coding: utf-8 -*-

from tensorflow.python.eager import def_function
from tensorflow.python.framework import ops
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops


class Adammom(optimizer_v2.OptimizerV2):
    """Adammom Optimizer

    w: trainable weights
    d2sum = 0.0
    ada_decay_rate = 0.9999
    ada_epsilon = 1e-8
    learning_rate = 0.0001
    mom_decay_rate = 0.99

    d2sum = d2sum * ada_decay_rate + 1
    for i in range(len(w)):
        g2sum = g2sum[i] * ada_decay_rate + grad[i] * grad[i]
        scale = sqrt((1.0 + ada_epsilon)/(g2sum/d2sum + ada_epsilon))
        velocity[i] = mom_decay_rate * velocity[i] + (1 - mom_decay_rate) * grad[i]
        w[i] = w[i] - learning_rate * velocity[i] * scale

    :args:
    ada_decay_rate: (float) The decay rate to control g2sum's decay. Defaults to be 0.9999.
    ada_epsilon: (float) A super small value to correct the scale. Defaults to be 1e-08.
    learning_rate: (float) The learning rate of AdamMom. Defaults to be 0.0001.
    mom_decay_rate: (float) The decay rate of moment. Defaults to be 0.99.
    """

    _HAS_AGGREGATE_GRAD = True

    def __init__(
        self,
        learning_rate=0.0001,
        ada_decay_rate=0.9999,
        ada_epsilon=1e-08,
        mom_decay_rate=0.99,
        name="Adammom",
        **kwargs
    ):
        super(Adammom, self).__init__(name, **kwargs)
        self._set_hyper("learning_rate", kwargs.get("lr", learning_rate))
        self._set_hyper("decay", self._initial_decay)
        self._set_hyper("ada_decay_rate", ada_decay_rate)
        self._set_hyper("mom_decay_rate", mom_decay_rate)
        self.ada_epsilon = ada_epsilon

    def _create_slots(self, var_list):
        # Create slots for the first and second moments.
        # Separate for-loops to respect the ordering of slot variables from v1.
        for var in var_list:
            self.add_slot(var, "d2sum")
        for var in var_list:
            self.add_slot(var, "g2sum")
        for var in var_list:
            self.add_slot(var, "velocity")

    def _prepare_local(self, var_device, var_dtype, apply_state):
        super(Adammom, self)._prepare_local(var_device, var_dtype, apply_state)

        ada_decay_rate_t = array_ops.identity(
            self._get_hyper("ada_decay_rate", var_dtype)
        )
        mom_decay_rate_t = array_ops.identity(
            self._get_hyper("mom_decay_rate", var_dtype)
        )

        apply_state[(var_device, var_dtype)].update(
            dict(
                ada_epsilon=ops.convert_to_tensor_v2_with_dispatch(
                    self.ada_epsilon, var_dtype
                ),
                ada_decay_rate_t=ada_decay_rate_t,
                mom_decay_rate_t=mom_decay_rate_t,
            )
        )

    @def_function.function(jit_compile=True)
    def _resource_apply_dense(self, grad, var, apply_state=None):
        var_device, var_dtype = var.device, var.dtype.base_dtype
        coefficients = (apply_state or {}).get(
            (var_device, var_dtype)
        ) or self._fallback_apply_state(var_device, var_dtype)
        # TODO(lebronzheng): The following calculations should be fused into a c++ kernel
        d2sum = self.get_slot(var, "d2sum")
        g2sum = self.get_slot(var, "g2sum")
        ada_decay_rate = coefficients["ada_decay_rate_t"]
        # d2sum = d2sum * ada_decay_rate + 1
        d2sum.assign(d2sum * ada_decay_rate + 1)
        # g2sum = g2sum[i] * ada_decay_rate + grad[i] * grad[i]
        g2sum.assign(g2sum * ada_decay_rate + math_ops.square(grad))
        # scale = sqrt((1.0 + ada_epsilon)/(g2sum/d2sum + ada_epsilon))
        ada_epsilon = coefficients["ada_epsilon"]
        scale = math_ops.sqrt((1 + ada_epsilon) / (g2sum / d2sum + ada_epsilon))
        # velocity = mom_decay_rate * velocity + (1 - mom_decay_rate) * grad
        mom_decay_rate = coefficients["mom_decay_rate_t"]
        velocity = self.get_slot(var, "velocity")
        velocity.assign(mom_decay_rate * velocity + (1 - mom_decay_rate) * grad)
        # w = w - learning_rate * velocity * scale
        var.assign_sub(coefficients["lr_t"] * velocity * scale)

    @def_function.function(jit_compile=True)
    def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
        raise NotImplemented("Not implemented currently")

    def get_config(self):
        config = super(Adammom, self).get_config()
        config.update(
            {
                "learning_rate": self._serialize_hyperparameter("learning_rate"),
                "decay": self._initial_decay,
                "ada_decay_rate": self._serialize_hyperparameter("ada_decay_rate"),
                "mom_decay_rate": self._serialize_hyperparameter("mom_decay_rate"),
                "ada_epsilon": self.ada_epsilon,
            }
        )
        return config

  

1. _resource_apply_sparse主要为稀疏场景设计,例如实现LazyAdam,可以选取指定的行更新,其他行的不更新. 2.self._iterations表示优化器更新的次数,在一些使用的time step的优化器中有用,例如adam中计算β的t次方中的t. 但是这个iterations是优化器级别的,也就是说优化器中的所有variable共用一个iterations. 如果每轮迭代是全部参数都进行更新,那没任何问题,但是如果每轮只更新部分参数,那么其他参数的t等价于也被+1了. 会导致计算的公式不是adam原始公式中的结果. 当然这个未必一定会影响效果. 需要实验测试. 如果需要实现一个参数级别的iteration,只需要把iteration这个variable在_create_slot中创建,然后每次apply的时候自动加1. 3._create_slot相当于定义训练参数之外的优化器参数,例如:momentum,energy等

标签:自定义,decay,self,rate,mom,ada,var,优化,tf2
From: https://www.cnblogs.com/deepllz/p/16707225.html

相关文章

  • 自定义docker网络与自定义的网络之间的连通
    一、自定义一个docker网络1、创建一个自定义网络[root@master~]#dockernetworkcreate--driverbridge--subnet10.192.0.0/24--gateway10.192.0.1mynet806b16d......
  • 两目标投资组合优化
    两目标投资组合优化回报与风险Photoby帕特里克·魏森伯格on不飞溅回报与风险双目标优化问题的一个经典例子是诺贝尔经济学奖得主HarryMarkowitz提出的投资组......
  • centos 查看程序文件数量_centos 7.x 文件数量限制优化
    https://blog.csdn.net/weixin_42342010/article/details/111951283?utm_medium=distribute.wap_aggpage_search_result.none-task-blog-2~aggregatepage~first_rank_ecpm_......
  • ES6 对String做的常用升级优化
    ES61.let有什么用,为什么有了var还要使用let在ES6之前,声明变量只能用var,var声明变量有很多不合理的点,准确的说是因为ES5中没有块级作用域是很不合理的,甚至可以说是一门语......
  • 37. OKhttp自定义配置
    37.OKhttp自定义配置37.1Builder构建者OkHttpClientokHttpClient=newOkHttpClient.Builder().build();提供更多操作37.2拦截器OkHttpClientokHttpClient=newO......
  • 记一次sql优化
    目录记一次sql优化记一次sql优化总数据70w原sql--优化前SELECTCOUNT(CASEWHENBUREAU_TASK='0'ANDDECLARE_PERSON_ID='bfa800bd0000000014e1beec0000000......
  • MySQL 优化篇(二)
    分析查询语句:EXPLAIN定位了查询慢的SQL之后,我们就可以使用EXPLAIN或DESCRIBE工具做针对性的分析查询语句。DESCRIBE语句的使用方法与EXPLAIN语句是一样的,并且分析结果也......
  • 自定义注解实现加解密及脱敏
    title:ep_自定义注解实现加解密及脱敏date:2020-04-2809:44定义自定义注解@Documented@Target({ElementType.FIELD})@Retention(RetentionPolicy.RUNTIME)@Order......
  • 自定义typora主题样式
    目录自定义typora主题样式准备工作h1~h6标签内容居中修改默认字体样式自定义颜色名称,并修改选中字体的背景颜色添加背景右键菜单颜色mygithub.css自定义typora主题样式​......
  • 并查集优化
    并查集及其优化并查集可以动态地连通两个点,可以非常快速判断两个点是否连通。假设存在n个节点,我们先将所有结点的leader标为自身;每次连接节点i和j时,我们可以将i......