EM算法学习

意图识别 <= Bi-LSTM+CRF <= 先懂CRF <= 先懂HMM <= 先懂EM
学习书籍:《概率论与数理统计》陈希儒.《统计学习方法》李航. 《PRML》神书.
此外参考了一些博客,看懂EM还是花了一些时间的,在李航那本书上做了很多笔记,所以整理出来这篇博客。

文章相关代码参见github:xinlp

这一篇就是从EM开始讲,EM的各种变量也统一是用《统计学习方法》这本书的,言归正传。

预备知识

Jensen不等式

看这个,传送门
凸函数和凹函数就不说了,就是二阶偏导恒大于等于0还是恒小于等于0.
y=x ^ 2,y'=2x,y''=2 ==> 凸函数
y=logx,y'=-1/x,y''=-1/x ^ 2 ==> 凹函数
大概就是凸的就是(先函数算出来x的值y,在对y进行求和或者积分,最后算出来均值或者期望)要比 (先对x进行求和或者积分算完的平均值,放进函数y求出来的结果)大。

凹函数就是不等式反过来。

极大似然估计

这个是参考概率论那本书,三个词,极大、似然和估计,一个词一个词的看

  • 极大,就是寻找函数在某个区间的极大值,一般是凹函数。
  • 似然,P(x|y),如果固定y,求x的分布那就是概率,这个函数也就是一个概率密度函数或者概率函数,比如P(x1|y)>P(x2|y),就是指在y的情况下,x1发生的可能性比x2要大;反过来想如果P(x|y1)>P(x|y2)(通常写做L(x;y))是什么呢?直面上就代表了y1情况下发生x的可能性比y2情况下发生x的可能性要大。这个L函数就是似然函数,反映了x固定的时候,对y取不同的值会产生不同的结果,即参数y是x发生的原因,但是y可能不是事件或者随机变量,所以不能说这个函数是概率,所以改用似然这个词。
  • 估计,当样本X分布是固定的,想知道什么样的参数Y,会产生这个结果X,就是一个参数估计的过程。

合起来说,极大似然估计就是已知样本X的情况下,这个"看起来最像"的参数值,这个估计(y1,y2,...,yk)就是参数Y的极大似然估计。
一个trick,通常计算不会直接计算L似然函数,而是计算log(L)即对数似然,因为X是(x1,x2,...,xn),X事件发生就是x1到xn同时发生,所以L(X;Y)=f(x1;y1,..yk)f(x2;y1,..yk)...f(xn;y1,..yk),这是一个连乘,利用对数函数可以更改为连加求和,因为求导数x各种加肯定比x各种乘要容易。

极大似然估计实例

借鉴这个博客提到的例子。
原理:假设在一个罐子中放着许多白球和黑球,并假定已经知道两种球的数目之比为1:3但是不知道那种颜色的球多。如果用放回抽样方法从罐中取5个球,观察结果为:黑、白、黑、黑、黑,估计取到黑球的概率为p;
假设p=1/4,则出现题目描述观察结果的概率为:(1/4)4 *(3/4) = 3/1024
假设p=3/4,则出现题目描述观察结果的概率为:(3/4)4 *(1/4) = 81/1024
由于81/1024 > 3/1024,因此任务p=3/4比1/4更能出现上述观察结果,所以p取3/4更为合理


(图片来自浙江大学概率论课程课件)

EM算法

定义

EM算法是一种迭代算法,在统计中被用于寻找,依赖于不可观察的隐性变量的概率模型中,参数的最大似然估计。在统计计算中,EM算法是在概率模型中寻找参数最大似然估计或者最大后验估计的算法,其中概率模型依赖于无法观测的隐性变量。EM算法经过两个步骤交替进行计算,第一步是计算期望(Exception),利用对隐藏变量的现有估计值,计算其最大似然估计值;第二步是极大化(Maximization),最大化在E步上求得的最大似然值来计算参数的值。M步上找到的参数估计值被用于下一个E步计算中,这个过程不断交替进行。

栗子——三硬币模型

问题描述

每个硬币都只有两种情况,就是正反,所以是一个二项分布(二项分布确实是离散变量最简单的一种分布,好多一讲分布就先拿二项分布举例子)。
假设:假设有三枚硬币,分别记为A、B、C。这些硬币正面的概率分别为π,p,q,进行如下的抛硬币实验:先掷硬币A,根据其结果选出硬币B或者硬币C,正面选硬币B,反面选硬币C,然后掷选出的硬币,掷硬币的结果出现正面记作1,出现反面记作0,独立地重复n次实验(这里n=10),然后观测结果如下:
1,1,0,1,0,0,1,0,1,1
假设只能观测到掷硬币的结果,不能观测掷硬币的过程,问如何估计三硬币正面出现的概率,即三硬币模型的参数π,p,q。

解题思路

用y来表示硬币掷出的结果,y可以等于1或者0,θ=(π,p,q);
当y=1,p(y|θ)=πp+(1-π)q;
当y=0,p(y|θ)=π(1-p)+(1-π)(1-q)
所以整体可以写为下面这样,也是参照二项式分布的标准写法:

这里随机变量y是观测变量,表示一次实验观测的结果是1或0,随机变量z是隐变量,表示未观测到的掷硬币A的结果,θ=(π,p,q)是模型参数,这一模型是以上数据的生成模型。再提醒一次,随机变量y的数据可以观测,随机变量z的数据不可观测。
将观测数据表示为Y,未观测数据表示为Z,则观测数据的似然函数是:

其实就是真实发生的流程,先求出来隐变量的概率分布,根据隐变量和模型参数算出来观测变量的的概率分布。

在该问题中,似然函数展开为

考虑求模型参数θ=(π,p,q)的极大似然估计,即:

这个问题没有解析解,只有通过迭代方法求解,EM算法就是可以用于求解这个问题的一个迭代算法,下面给出求解这个问题的EM算法过程。

解题步骤

EM算法首先选取参数的初值θ(0),然后通过下面的步骤迭代计算参数的估计址,直到收敛为止。EM算的第i+1次迭代过程如下:

上面这张图看了很久,这张图(9.5)的地方应该μ有个下标j,新版的书已经添上了。
E步的公式就是9.5,看懂9.5就好看懂剩下的了:

分母其实就是B硬币产生yj的概率+C硬币产生yj的概率=第j个观测是yj的概率(yj是观测序列的第j个值,可能为0或1)
分子为B产生yj的概率
所以μ两者相除就是产生yj来自B的概率

M步的参数就是更新θ的三个值,分别来看

9.6 更新π 因为π就是μ(yj来自B的概率),所以只需要对观测序列每一个值yj算出来自B的概率,求和之后再除以序列长度N就行了。
9.7 更新p 因为p是硬币B正的概率,只需要用来自硬币B的时候而且是正面的概率除以来自硬币B的概率就行,当yj=0,分子+0,yj=1,分子加上来自硬币B的概率,分母相当于一直乘以yj=1
9.8 更新q 和9.7类似

迭代计算(初值敏感)

如果θ0=(0.5,0.5,0.5),最后的θ=(0.5,0.6,0.6)
这是一个很合理的结果,因为观测序列出现了6次正面,又假设A是均匀的,所以肯定估计B、C出现正面肯定在6/10左右
如果θ0=(0.4,0.6,0.7),最后的θ=(0.4064,0.5368,0.6432)
三硬币模型就结束了

EM算法核心步骤——Q函数

从上面的三硬币模型看EM算法操作起来比较简单,但是看公式其实就比较复杂了。
一般的,用Y表示观测随机变量的数据,Z表示隐随机变量的数据,Y和Z连在一起称为完全数据,只有观测数据Y称为不完全数据,假设给定观测数据Y,其概率分布为P(Y|θ),那么不完全数据的似然函数就是P(Y|θ),对数似然函数是L(θ) = log(P(Y|θ)),假设Y和Z的联合概率分布是P(Y,Z|θ),那么完全数据的对数似然函数是logP(Y,Z|θ)。

EM算法通过迭代求L(θ) = log(P(Y|θ))的极大似然估计,每次迭代包括两步:E步,求期望,M步,求最大化,下面介绍EM算法的步骤:


这里面有一个Q函数,只要用EM算法就要找准这个Q函数,GMM和HMM都有自己的Q函数。
公式9.9和公式9.10两个合起来看可以这么说:

  1. Q函数想做的事情就是用θ(i)和Y求出来Z,然后极大似然估计Y,Z下的新的θ,即θ(i+1)。
  2. 公式9.9就是像求Z之后似然估计新θ
  3. 公式9.10就是得到极大化的新θ
    实际上Q函数对θ的参数分别求偏导=0,就可以得到用θ(i)表示的新的参数θ(i+1)
    ### EM算法具体导出
    面对一个含有隐变量的概率模型,目标是极大化观测数据Y(不完全数据)对于θ的对数似然函数,即极大化

但是我们不能直接的得到Z,所以上面这个函数是不能直接求出来的。
于是曲线救国,和神经网络的梯度下降理论类似,偏导=0求不出来,我慢慢的迭代逼近呀,只是不是利用反向传播而已,用的极大似然,希望每一次新的θ值比旧的θ值使Y发生的可能性变得更大,逐步逼近极大值。
下面证明的就是一次更比一次好:

不等号的原因是Jenson不等式,预备知识里提到过。
第三行能变成最后等号的原因是因为=1,所以被减数*1再合并。

然后可以得到L(θ)的下界B(θ),那我只要你的下限一直在极大化就行,于是下面求θ(i+1)极大化B(θ)可以证明就是在极大化Q(θ,θ(i)).


上面证明了EM算法为什么有那个Q函数?因为通过迭代引入下界B函数,极大B函数时就是极大Q函数,所以可以直接认为需要极大Q函数。但是并没有说明为什么极大Q函数最后就能得到那个估计出来的最好的值。
还需要证明两个东西:
一个是这个能收敛,第二个是稳定的时候能得到的就是极大值。
反正别人都证明了,书上给的是参考文献。

三硬币模型——java实现

import org.junit.jupiter.api.Test;

/**
 * @author unclewang
 */
public class EmTest {
    @Test
    public void test() {
        //每个硬币初始一次为正的概率
        double[] yita = m(0.2, 0.5);
        for (int i = 0; i < 100; i++) {
            yita = m(yita);
            System.out.println(yita[0] + "\t" + yita[1]);
        }
    }

    public double[] m(double... yita) {
        int[] nums = {5, 5, 9, 1, 8, 2, 4, 6, 7, 3};
        double[] e = new double[5];
        double[] m = new double[5];
        double[] m_ = new double[5];
        double[] n = new double[5];
        double[] n_ = new double[5];
        for (int i = 0; i < e.length; i++) {
            //e步
            e[i] = e(yita[0], nums[i * 2], yita[1]);
            m[i] = e[i] * nums[2 * i];
            m_[i] = e[i] * nums[2 * i + 1];
            n[i] = (1 - e[i]) * nums[2 * i];
            n_[i] = (1 - e[i]) * nums[2 * i + 1];
        }
        double yita1 = sum(m) / (sum(m) + sum(m_));
        double yita2 = sum(n) / (sum(n) + sum(n_));
        System.out.println("开始迭代");
        print(e);
        print(m);
        print(m_);
        print(n);
        print(n_);
        return new double[]{yita1, yita2};
    }

    public void print(double[] nums) {
        for (double a : nums) {
            System.out.print(a + "\t");
        }
        System.out.println();
    }

    public double sum(double[] nums) {
        double sum = 0;
        for (double a : nums) {
            sum += a;
        }
        return sum;
    }

    public double e(double a, double b, double c) {
        double e1 = Math.pow(a, b) * Math.pow(1 - a, 10 - b);
        double e2 = Math.pow(c, 10 - b) * Math.pow(1 - c, b);
        return e1 / (e1 + e2);
    }
}

发表评论

电子邮件地址不会被公开。