分享

高斯混合模型之代码实现

 陈永正的图书馆 2017-06-08

高斯混合模型的代码实现,总体的思路是比较简单的。但涉及到具体的优化,如多维高斯概率分布协方差矩阵的逆矩阵,就是一个很头疼的奇异矩阵问题。这里我只想讲下代码实现的流程。具体的代码可以参照:http://blog.csdn.net/crzy_sparrow/article/details/7413019。(注意他的代码没有考虑协方差逆矩阵的问题)。


高斯混合模型代码实现流程

(1)·首先是初始化,高斯混合模型的效果很大程度上依赖于初始点的设定。一般我们用K-means聚类生成K个中心节点。对于属于同一节点的数据,我们求其均值,方差以及该节点的概率。这里所谓的均值就是中心节点,协方差矩阵按照定义求解,该节点概率(选择该个高斯模型的概率)= 属于该节点的数据个数 / 总数据个数,这样初始化完成。

(2)·E-STEP:求得Q(j),这里要将上次得到的均值u,协方差sigma,模型概率pj,带入Q(j)的定义式(见“高斯混合模型之理解”),注意p(x|j)是j类高斯概率分布;

(3)·M-STEP:按照我们推导的公式,更新均值u,协方差sigma和模型概率pj;

(4)·将(3)中更新的参数带入(2)中更新Q(j);

(5)·最后要设定阀值,使迭代结束。按照定义,我们要将u,sigma,pj,带入L(theta)(最大似然值)公式中,如果t+1时刻的L与t时刻的L的比值接近于1,即可停止。具体的阀值还要应对实际的数据进行调整;


我的代码(MATLAB):

·初始化函数:

  1. function [ mu,m_sigma,mp ] = GMM_ini( data,n_center )  
  2.   
  3.   
  4. [m,n]=size(data);  
  5. [data_id,centers]=kmeans(data,n_center);  
  6. mu=centers;  
  7. mp=zeros(1,n_center);  
  8. m_sigma=zeros(n,n,n_center);  
  9.   
  10.   
  11. for i=1:n_center  
  12.     tem_id=(data_id==i);  
  13.     m_sigma(:,:,i)=sigma(data(tem_id,:));  
  14.     mp(i)=sum(tem_id)/m;  
  15. end  
  16.   
  17. end  
  1. function sig=sigma(data)//计算初始化的方差  
  2.   
  3.   
  4. [m,n]=size(data);  
  5. u=mean(data,1);  
  6. tem_data=data-repmat(u,m,1);  
  7.   
  8.   
  9. sig=zeros(n,n);  
  10. for k1=1:m  
  11. %     for k2=1:m   
  12.     sig=sig+tem_data(k1,:)'*tem_data(k1,:);  
  13. %     end  
  14. end  
  15. sig=(sig+ 1E-5.*diag(ones(n,1)))/m;  
  16. end  


·高斯概率分布函数

  1. function gp=GaussianPDF(data,u,sigma)  
  2.   
  3. [m,n]=size(data);  
  4.   
  5. pre_item=1/sqrt(((2*pi)^n)*abs(det(sigma)+realmin));  
  6. nxt_item(1:m)=0;  
  7. tem_data=data-repmat(u,m,1);  
  8. for i=1:m  
  9.    tem_data_t=tem_data(i,:)';  
  10.    nxt_item(i)=exp(-0.5*(tem_data(i,:)*(inv(sigma))*tem_data_t));   
  11. end  
  12.   
  13. gp=pre_item*nxt_item;  
  14.   
  15. end  

·EM算法函数

  1. function [mu,msigma,mp]=GMM(data,n_center,loglik_threshold)  
  2.   
  3. [ mu,msigma,mp ] = GMM_ini( data,n_center );  
  4. disp('GMM_Ini Completed ! ');  
  5.   
  6. Qt=E_step(data,mu,msigma,mp);  
  7. loglik_pre=loglike(data,mu,msigma,mp);  
  8. step=0;  
  9.   
  10. while 1  
  11.    [mu,msigma,mp]=M_step(Qt,data);  
  12.    loglik_nxt=loglike(data,mu,msigma,mp);  
  13.   if abs((loglik_nxt/loglik_pre)-1) < loglik_threshold    
  14.     break;    
  15.   end  
  16.     
  17.   if step>4  
  18.       break;  
  19.   end  
  20.     
  21.   step=step+1;  
  22.   step  
  23.   loglik_pre=loglik_nxt;  
  24.   Qt=E_step(data,mu,msigma,mp);  
  25.     
  26.     
  27. end  
  28.   
  29. end  
  30.   
  31. function Qt=E_step(data,mu,m_sigma,mp)//E_STEP  
  32.   
  33. n_model=length(mp);  
  34. m=size(data,1);  
  35. pxj(m,n_model)=0;  
  36.   
  37. for j=1:n_model  
  38.    pxj(:,j)=GaussianPDF(data,mu(j,:),m_sigma(:,:,j));  
  39. end  
  40.   
  41. px=pxj.*repmat(mp,m,1);  
  42. sp=sum(px,2);  
  43. Qt=px./repmat(sp,1,n_model);  
  44.   
  45. end  
  46.   
  47. function [lu,lsigma,lp]=M_step(Qt,data)//M_STEP  
  48.   
  49. [m,n_model]=size(Qt);  
  50. n=size(data,2);  
  51.   
  52. lu=zeros(n_model,n);  
  53. lsigma=zeros(n,n,n_model);  
  54. lp=zeros(1,n_model);  
  55.   
  56. mul_data=zeros(n,n);  
  57.   
  58. for j=1:n_model   
  59.     lu(j,:)=sum(data.*repmat(Qt(:,j),1,n))/sum(Qt(:,j));     
  60.     tem_data=data-repmat(lu(j,:),m,1);   
  61.         for k=1:m  
  62.              mul_data=mul_data+tem_data(k,:)'*tem_data(k,:)*Qt(k,j);   
  63.         end  
  64.     lsigma(:,:,j)=realmin+mul_data/sum(Qt(:,j));  
  65.     lp(j)=sum(Qt(:,j))/m;  
  66. end  
  67.   
  68.   
  69. end  
  70.   
  71. function loglik=loglike(data,mu,msigma,mp)//似然值  
  72.   
  73. n_center=size(mu,1);  
  74. pxj=zeros(size(data,1),n_center);  
  75.  for j=1:n_center    
  76.     pxj(:,j) = GaussianPDF(data, mu(j,:), msigma(:,:,j));    
  77.   end    
  78.   F = pxj*mp';    
  79.   F(F<realmin) = realmin;    
  80.   loglik = log(sum(F));   
  81.     
  82. end  

·测试函数

  1. clear all;  
  2. clc;  
  3. data=rand(1000,128);//1000个128维的数据样本  
  4. n_center=4;  
  5. thresh=0.0005;  
  6. [u,sigma,p]=GMM(data,n_center,thresh);  
  7.   
  8. disp('Test Completed !');  

注意,模型数在3-5左右,阀值要在0.0005-0.001,否则容易得到奇异方差矩阵。


    本站是提供个人知识管理的网络存储空间,所有内容均由用户发布,不代表本站观点。请注意甄别内容中的联系方式、诱导购买等信息,谨防诈骗。如发现有害或侵权内容,请点击一键举报。
    转藏 分享 献花(0

    0条评论

    发表

    请遵守用户 评论公约

    类似文章 更多