HMM算法学习——自己动手实现一个简单的HMM分词(java)

现在到了HMM,离CRF又近了一步。
主要书籍:《统计学习方法》
文章相关代码:github::xinlp

一些看似无关紧要的东西

这一部分其实很多是《统计学习方法》第十二章的东西,因为写这篇的时候其实我已经看完了,所以觉得放在这去讲比较好。

判别还是生成

其实这个问题原来我也没太注意,但是最近接触的有点多,比如生成对抗网络、CRF是判别式的,HMM是生成式的,所以还是要分一下。
以我的理解:

判别模型是直接对P(Y|X)建模,就是说,直接根据X特征来对Y建模训练,X-Y的映射的关系最后留下来,所以一有X就去映射Y。
生成模型是训练阶段是对P(X,Y)建模,得到这个分布的各种参数,可能包括X自己的分布,Y自己的分布还有各种条件分布,最后有X进来想求Y其实做了两件事——P(Y|X)=P(X,Y)/P(X),求出联合分布和X的分布,得到X条件下的Y。

两者对比:判别式模型好训练,但是就只能做一件事情,生成模型不好训练,但是仿佛达到了全知境界,不止能干预测,还能算概率什么的。

对应的模型

有些模型我也不太熟,照本宣科一下
判别式:大部分的神经网络模型、SVM、CRF、决策树
生成式:生成对抗网络、HMM、朴素贝叶斯法

概率模型还是非概率模型

概率模型一般表示是P(Y|X),非概率模型一般是Y=f(X)
HMM和NB就是概率模型,SVM、K-means、AdaBoost是非概率
决策树、CRF即可以看作概率,也可以看成非概率

图模型一览

HMM的定义

HMM是一个主要用于标注问题的生成式模型,是关于序列的概率模型。
隐马尔可夫模型是关于时序的概率模型,描述由一个隐藏的马尔可夫链随机生成不可观测的状态随机序列,再由各个状态生成一个观测而产生观测随机序列的过程。
隐藏的马尔可夫链随机生成的状态的序列,称为状态序列(state sequence):每个状态生成一个观测,而由此产生的观测的随机序列,称为观测序列(observation sequence)。序列的每一个位置又可以看作是一个时刻。

符号定义

符号定义包括参数的定义和变量的定义
变量符号定义:
观测集合:V,大小为M
观测序列:O,长度为T
隐藏集合:Q,大小为N
隐藏序列:I,长度为T
Hmm参数有三个:
初始状态概率分布Pi:矩阵维度是1 * N,就是隐藏集合的N个隐藏状态的初始分布
状态转移概率分布A:矩阵维度是N * N,N种隐藏状态向N种隐藏状态的转移概率分布
观测概率分布B:矩阵维度是N * M,M个观测状态分别在N个隐藏状态的概率分布
λ=(Pi,A,B),HMM三要素
HMM=(Q,I,Pi,A,B),HMM五元组

模型基本假设

1.齐次马尔可夫性假设
任意时刻t的状态只依赖于其前一时刻的状态,与其他时刻的状态和观测无关
2.观测独立性假设
任意时刻的观测只依赖于该时刻的马尔可夫链的状态。

这两个假设为什么要提呢?
因为这两个假设实际上已经把马尔可夫要做的事情说清楚了。
而且用图表示就是经典的下面这张小图:

白色的圈就是隐藏状态,黑色的圈就是观测状态
白色的圈之间横着的箭头就是假设1的结果;
白色到黑色的竖着的箭头就是假设2的结果;
这两个假设==这个图,所以看到这张图不是看个热闹,而是要记起来HMM的两个假设。

HMM的三个基本问题

1.概率计算问题;已知λ=(Pi,A,B)和观测序列O,计算P(O|λ),思想简单
2.参数学习问题;已知O,估计λ=(Pi,A,B),极大似然估计P(O|λ)中的λ,EM算法实现
3.预测解码问题;已知λ=(Pi,A,B)和观测序列O,计算P(I|λ),序列标注问题。

HMM的三个基本问题的解决方案(简述)

问题一:直接计算法(不可能),前向-后向算法(基本动态规划思想)
问题二:监督统计,非监督的Baum-Welch算法,EM算法的HMM实操版
问题三:近似算法,维特比算法(动态规划思想的结晶)

问题一

直接计算法

列举所有可能的状态序列,然后计算每种状态序列情况下的观测序列概率,最后求和。
时间复杂度非常高,对于每一个t,都有N种隐藏状态,那整个序列T的所有可能就是
N的T次方,然后求和又是T的所有复杂度,所以整体的时间复杂度是O(TN T )

前向算法和后向算法(非常非常重要)

先说一句很多博客的误区,就是后向算法他们以为没意义,在问题二的时候才有用,我开始也是这样觉得。
其实实际上当计算整个序列的话,因为最后一个时刻在后向算法里各种状态的概率都是1,所以对于整个观测序列的前向概率算完乘1不改变结果所以没有影响。
前向概率用alpha,后向概率用beta
前向计算初始是初始概率分布Pi,后向计算初始都是1
这个具体想看懂,看书就好,这里放两张图,其实两个差不多,都是多对一,后向从t+1到t,两个其实是“逆等价”的(随便造的新词。。。):
前向:

后向:

上面说的两则计算思想等效,那前向后向的时间复杂度就是一样的,O((T-1)N 2 )

(敲黑板)前向后向能干什么?

先说一句:参数学习算法又叫前向后向算法
要是计算整个序列的概率,前向就解决了
要是计算整个序列某个子序列出现的概率,那就必须要两者一起来算了
但是要是就去计算序列的概率,就不会说那么久了。
再来看这两个的公式:


两者的乘积是非常有意义的:

是不是有种恍然大悟的感觉?
你现在都可以算出来某个t时刻的某个i的概率,这个一般用γ(gamma)表示;

除此之外,你还可以求另外一个东西,如果α和β不都是i了,β算的是下一个时刻的概率,然后再有转移概率和观测概率,那我们可以求出来这个时刻是i1,下一个时刻是i2的概率,这个一般用ξ表示


下面这张图的所有线就都在上面公式的分子了。

上面的两个东西,γ和ξ又可以算出来很多东西:
(1)观测序列O中状态i出现的概率
选定i,可以把每一个时刻t出现i的概率求和,就是
(2)观测序列O中状态i转移的期望值
这句话换个说话就好理解了,状态i转移:其实就是不管现在是啥状态反正前一个状态是i的概率。那就是选定i,可以把1到T-1的都加起来就对了,
(3)观测序列O中由状态i转到状态j的期望值
选定i和j,把对应的t时刻的ξ加起来就是

问题二

监督学习算法——频数/总数就是概率

首先监督学习算法,就是数据足够,然后人工标注好,其实你只需要统计出来各种频数就可以了。
比如分词的时候:
统计B到E的频数,B到M的频数什么的都能求出来,转移矩阵A有了
每个字分别为BEMS的频数,观测矩阵B有了
样本中第一个字为B和S分别的概率,初始概率PI有了
然后就解决问题了,其实分词就是这样的,在人工标注好的数据集上统计就好了

非监督学习算法——Baum-Welch算法

第二种就是没有标注的情况下,你只有一堆句子(假设句子长短一致,都是T个字),这时候学习这些参数,就要用EM算法对于HMM参数学习问题的适配版——Baum-Welch算法,具体训练迭代就不说了,主要在于每一次E步和M步参数如何估计得到。

E步

再看一眼EM标准的Q函数

按照EM的套路,先要确定Q函数,参数θ这里是λ,观测变量Y这里是O,隐变量Z这里是I
所以HMM问题的Q函数可以写成:

但是后半部分是不好直接表示的,于是先做一步化简:

对于分母来说,O是固定的,λ杠也是固定的,对于变量λ来说,这是一个常数。
所以Q函数就变成了李航老师那本书的10.33公式写的样子(开始还以为书上写错了。。。。尴尬)

根据状态序列和观测序列的联合分布

所以Q函数可以拆解开来,变成下面这个样子:

M步

M步就是像前面求GMM的时候写的那样,求偏导,不过这次三个参数都需要利用拉格朗日乘子法。

PI

对于pi来说,其实只和第一个隐藏状态i1有关系,所以可以改写成

借用约束条件可以利用拉格朗日乘子法,写出拉格朗日函数:

令偏导=0:

得:

两边同时对i求和

最后得到新的pi是:

A

Q函数的第二项可以写成

当i不变的时候,j从1到N,此时从i到j的转移概率之和也是1,约束条件有了,这个和pi求解的过程一模一样

B

Q函数的第三项可以写成

此时的约束条件是:
对于一个时刻j,j的各种隐藏状态的概率之和等于1,和pi求解的过程也是一模一样

用γ和ξ表示新参数

上面说过γ和ξ都具有各种实际意义的,于是根据意义等价表示原则(自己瞎创的新词),可以用γ和ξ表示A,B,PI,如下图所示

两种角度看Baum-Welch算法

我觉得有两个角度可以去思考这个算法的意义。
一个是正向角度,就是像上面一样用EM的思想去先得到Q函数,然后M步,求各个变量求偏导,得到各个新参数用原来的旧参数表示,然后去解决旧参数的计算问题(用γ和ξ表示参数)。
另一个是反向看(当然有点事后诸葛亮的味道),其实你知道了γ和ξ以后,就不用Q函数,直接根据意义表示A,B,PI。于是参数学习直接用下面这张图去解释,你先找到参数的意义,根据意义等价表示原则,直接用γ和ξ表示A,B,PI。

问题三

近似算法

近似算法就是,找出每个时刻隐藏状态集合中出现概率最大的隐藏状态。
这个不好,不是因为时间复杂度太高,其实一点都不高,是因为有问题,比如分词的时候,比如"保证"这两个字,假如算出来的都是隐藏状态为‘B’的时候概率最大,但是B后面只能是M或者E呀,这时候就会有问题。
但是近似算法在很多情况下还是有用的。

维特比算法

viterbi算法是一种动态规划的思想,不仅HMM要用,CRF一样的思想也要用 ,但是又不像前向算法就是为了降低时间复杂度先存着前一步生成的结果以便于后面去用那么直接。
这个思想的基础是这样的,如果从p1到pN存在一条最好的路是k,如果这条路k经过了p‘,则p’到pN一定是最优的,如果不是最优的,就存在另外一条路k‘,他的p’到pN更好,那这条路就不是k了,矛盾。所以我们只要找到最大概率的最后一个结点,然后一步一步向前就能求出来最优路径。
再说的接近算法实现一点,就是每次我们都根据t时刻的隐藏状态情况算出t+1时刻各种隐藏状态情况,记录t+1时刻概率最大的时候t时刻是哪种隐藏状态就行了。
一般需要两个矩阵:
一个是δ,矩阵维度是T * N,每一个矩阵的值(δtn)用来记录t时刻状态为n的概率最大值(因为都是前一个状态概率 * 转移概率,不一定是哪个大)
一个是ψ,矩阵维度是T * N,记录t时刻概率最大值是由t-1时刻哪一个状态乘以转移概率得到的。
具体就是求出来δ和ψ以后,路径回溯就得到了隐藏状态。

各种荔枝的实战

网上关于HMM,基本就是三个小例子(盒子与球,根据这个人出不出门判断那几天天气怎么样,根据和女朋友的聊天记录判断女朋友心情),这三个例子都很简单;
最后实现HMM分词,借用jieba的参数去写viterbi算法实现。

说说盒子与球模型

类比EM算法的三硬币模型,就是掷哪个硬币的概率不是来自于一个原来统一的A硬币去解决了,而是由前一个硬币的结果来自于一个怎样的A硬币决定的。
所以理解了三硬币以后,只要把里面的那个唯一的A硬币考虑成每次都有这个A硬币去决定,而且这个A硬币还不是独立的,这一次抛的结果跟上一次抛的有很大关系。

这个问题的题干真的不想叙述了,查看问题去这个博客看一下就好。

这个例子并不是很难,只是理解了这个,什么都能套上去。

比如:
天气那个例子的盒子就是各种天气,一个人出去干了什么就是球,天气变化A就是盒子怎么抽球的规定,什么天气出去干什么的可能性就是B,本来一天的天气可能就是PI
女朋友心情的例子里的盒子就是她高兴还是生气,看到的球就是她对你说的话,A就是她本来性情转变如何,B就是每说一句话思考有没有潜台词(比如“我服你了”,可能是正面高兴,也可能是“你是傻逼吗”生气,所以也是概率),PI就是这个女孩生性乐观还是易怒。
分词的盒子就是BEMS,每个汉字就是一种颜色的球,每种颜色的球来自哪个盒子就是B,BEMS互相之间的转移概率A就是规定的盒子抽球的规则,PI就是什么都不知道的情况某个字在BEMS四种情况的概率。

HMM分词实现——viterbi算法

用的jieba分词的参数,心里话,效果不错
有时间自己用EM算法训练一下参数。
刚才已经类比盒子与球模型里面出现的各种变量与参数了,所以直接说Java实现吧,写这个代码的时候自己也是摸石头过河,所以注释还是比较多的。具体代码和相关文件可以访问: https://github.com/1000-7/xinlp

package segment.hmm;

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.junit.jupiter.api.Test;
import segment.crf.XinCRFConfig;

import java.io.File;
import java.io.IOException;
import java.math.BigDecimal;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;

/**
 * viterbi算法
 * 已知状态转移矩阵A、概率观测矩阵B、初始状态概率向量Pi和观测序列O
 * 求可能性最大的状态序列
 * 在分词问题上,状态集合是BEMS
 * O是"武汉大学真美"
 * I是"BMMESS"
 * 下面是Jieba分词的参数
 * Pi是
 * {'B': -0.26268660809250016,
 * 'E': -3.14e+100,
 * 'M': -3.14e+100,
 * 'S': -1.4652633398537678}
 * A是
 * {'B': {'E': -0.510825623765990, 'M': -0.916290731874155},
 * 'E': {'B': -0.5897149736854513, 'S': -0.8085250474669937},
 * 'M': {'E': -0.33344856811948514, 'M': -1.2603623820268226},
 * 'S': {'B': -0.7211965654669841, 'S': -0.6658631448798212}}
 * B是
 * {'B': {'\u4e00': -3.6544978750449433,
 * '\u4e01': -8.125041941842026,
 * '\u4e03': -7.817392401429855,
 * '\u4e07': -6.3096425804013165,
 * '\u4e08': -8.866689067453933,
 * '\u4e09': -5.932085850549891,
 * '\u4e0a': -5.739552583325728,
 * '\u4e0b': -5.997089097239644,
 * '\u4e0d': -4.274262055936421,
 * '\u4e0e': -8.355569307500769,
 * ...},
 * 'E': {'\u4e00': -6.044987536255073,
 * '\u4e01': -9.075800412310807,
 * '\u4e03': -9.198842005220659,
 * '\u4e07': -7.655326112989935,
 * '\u4e08': -9.02382100266782,
 * '\u4e09': -7.978829805438807,
 * '\u4e0a': -5.323135439997585,
 * '\u4e0b': -5.739644714409899,
 * ...},
 * 'M': {...},
 * 'S': {...}
 * }
 *
 * @author unclewang
 */
@Data
@Slf4j
public class ViterbiHmm {

    private static char[] state = new char[]{'B', 'E', 'M', 'S'};
    /**
     * 状态值集合的大小 N
     **/
    protected static int stateNum = state.length;

    protected static final Double MIN = -3.14e+100;
    /**
     * 初始状态概率Pi
     **/
    protected Double[] pi;
    /**
     * 转移概率A
     **/
    protected Double[][] transferProbability;
    /**
     * 发射概率B
     **/
    protected Double[][] emissionProbability;

    /**
     * 观测值集合的大小 出现了几种可能性,红白球的话就是2,分词的话机会是词表的长度
     **/
    protected int observationNum;

    /**
     * 观测序列O,比如 武汉大学真美
     */
    private Integer[] observeSequence;
    /**
     * 词典和id双向对应map
     */
    private BiMap<String, Integer> wordId;

    /**
     * 使用jieba分词使用的概率
     */
    public void initLambda() {
        initPi();
        initA();
        initB();
    }


    /**
     * 维特比算法
     */
    public void viterbi(String s) {
        initLambda();
        String[] sentences = s.split("[,.?;。,]");
        for (String sentence : sentences) {
            viterbi(str2int(sentence));
        }
    }


    public void viterbi(Integer[] observeSequence) {
        observationNum = observeSequence.length;
        Integer[][] path = new Integer[observationNum][stateNum];
        Double[][] deltas = new Double[observationNum][stateNum];

        for (int i = 0; i < stateNum; i++) {
            deltas[0][i] = pi[i] + emissionProbability[i][observeSequence[0]];
            path[0][i] = i;
        }

        for (int t = 1; t < observationNum; t++) {
            for (int i = 0; i < stateNum; i++) {
                deltas[t][i] = deltas[t - 1][0] + transferProbability[0][i];
                path[t][i] = 0;
                for (int j = 1; j < stateNum; j++) {
                    double tmp = deltas[t - 1][j] + transferProbability[j][i];
                    if (tmp > deltas[t][i]) {
                        deltas[t][i] = tmp;
                        path[t][i] = j;
                    }
                }
                deltas[t][i] += emissionProbability[i][observeSequence[t]];
            }
        }
        XinCRFConfig.print(deltas);
        XinCRFConfig.print(path);

        //找最优路径,注意最后一个字不是所有状态的最大值,而是E(1)和S(3)的最大值
        Integer[] mostLikelyStateSequence = new Integer[observationNum];
        mostLikelyStateSequence[observationNum - 1] = deltas[observationNum - 1][1] >= deltas[observationNum - 1][3] ? 1 : 3;

        for (int i = mostLikelyStateSequence.length - 2; i >= 0; i--) {
            mostLikelyStateSequence[i] = path[i + 1][mostLikelyStateSequence[i + 1]];
        }
        for (int i = 0; i < observationNum; i++) {
            System.out.print(wordId.inverse().get(observeSequence[i]));
            if (mostLikelyStateSequence[i] == 1 || mostLikelyStateSequence[i] == 3) {
                System.out.print(" ");
            }
        }

    }

    public Integer[] str2int(String s) {
        char[] chars = s.toCharArray();
        Integer[] res = new Integer[chars.length];
        for (int i = 0; i < chars.length; i++) {
            res[i] = wordId.getOrDefault(String.valueOf(chars[i]), 1);
        }
        return res;
    }


    private void initB() {
        try {
            String list = FileUtils.readFileToString(new File(System.getProperty("user.dir") + "/src/main/resources/B.json"), "UTF8");
            JSONObject jsonObject = JSON.parseObject(list);
            Map<String, Double> bMap = toDouble(JSON.parseObject(jsonObject.get("B").toString()).getInnerMap());
            Map<String, Double> eMap = toDouble(JSON.parseObject(jsonObject.get("E").toString()).getInnerMap());
            Map<String, Double> mMap = toDouble(JSON.parseObject(jsonObject.get("M").toString()).getInnerMap());
            Map<String, Double> sMap = toDouble(JSON.parseObject(jsonObject.get("S").toString()).getInnerMap());
            HashSet<String> wordSet = new HashSet<>(bMap.keySet());
            wordSet.addAll(eMap.keySet());
            wordSet.addAll(mMap.keySet());
            wordSet.addAll(sMap.keySet());
            emissionProbability = new Double[stateNum][wordSet.size()];
            wordId = HashBiMap.create();
            int i = 0;
            for (String s : wordSet) {
                wordId.put(s, i);
                emissionProbability[0][i] = bMap.getOrDefault(s, MIN);
                emissionProbability[1][i] = eMap.getOrDefault(s, MIN);
                emissionProbability[2][i] = mMap.getOrDefault(s, MIN);
                emissionProbability[3][i] = sMap.getOrDefault(s, MIN);
                i++;
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public static Map<String, Double> toDouble(Map<String, Object> map) {
        Map<String, Double> res = new HashMap<>();
        map.forEach((key, value) -> res.put(key, ((BigDecimal) value).doubleValue()));
        return res;
    }


    private void initA() {
        transferProbability = new Double[][]{
                {MIN, -0.510825623765990, -0.916290731874155, MIN},
                {-0.5897149736854513, MIN, MIN, -0.8085250474669937},
                {MIN, -0.33344856811948514, -1.2603623820268226, MIN},
                {-0.7211965654669841, MIN, MIN, -0.6658631448798212}};
    }

    private void initPi() {
        pi = new Double[]{-0.26268660809250016, -3.14e+100, -3.14e+100, -1.4652633398537678};
    }

    public <T extends Object> void print(T[] nums) {
        for (T a : nums) {
            System.out.print(a + "\t");
        }
        System.out.println();
    }

    public <T extends Object> void print(T[][] nums) {
        for (int i = 0; i < nums.length; i++) {
            for (int j = 0; j < nums[0].length; j++) {
                System.out.print(nums[i][j] + "\t");
            }
            System.out.println();
        }
        System.out.println();
    }

    @Test
    public void segmentTest() {
        initLambda();
        viterbi("今天的天气很好,出来散心挺不错,武汉大学特别好,提高人民的生活水平");
    }
}

最后效果截个图:

“HMM算法学习——自己动手实现一个简单的HMM分词(java)”的2个回复

  1. hi博主~
    感谢教程~
    想请问一下,有没有什么方法能处理离散的观测序列呀?比如观测序列是一个词向量。
    谢谢~

发表评论

电子邮件地址不会被公开。