GMM高斯混合模型

文章相关代码参见github:xinlp

EM算法的一个重要应用是高斯混合模型的参数估计。
高斯混合模型(Gaussian Mixed Model)指的是多个高斯分布函数的线性组合,理论上GMM可以拟合出任意类型的分布,常用于解决同一集合下的数据包含多个不同的分布的情况(或者是同一类分布但参数不一样,或者是不同类型的分布,比如正态分布和伯努利分布)。

单高斯模型——GSM

大家都知道正态分布,这一分布反映了自然界普遍存在的有关变量的一种统计规律,例如身高,考试成绩等;而且有很好的数学性质,具有各阶导数,变量频数分布由θ完全决定等等,在许多领域得到广泛应用。在这里简单介绍下高斯分布的概率密度分布函数,一个高斯分布,θ=(μ,σ2):

高斯混合模型——GMM

栗子--商城里的人

在百度上搜GMM的博客,经常有这样的例子:
商场里有男人和女人,他们的身高应该分别是两个正态分布,那么所有人的身高是一个混合的高斯分布,可以用两个高斯分布去拟合出来。(其实我模拟了数据去算,很难收敛成两个不同正态分布,基本拟合成两个一样的正态分布,这两真的太像了)。

GMM定义

高斯混合模型就是对于k个正态分布,每个正态分布有αk的概率被选取,然后每个正态分布是一个θk,形式如下:

上面的商场的例子在这个模型里就是K=2(男,女两种正态分布)
需要估计的参数就是α(男),α(女),μ(男),μ(女),σ(男),σ(女)6个参数。
其中α(男)+α(女)=1,这里的人不是男就是女。
但实际情况比这个可能要复杂一些,比如K不一定等于2,需要事先设定,所以一般有3K+1个参数需要估计出来,但是K一般事先初始化就不会变了。

EM算法看GMM

这个光看《统计学习方法》的时候,有时候会觉得有点跳,最后结合PRML两本书一起看比较清晰。

一、明确参数、隐变量和数据

首先要引用γjk代表观测数据yj来自第k个模型,假如yj来自第k个分模型,γjk=1,其余分模型的γjk=0,此时是一个“1-of-K“表示的方法,就是k从1到K,只有一个是1,其他都等于0

j属于(1,N),k属于(1,K)
1.需要估计的参数:包括K,α1到αK,θ1到θK 总共3K+1个参数
2.观测数据:y1到yN 总共N个数据
3.未观测数据:γ11到γNK 总共NK个数据,其中有N个1,N(K-1)个0,这时候可以理解先验没有去似然,所以还是数据,其实后面也需要把这些隐变量当作参数进行似然估计
4.完全数据:观测数据和未观测数据 总共(N+1)K个数据

二、写出完全数据的对数似然函数

先不求对数,先写出完全数据的似然函数,像下面这样

上面有四行:
第一行就是正常写法,就是参数作为条件,算出某一个观测数据的概率,再把所有数据的概率相乘;
第二行就是θ拆出来αk和θk,然后变成概率密度分布函数的形式去写,因为这是1-of-K的形式 ,所以对于每个γjk,可以写成幂的形式。PRML那本书是这样说的:


虽然
第三行是令,然后可以得到。其实这里放个大括号更好。
第四行是把高斯分布写成具体的形式。

最后去对数以后的结果:

三、E步:Q函数

Q函数:完全数据的对数似然函数,在给定观测数据Y和当前参数θi对未观测数据γ的条件概率分布P(γ|Y,θi)的期望

上面图的Eγjk其实就是E(γjk|Y,θi)
所以你必须必须要知道一点,上面后面那么一长串子东西,到底要求什么的期望,这个问题想算的是γ的期望,就是要知道γjk=1的时候的概率。因为一个东西不是1就是0,二项分布,所以E(γ|Y,θ)=P(γ=1|Y,θ),然后怎么算,就看下面这张图,顺利成章了。


上面这张图的结果就是是在当前模型参数下第j个观测数据来自第k个分模型的概率,称为分模型k对观测数据yj的响应度。PRML这本书说这个是分模型k对数据yj的责任。
将nk代入以及γjk的期望代入后:

四.M步:求偏导,得新θ

M步是对Q函数对θ的极大化

求偏导(这个过程《统计学习方法》没有细讲,而且这个地方好多都写错了,所以稍微引用点PRML那本书的内容):

先说几个事情,

  1. y和γ所有变量都是已知的(γ已经通过E步得到了从γ11到γNK所有的相应度)
  2. 求偏导的顺序应该是μ->σ的平方,α先后都可以
  3. 求α比较麻烦,需要运用拉格朗日乘子法(原来的博客讲SVM的KKT的时候讲过)

先求μ,因为它和其他参数都没有关系,它就是一个连加的二次函数,比如a1(b1-x)2 +a2(b2-x)2=0 ,直接放里面求x=a1b1+a2b2/a1+a2,分子分母都要分别求和,不能直接约了.
再求σ,相当于求a1[logx+(b1-c1)2 *(1/2x2 )]+a2[logx+(b2-c2)2 *(1/2x2 )]=0,注意这里的c就是刚才求的μ
后求α,需要借助一下α的等式条件。

自己又手推写了个详细版:

最后求出来的结果都是:

这个时候,整个推导过程就结束了,重复迭代,参数不再变化或者达到迭代次数就行了。

算法思路总结

  1. 取初始值,初始值敏感
  2. E步求出分模型k对观测数据yj的响应度
  3. M步重新得到新的估计参数
  4. 重复E和M直到收敛

思路清晰,两个字通透

算法——java实现GMM

这个就是本来想模拟男女和小孩的身高,最后发现根本就区分不出来,最后无奈只能把模拟男女身高数据的平均值设的很夸张。。。。下面男的身高mean都是负的了。。。。不过最后拟合效果很好

还想吐槽一个东西。初始值敏感这事,本以为就是有点敏感,其实是非常敏感,所有初始化参数值时先验知识很重要。

    mean = new Double[]{-170.0, 1600.3, 103.5};
    sd = new Double[]{10.0, 10.2, 23.5};
    k = new Double[]{0.3, 0.3, 0.4};
import lombok.extern.java.Log;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.util.FastMath;
import org.junit.jupiter.api.Test;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;

/**
 * @Author unclewang
 * @Date 2018/11/15 14:15
 */
@Slf4j
public class GmmTest {
    //生成数据
    private final static int N = 100;
    private static Double[] men = new Double[(int) (N * 3.5)];
    private static Double[] women = new Double[N * 4];
    private static Double[] children = new Double[(int) (N * 2.5)];
    private static Double[] people = new Double[men.length + women.length + children.length];
    private static Double[] m = new Double[people.length];
    private static Double[] w = new Double[people.length];
    private static Double[] c = new Double[people.length];
    private static Double[] e = new Double[people.length];
    private RandomGenerator rg = new MersenneTwister(100);

    /**
     * EM算法参数定义
     */
    private Double[] mean;
    private Double[] sd;
    //k的和应该等于1,k[0]对应men,k[1]对应women,k[2]对应children
    private Double[] k;

    @Test
    public void test() {
        init();
        for (int i = 0; i < 400; i++) {
            e();
            m();
        }
    }

    @Test
    public void m() {
        mean = new Double[]{reCountMean(m), reCountMean(w), reCountMean(c)};
        sd = new Double[]{reCountSd(m), reCountSd(w), reCountSd(c)};
        k = new Double[]{reCountK(m), reCountK(w), reCountK(c)};
        System.err.println(k[0] + "\t" + k[1] + "\t" + k[2]);
    }

    @Test
    public void e() {
        for (int i = 0; i < e.length; i++) {
            m[i] = k[0] * getP(people[i], mean[0], sd[0]);
            w[i] = k[1] * getP(people[i], mean[1], sd[1]);
            c[i] = k[2] * getP(people[i], mean[2], sd[2]);
            e[i] = m[i] + w[i] + c[i];
            m[i] /= e[i];
            w[i] /= e[i];
            c[i] /= e[i];
        }
        log.info("迭代结果:" + reCountMean(m) + "\t" + reCountMean(w) + "\t" + reCountMean(c));
    }


    public double reCountMean(Double[] d) {
        double sum = 0;
        double meanSum = 0;
        for (int i = 0; i < e.length; i++) {
            meanSum += d[i];
            sum += d[i] * people[i];
        }
        return sum / meanSum;
    }

    public double reCountSd(Double[] d) {
        double newMean = reCountMean(d);
        double sdSum = 0;
        double meanSum = 0;
        for (int i = 0; i < e.length; i++) {
            sdSum += d[i] * FastMath.pow(people[i] - newMean, 2);
            meanSum += d[i];
        }
        return sdSum / meanSum;
    }

    public double reCountK(Double[] d) {
        double meanSum = 0;
        for (int i = 0; i < d.length; i++) {
            meanSum += d[i];
        }
        return meanSum / people.length;
    }

    @Test
    public void testLength() {
        System.out.println(men.length);
        System.out.println(women.length);
        System.out.println(children.length);
        System.out.println(men.length + women.length + children.length);
        System.out.println(people.length);
        System.out.println(e.length);
    }

    @Test
    public void init() {
        generate();
        //初始化参数,因为猜测来自三种人的分布,所以数组的长度都是3
        mean = new Double[]{-170.0, 1600.3, 103.5};
        sd = new Double[]{10.0, 10.2, 23.5};
        k = new Double[]{0.3, 0.3, 0.4};
        log.info("(1)正态分布的均值初始值设定:" + mean[0] + "\t" + mean[1] + "\t" + mean[2]);
    }


    @Test
    public void testGetP() {
        System.out.println(getP(0, 0, 1));
        System.out.println(getP(3, 0, 1));
        System.out.println(getP(-3, 0, 1));
        System.out.println(getP(1, 162, 13));
    }


    public double getP(double x, double mean, double sd) {
        NormalDistribution nd = new NormalDistribution(mean, sd);
        double p = Math.abs(nd.cumulativeProbability(x));
        return p > 0.5 ? 1 - p : p;
    }

    public Double[] generatePeople(Double[] people, double mean, double sd) {
        for (int i = 0; i < people.length; i++) {
            people[i] = normal(mean, sd);
        }
        return people;
    }

    public Double[] generate() {
        log.info("正在生成1000个人的数据");
        men = generatePeople(men, -178, 5);
        women = generatePeople(women, 1630, 5);
        children = generatePeople(children, 100, 4);
        log.info("数据分布情况介绍:\n" + "平均值\t178\t163\t100\n标准差\t10\t15\t24");
        for (int i = 0; i < people.length; i++) {
            if (i < men.length) {
                people[i] = men[i];
            } else if (i < men.length + women.length) {
                people[i] = women[i - men.length];
            } else {
                people[i] = children[i - men.length - women.length];
            }
        }
        List<Double> peopleList = Arrays.asList(people);
        Collections.shuffle(peopleList, new Random(10));
        people = peopleList.toArray(new Double[]{});

//        print(people);
        return people;
    }

    public double normal(double mean, double sd) {
        NormalDistribution nd = new NormalDistribution(rg, mean, sd);
        return nd.sample();
    }

    public <T extends Object> void print(T[] nums) {
        for (T a : nums) {
            System.out.print(a + "\t");
        }
        System.out.println();
    }
}

“GMM高斯混合模型”的一个回复

  1. 你好作者,里面的有个公式响应度是怎么推导的哪?尤其是第一步,我有点不明白,计算 ‘伽玛’ 的数学期望伽。

发表评论

电子邮件地址不会被公开。