首页 > 编程语言 >tensorflow_probability.python.bijectors的一些使用

tensorflow_probability.python.bijectors的一些使用

时间:2022-12-28 17:02:12浏览次数:62  
标签:__ bijectors log probability python self forward tf def

 

 

网上见到一个TensorFlow的代码,没见过这个形式的,是概率编程的代码:

# coding=utf-8
# Copyright 2020 The TF-Agents Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tanh bijector."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf  # pylint: disable=g-explicit-tensorflow-version-import
from tensorflow_probability.python.bijectors import bijector


__all__ = [
    "Tanh",
]


class Tanh(bijector.Bijector):
  """Bijector that computes `Y = tanh(X)`, therefore `Y in (-1, 1)`.

  This can be achieved by an affine transform of the Sigmoid bijector, i.e.,
  it is equivalent to
  ```
  tfb.Chain([tfb.Affine(shift=-1, scale=2.),
             tfb.Sigmoid(),
             tfb.Affine(scale=2.)])
  ```

  However, using the `Tanh` bijector directly is slightly faster and more
  numerically stable.
  """

  def __init__(self, validate_args=False, name="tanh"):
    parameters = dict(locals())
    super(Tanh, self).__init__(
        forward_min_event_ndims=0,
        validate_args=validate_args,
        parameters=parameters,
        name=name)

  def _forward(self, x):
    return tf.nn.tanh(x)

  def _inverse(self, y):
    # 0.99999997 is the maximum value such that atanh(x) is valid for both
    # tf.float32 and tf.float64
    y = tf.where(tf.less_equal(tf.abs(y), 1.),
                 tf.clip_by_value(y, -0.99999997, 0.99999997),
                 y)
    return tf.atanh(y)

  def _forward_log_det_jacobian(self, x):
    #  This formula is mathematically equivalent to
    #  `tf.log1p(-tf.square(tf.tanh(x)))`, however this code is more numerically
    #  stable.

    #  Derivation:
    #    log(1 - tanh(x)^2)
    #    = log(sech(x)^2)
    #    = 2 * log(sech(x))
    #    = 2 * log(2e^-x / (e^-2x + 1))
    #    = 2 * (log(2) - x - log(e^-2x + 1))
    #    = 2 * (log(2) - x - softplus(-2x))
    return 2.0 * (
        tf.math.log(tf.constant(2.0, dtype=x.dtype)) - x - tf.nn.softplus(
            -2.0 * x))

 

 

 

================================================

 

 

由于不是很理解这个代码的意思,于是找了下TensorFlow的官方文档:

https://tensorflow.google.cn/probability/api_docs/python/tfp/bijectors/Bijector

 

 

  class Exp(Bijector):

    def __init__(self, validate_args=False, name='exp'):
      super(Exp, self).__init__(
          validate_args=validate_args,
          forward_min_event_ndims=0,
          name=name)

    def _forward(self, x):
      return tf.exp(x)

    def _inverse(self, y):
      return tf.log(y)

    def _inverse_log_det_jacobian(self, y):
      return -self._forward_log_det_jacobian(self._inverse(y))

    def _forward_log_det_jacobian(self, x):
      # Notice that we needn't do any reducing, even when`event_ndims > 0`.
      # The base Bijector class will handle reducing for us; it knows how
      # to do so because we called `super` `__init__` with
      # `forward_min_event_ndims = 0`.
      return x
  ```

 

标签:__,bijectors,log,probability,python,self,forward,tf,def
From: https://www.cnblogs.com/devilmaycry812839668/p/17010517.html

相关文章

  • python中global 和 nonlocal 的作用域
    python引用变量的顺序: 当前作用域局部变量->外层作用域变量->当前模块中的全局变量->python内置变量。一globalglobal关键字用来在函数或其他局部作用域中使用全局变量。......
  • 【leetcode】3: 无重复字串的最长子串(python)
    给定一个字符串s,请你找出其中不含有重复字符的 最长子串 的长度。 示例 1:输入:s="abcabcbb"输出:3解释:因为无重复字符的最长子串是"abc",所以其长度为3......
  • [oeasy]python0033_回车_carriage_return_figlet_字体变大
    回到开头回忆上次内容进程前后台切换<kbd>ctrl</kbd>+<kbd>z</kbd>把当前进程切换到后台并暂停​​jobs​​查看所有作业用​​fg​​可以把后台进程再切回前台​​......
  • python logging配置
    python中,logging由logger,handler,filter,formater四个部分组成。logger是提供我们记录日志的方法;handler是让我们选择日志的输出地方,如:控制台,文件,邮件发送等,一个logger添加......
  • python中的mysql操作教程及实例
    一.数据库在自动化测试中的应用存测试数据有的时候大批量的数据,我们需要存到数据库中,在测试的时候才能用到,测试的时候就从数据库中读取出来。这点是非常重要的!存测试结......
  • 数值计算:前向和反向自动微分(Python实现)
    1自动微分我们在《数值分析》课程中已经学过许多经典的数值微分方法。许多经典的数值微分算法非常快,因为它们只需要计算差商。然而,他们的主要缺点在于他们是数值的,这意味......
  • python中的集合推导式
    集合推导式可用来去重需求:将列表list1=[2,2,2,3,4,4,4]中的偶数进行筛选,并且去重list1=[2,2,2,3,4,4,4]set1={iforiinlist1ifi%2==0}print(set......
  • python中的列表推导式
    1.单列表,单条件求1-20之间的偶数list1=[]foriinrange(1,21):ifi%2==0:list1.append(i)print(list1)列表推导式list2=[iforiinrange(1,21)if......
  • Python爬虫(一)热身
    基础操作一importurllibimportchardet#字符集检测url="http://www.163.com/"html=urllib.urlopen(url)printhtml.headers#头部信息printhtml.getcode()#状态......
  • 当我把用Python做的课堂点名系统献给各科老师后,再也没挂过科
    刚上大学的表弟问我,大学准备好好玩玩,问我有没有什么不挂科的秘诀。哎,这可就问对人了,要想不挂科,先把老师贿赂好,当然,咱们说的贿赂不是送钱啥的,这不是侮辱老师吗?于是我......