分享

混合高斯模型(GMM)Spark MLlib调用实例(Scala/Java/Python)

 陈永正的图书馆 2017-06-08

高斯混合模型

算法原理:

    混合高斯模型描述数据点以一定的概率服从k种高斯子分布的一种混合分布。Spark.ml使用EM算法给出一组样本的极大似然模型。

参数:

featuresCol:

类型:字符串型。

含义:特征列名。

k:

类型:整数型。

含义:混合模型中独立的高斯数目。

maxIter:

类型:整数型。

含义:迭代次数(>=0)。

predictionCol:

类型:字符串型。

含义:预测结果列名。

probabilityCol:

类型:字符串型。

含义:用以预测类别条件概率的列名。

seed:

类型:长整型。

含义:随机种子。

tol:

类型:双精度型。

含义:迭代算法的收敛性。

调用示例:

Scala:

[plain] view plain copy
  1. import org.apache.spark.ml.clustering.GaussianMixture  
  2.   
  3. // Loads data  
  4. val dataset = spark.read.format("libsvm").load("data/mllib/sample_kmeans_data.txt")  
  5.   
  6. // Trains Gaussian Mixture Model  
  7. val gmm = new GaussianMixture()  
  8.   .setK(2)  
  9. val model = gmm.fit(dataset)  
  10.   
  11. // output parameters of mixture model model  
  12. for (i <- 0 until model.getK) {  
  13.   println("weight=%f\nmu=%s\nsigma=\n%s\n" format  
  14.     (model.weights(i), model.gaussians(i).mean, model.gaussians(i).cov))  
  15. }  
Java:

  1. import org.apache.spark.ml.clustering.GaussianMixture;  
  2. import org.apache.spark.ml.clustering.GaussianMixtureModel;  
  3. import org.apache.spark.sql.Dataset;  
  4. import org.apache.spark.sql.Row;  
  5.   
  6. // Loads data  
  7. Dataset<Row> dataset = spark.read().format("libsvm").load("data/mllib/sample_kmeans_data.txt");  
  8.   
  9. // Trains a GaussianMixture model  
  10. GaussianMixture gmm = new GaussianMixture()  
  11.   .setK(2);  
  12. GaussianMixtureModel model = gmm.fit(dataset);  
  13.   
  14. // Output the parameters of the mixture model  
  15. for (int i = 0; i < model.getK(); i++) {  
  16.   System.out.printf("weight=%f\nmu=%s\nsigma=\n%s\n",  
  17.           model.weights()[i], model.gaussians()[i].mean(), model.gaussians()[i].cov());  
  18. }  
Python:

[python] view plain copy
  1. from pyspark.ml.clustering import GaussianMixture  
  2.   
  3. # loads data  
  4. dataset = spark.read.format("libsvm").load("data/mllib/sample_kmeans_data.txt")  
  5.   
  6. gmm = GaussianMixture().setK(2)  
  7. model = gmm.fit(dataset)  
  8.   
  9. print("Gaussians: ")  
  10. model.gaussiansDF.show()  


    本站是提供个人知识管理的网络存储空间,所有内容均由用户发布,不代表本站观点。请注意甄别内容中的联系方式、诱导购买等信息,谨防诈骗。如发现有害或侵权内容,请点击一键举报。
    转藏 分享 献花(0

    0条评论

    发表

    请遵守用户 评论公约

    类似文章 更多