分享

Weka开发[28]——EM源代码分析(2)

 lzqkean 2013-07-22

private double iterate(Instances inst, boolean report) throws Exception {

    int i;

    double llkold = 0.0;

    double llk = 0.0;

 

    boolean ok = false;

    int seed = m_rseed;

    int restartCount = 0;

    while (!ok) {

       try {

           for (i = 0; i < m_max_iterations; i ) {

              llkold = llk;

              llk = E(inst, true);

 

              if (i > 0) {

                  if ((llk - llkold) < 1e-6) {

                     break;

                  }

              }

              M(inst);

           }

           ok = true;

       } catch (Exception ex) {

       }

    }

 

    return llk;

}

       可以看到有两种迭代中止的方法,第一种是达到了m_max_iterations,第二种是llk-llkold小于阈值,llklog likelihood的缩写,EM的目标就是最大化它,如果已经接近最优值了,所以就停止了。下面就是E

private double E(Instances inst, boolean change_weights) throws Exception {

 

    double loglk = 0.0, sOW = 0.0;

 

    for (int l = 0; l < inst.numInstances(); l ) {

 

       Instance in = inst.instance(l);

 

       loglk = in.weight() * logDensityForInstance(in);

       sOW = in.weight();

 

       if (change_weights) {

           m_weights[l] = distributionForInstance(in);

       }

    }

 

    // reestimate priors

    if (change_weights) {

       estimate_priors(inst);

    }

    return loglk / sOW;

}

       这里logDensityForInstance的代码:

public double logDensityForInstance(Instance instance) throws Exception {

    double[] a = logJointDensitiesForInstance(instance);

    double max = a[Utils.maxIndex(a)];

    double sum = 0.0;

 

    for (int i = 0; i < a.length; i ) {

       sum = Math.exp(a[i] - max);

    }

 

    return max Math.log(sum);

}

       logJointDensitiesForInstance

public double[] logJointDensitiesForInstance(Instance inst)

       throws Exception {

 

    double[] weights = logDensityPerClusterForInstance(inst);

    double[] priors = clusterPriors();

 

    for (int i = 0; i < weights.length; i ) {

       if (priors[i] > 0) {

           weights[i] = Math.log(priors[i]);

       } else {

           throw new IllegalArgumentException("Cluster empty!");

       }

    }

    return weights;

}

       logDensityPerClusterForInstance

public double[] logDensityPerClusterForInstance(Instance inst)

       throws Exception {

 

    int i, j;

    double logprob;

    double[] wghts = new double[m_num_clusters];

 

    m_replaceMissing.input(inst);

    inst = m_replaceMissing.output();

 

    for (i = 0; i < m_num_clusters; i ) {

       logprob = 0.0;

 

       for (j = 0; j < m_num_attribs; j ) {

           if (!inst.isMissing(j)) {

              if (inst.attribute(j).isNominal()) {

                  logprob = Math.log(m_model[i][j].

getProbability(inst.value(j)));

              } else { // numeric attribute

                  logprob = logNormalDens(inst.value(j),

                         m_modelNormal[i][j][0],

                         m_modelNormal[i][j][1]);

              }

           }

       }

 

       wghts[i] = logprob;

    }

    return wghts;

}

       对于离散型属性,这里计算它的概率的对数,它的概率计算很简单:

public double getProbability(double data) {

 

    if (m_SumOfCounts == 0) {

       return 0;

    }

    return (double) m_Counts[(int) data] / m_SumOfCounts;

}

       就是Laplace平滑后的概率,而对于连续属性:

private static double m_normConst = Math.log(Math.sqrt(2 * Math.PI));

 

private double logNormalDens(double x, double mean, double stdDev) {

 

    double diff = x - mean;

 

    return -(diff * diff / (2 * stdDev * stdDev)) - m_normConst

           - Math.log(stdDev);

}

       Diff就是高斯分布(正态分布)中的x-mean,而下面的那一长串就是对高斯分布的公式对数化得到的。

       回到logJointDensitiesForInstance中,clusterPriors的代码如下:

public double[] clusterPriors() {

 

    double[] n = new double[m_priors.length];

 

    System.arraycopy(m_priors, 0, n, 0, n.length);

    return n;

}

       只是复制一下,而在clusterPriors后,weights[i] =Math.log(priors[i])还是取对数后,可以展开来,也就是P(xi|zi)P(zi)P(zi)log(P(xi|zi)P(zi))=log(P(xi|zi)) log(P(zi))

       再回到logDensityForInstance中,这里的Math.exp(a[i]-max)这个看起来奇怪,公式里没有a[i]-max,这里可以把max这个量想成要用别的a[i]组合得到的,因为sum(a[i])=1,可以参考一下Andrew Nglecture notes中的supervise learing28页。其实也就是分母。

       E函数中,有distributionForInstance这个函数如下:

public double[] distributionForInstance(Instance instance) throws Exception {

 

    return Utils.logs2probs(logJointDensitiesForInstance(instance));

}

       函数logs2probs注释中写,将概率取自然对数后的数组转换回概率,概率的和为1 Converts an array containing the natural logarithms of probabilities stored in a vector back into probabilities. The probabilities are assumed to sum to one.

private void estimate_priors(Instances inst) throws Exception {

 

    for (int i = 0; i < m_num_clusters; i ) {

       m_priors[i] = 0.0;

    }

 

    for (int i = 0; i < inst.numInstances(); i ) {

       for (int j = 0; j < m_num_clusters; j ) {

           m_priors[j] = inst.instance(i).weight() * m_weights[i][j];

       }

    }

 

    Utils.normalize(m_priors);

}

       计算每个类的先验概率,就是将每个样本的权重?属于某个类的概率。

       下面是M函数的内容:

       new_estimators的代码如下:

private void new_estimators() {

    for (int i = 0; i < m_num_clusters; i ) {

       for (int j = 0; j < m_num_attribs; j ) {

           if (m_theInstances.attribute(j).isNominal()) {

              m_model[i][j] = new DiscreteEstimator(m_theInstances

                     .attribute(j).numValues(), true);

           } else {

              m_modelNormal[i][j][0] = m_modelNormal[i][j][1] =

                  m_modelNormal[i][j][2] = 0.0;

           }

       }

    }

}

       重新初始化m_modelm_modelNormal

for (i = 0; i < m_num_clusters; i ) {

    for (j = 0; j < m_num_attribs; j ) {

       for (l = 0; l < inst.numInstances(); l ) {

           Instance in = inst.instance(l);

           if (!in.isMissing(j)) {

              if (inst.attribute(j).isNominal()) {

                  m_model[i][j].addValue(in.value(j), in.weight()

                         * m_weights[l][i]);

              } else {

                  m_modelNormal[i][j][0] = (in.value(j)

                         * in.weight() * m_weights[l][i]);

                  m_modelNormal[i][j][2] = in.weight()

                         * m_weights[l][i];

                  m_modelNormal[i][j][1] = (in.value(j)

                         * in.value(j) * in.weight() * m_weights[l][i]);

              }

           }

       }

    }

}

       这里的代码与初始化的时候是不一样的(当然不一样,一样不就没进步了吗?),如果是离散值,那么比较简单就是将属于这个簇的权重?样本权重就可以了,如果是连续值,也不复杂,记录一下就可以了,当然这还只是累记。

// calcualte mean and std deviation for numeric attributes

for (j = 0; j < m_num_attribs; j ) {

    if (!inst.attribute(j).isNominal()) {

       for (i = 0; i < m_num_clusters; i ) {

           if (m_modelNormal[i][j][2] <= 0) {

              m_modelNormal[i][j][1] = Double.MAX_VALUE;

              //      m_modelNormal[i][j][0] = 0;

              m_modelNormal[i][j][0] = m_minStdDev;

           } else {

 

              // variance

              m_modelNormal[i][j][1] = (m_modelNormal[i][j][1] –

                  (m_modelNormal[i][j][0] * m_modelNormal[i][j][0] /

 m_modelNormal[i][j][2])) / (m_modelNormal[i][j][2]);

 

              if (m_modelNormal[i][j][1] < 0) {

                  m_modelNormal[i][j][1] = 0;

              }

 

              // std dev     

              double minStdD = (m_minStdDevPerAtt != null)

m_minStdDevPerAtt[j]

                     : m_minStdDev;

 

              m_modelNormal[i][j][1] = Math

                     .sqrt(m_modelNormal[i][j][1]);

 

              if ((m_modelNormal[i][j][1] <= minStdD)) {

                  m_modelNormal[i][j][1] =

inst.attributeStats(j).numericStats.stdDev;

                  if ((m_modelNormal[i][j][1] <= minStdD)) {

                     m_modelNormal[i][j][1] = minStdD;

                  }

              }

              if ((m_modelNormal[i][j][1] <= 0)) {

                  m_modelNormal[i][j][1] = m_minStdDev;

              }

              if (Double.isInfinite(m_modelNormal[i][j][1])) {

                  m_modelNormal[i][j][1] = m_minStdDev;

              }

 

              // mean

              m_modelNormal[i][j][0] /= m_modelNormal[i][j][2];

           }

       }

    }

}

       这里还要继续对连续属性进行处理,因为上次只是进行了累加。如果m_NormalModel <= 0表示没有样本属于这个类,else的代码很长,其实很简单,第一个是计算方差,然后还是得到有关精度的minStdD,开方后得到标准差,再根据精度对标准差进行处理,最后是计算均值。

double CVLogLikely = -Double.MAX_VALUE;

double templl, tll;

boolean CVincreased = true;

m_num_clusters = 1;

int num_clusters = m_num_clusters;

int i;

Random cvr;

Instances trainCopy;

int numFolds = (m_theInstances.numInstances() < 10) ? m_theInstances

       .numInstances() : 10;

 

boolean ok = true;

int seed = m_rseed;

int restartCount = 0;

       只是初始化一点内容。

CLUSTER_SEARCH: while (CVincreased) {

    // theInstances.stratify(10);

 

    CVincreased = false;

    cvr = new Random(m_rseed);

    trainCopy = new Instances(m_theInstances);

    trainCopy.randomize(cvr);

    templl = 0.0;

    for (i = 0; i < numFolds; i ) {

       Instances cvTrain = trainCopy.trainCV(numFolds, i, cvr);

       if (num_clusters > cvTrain.numInstances()) {

           break CLUSTER_SEARCH;

       }

       Instances cvTest = trainCopy.testCV(numFolds, i);

       m_rr = new Random(seed);

       for (int z = 0; z < 10; z )

           m_rr.nextDouble();

       m_num_clusters = num_clusters;

       EM_Init(cvTrain);

       try {

           iterate(cvTrain, false);

       } catch (Exception ex) {

           // catch any problems - i.e. empty clusters occuring

           ex.printStackTrace();

           seed ;

           restartCount ;

           ok = false;

           if (restartCount > 5) {

              break CLUSTER_SEARCH;

           }

           break;

       }

       try {

           tll = E(cvTest, false);

       } catch (Exception ex) {

           ex.printStackTrace();

           seed ;

           restartCount ;

           ok = false;

           if (restartCount > 5) {

              break CLUSTER_SEARCH;

           }

           break;

       }

 

       templl = tll;

    }

 

    if (ok) {

       restartCount = 0;

       seed = m_rseed;

       templl /= (double) numFolds;

 

       if (templl > CVLogLikely) {

           CVLogLikely = templl;

           CVincreased = true;

           num_clusters ;

       }

    }

}

       trainCopy是待聚类的数据,如果没有什么意外那么numFolds等于10trainCVtestCV相当于是cross validation中的第i次的训练和测试数据。如果要聚的类比样本还多,当然是break。再用EM_Init初始化,iterate进行迭代,然后用EcvTest上测试得到log likelihood,如果没有遇到异常,那么求得templl看是不是比以前更小的num_clusters能取得更好的结果,如果没有取得那么整个循环就结束了。

    本站是提供个人知识管理的网络存储空间,所有内容均由用户发布,不代表本站观点。请注意甄别内容中的联系方式、诱导购买等信息,谨防诈骗。如发现有害或侵权内容,请点击一键举报。
    转藏 分享 献花(0

    0条评论

    发表

    请遵守用户 评论公约

    类似文章 更多