首页 > 编程语言 >机器学习算法原理实现——EM算法

机器学习算法原理实现——EM算法

时间:2023-11-16 12:01:32浏览次数:22  
标签:EM 机器 算法 tails heads theta likelihood

【EM算法简介】

EM算法,全称为期望最大化算法(Expectation-Maximization Algorithm),是一种迭代优化算法,主要用于含有隐变量的概率模型参数的估计。EM算法的基本思想是:如果给定模型的参数,那么可以根据模型计算出隐变量的期望值;反过来,如果给定隐变量的值,那么可以通过最大化似然函数来估计模型的参数。EM算法就是通过交替进行这两步来找到参数的最大似然估计。

EM算法的基本步骤如下:

1. 初始化模型参数
2. E步:计算隐变量的期望值
3. M步:最大化似然函数,更新模型参数
4. 重复步骤2和3,直到模型参数收敛

【EM算法举例】

K-means算法可以被看作是一种特殊的EM算法。在K-means算法中,我们试图找到一种方式将数据点分配到K个集群中,使得每个数据点到其所在集群中心的距离之和最小。

如果我们将集群分配看作是隐变量,那么K-means算法就可以看作是EM算法:

1. E步:期望步骤。给定当前的集群中心(模型参数),我们可以计算每个数据点最近的集群中心,也就是将每个数据点分配到一个集群中。这个步骤就是计算隐变量的期望值。

2. M步:最大化步骤。给定当前的集群分配(隐变量的值),我们可以计算新的集群中心,也就是每个集群中所有数据点的均值。这个步骤就是最大化似然函数,更新模型参数。

通过交替进行E步和M步,K-means算法可以找到一种集群分配和集群中心,使得每个数据点到其所在集群中心的距离之和最小。这就是K-means算法使用EM算法的地方。

 

【再举一个例子】

见:https://zhuanlan.zhihu.com/p/78311644 写得非常好,关键摘录如下:

机器学习算法原理实现——EM算法_最大似然估计

机器学习算法原理实现——EM算法_初始化_02

 

机器学习算法原理实现——EM算法_最大似然估计_03

 

 

【python编程实现】

import math
import random

def coin_em(rolls, theta_A=None, theta_B=None, maxiter=10000, tol=1e-6):
    # 初始化参数
    theta_A = theta_A or random.random()
    theta_B = theta_B or random.random()
    loglike_old = 0
    for i in range(maxiter):
        # E步
        heads_A, tails_A, heads_B, tails_B = e_step(rolls, theta_A, theta_B)
        # M步
        theta_A, theta_B = m_step(heads_A, tails_A, heads_B, tails_B)
        # 计算对数似然
        loglike_new = loglikelihood(rolls, theta_A, theta_B)
        # 检查收敛
        if abs(loglike_new - loglike_old) < tol:
            break
        else:
            loglike_old = loglike_new
    return theta_A, theta_B

def e_step(rolls, theta_A, theta_B):
    heads_A, tails_A, heads_B, tails_B = 0, 0, 0, 0
    for trial in rolls:
        likelihood_A = likelihood(trial, theta_A)
        likelihood_B = likelihood(trial, theta_B)
        p_A = likelihood_A / (likelihood_A + likelihood_B)
        p_B = 1 - p_A
        heads_A += p_A * trial.count("H")
        tails_A += p_A * trial.count("T")
        heads_B += p_B * trial.count("H")
        tails_B += p_B * trial.count("T")
    return heads_A, tails_A, heads_B, tails_B

def m_step(heads_A, tails_A, heads_B, tails_B):
    theta_A = heads_A / (heads_A + tails_A)
    theta_B = heads_B / (heads_B + tails_B)
    return theta_A, theta_B

def likelihood(roll, theta):
    numHeads = roll.count("H")
    flips = len(roll)
    return (theta**numHeads) * ((1-theta)**(flips-numHeads))

def loglikelihood(rolls, theta_A, theta_B):
    total = 0
    for roll in rolls:
        heads = roll.count("H")
        tails = roll.count("T")
        total += math.log(0.5 * likelihood(roll, theta_A) + 0.5 * likelihood(roll, theta_B))
    return total

# 测试
rolls = ["HTTTHHTHTH", "HHHHTHHHHH", "HTHHHHHTHH", "HTHTTTHHTT", "THHHTHHHTH"]
print(coin_em(rolls))

  

输出:

(0.7967659656145668, 0.5195829299707858)

 

和原始答案比较接近。

 

 

 



标签:EM,机器,算法,tails,heads,theta,likelihood
From: https://blog.51cto.com/u_11908275/8416153

相关文章

  • 机器学习算法原理实现——朴素贝叶斯
    【先说条件概率】条件概率是指在某个事件发生的条件下,另一个事件发生的概率。以下是一个实际的例子:假设你有一副扑克牌(不包括大小王,共52张牌),你随机抽一张牌。我们设事件A为"抽到的牌是红色的"(红心和方块为红色,共26张),事件B为"抽到的牌是心"(红心共13张)。1.首先,我们可以计算事件A和事......
  • 机器学习算法原理实现——最大熵模型
    【写在前面】在sklearn库中,没有直接称为"最大熵模型"的类,但是有一个与之非常相似的模型,那就是LogisticRegression。逻辑回归模型可以被视为最大熵模型的一个特例,当问题是二分类问题,且特征函数是输入和输出的线性函数时,最大熵模型就等价于逻辑回归模型。【最大熵模型的原理】最大熵......
  • systemctl mask firewalld
    systemctlmaskfirewalldsystemctl--helpmaskNAME...MaskoneormoreunitsunmaskNAME...Unmaskoneormoreunits[root@hecs-98663~]#systemctlstatusfirewalld●firewalld.service-firewalld-dynamicfirewall......
  • (倒推2)E:\mmdetection-main\demo\image_demo.py 代码解读
    #Copyright(c)OpenMMLab.Allrightsreserved."""ImageDemo.Thisscriptadoptsanewinfenenceclass,currentlysupportsimagepath,np.arrayandfolderinputformats,andwillsupportvideoandwebcaminthefuture.Example:Save......
  • 机器学习——注意力汇聚:Nadaraya-Watson 核回归
    上节介绍了框架下的注意力机制的主要成分 图10.1.3:查询(自主提示)和键(非自主提示)之间的交互形成了注意力汇聚;注意力汇聚有选择地聚合了值(感官输入)以生成最终的输出。本节将介绍注意力汇聚的更多细节,以便从宏观上了解注意力机制在实践中的运作方式。具体来说,1964年提出的Nadara......
  • 算法刷题记录-哈希表
    算法刷题记录-哈希表有效的字母异位词给定两个字符串*s*和*t*,编写一个函数来判断*t*是否是*s*的字母异位词。注意:若*s*和*t*中每个字符出现的次数都相同,则称*s*和*t*互为字母异位词。示例1:输入:s="anagram",t="nagaram"输出:true示例2:输入:s......
  • 获取所有指定类名的元素:getElementsByClassName 注意是带s的
    下列不属于javascript中查找元素的方法的是()AgetElementByClassName()BgetElementsByTagName()CgetElementById()DgetElementsByName()正确答案:A选择A错在Elements。因为这个方法可以返回一组节点。A.获取所有指定类名的元素:getElementsByClassNamevarx=documen......
  • element menu结构 解释
    在使用element-uiMenu菜单的时候,一开始看很蒙蔽主要是因为这个组件里面有的东西有点多:而且还是嵌套嵌套这样的.整的就很难受.然后我就开始倒腾,一个一个拆解.最后得出结论标签需要放在最外层这个放在这个里层的任何位置,表示子菜单,然后和是配套的,下面解释......
  • elementui el-upload实现不自动上传,将上传内容放在formData里面,传递给后端
    //这种情况一般是要弹出一个弹框进行上传操作<el-uploadref="upload"action=""name="fileList":show-file-list="false":auto-upload=&qu......
  • 机器学习-小样本情况下如何机器学习
    交叉验证是在机器学习建立模型和验证模型参数时常用的办法。交叉验证,顾名思义,就是重复的使用数据,把得到的样本数据进行切分,组合为不同的训练集和测试集,用训练集来训练模型,用测试集来评估模型预测的好坏。在此基础上可以得到多组不同的训练集和测试集,某次训练集中的某样本在下次可......