分享

算法——K均值聚类算法扩展应用(Java实现)

 学海无涯GL 2013-11-26

1、前面一篇文章算法——K均值聚类算法(Java实现)简单的实现了一下K均值分类算法,这节我们对于他的应用进行一个扩展应用

2、目标为对对象的分类

3、具体实现如下

1)首先建立一个基类KmeansObject,目的为继承该类的子类都可以应用我们的k均值算法进行分类,代码如下

  1. package org.cyxl.util.algorithm;  
  2.   
  3. /** 
  4.  * 所有使用k均值分类算法的对象都必须继承自该对象 
  5.  * @author cyxl 
  6.  * @version 1.0 2012-05-24 
  7.  * @since 1.0 
  8.  * 
  9.  */  
  10. public class KmeansObject {  
  11.     public float compare;       //比较因子   
  12. }  

2)算法实现,代码如下

  1. package org.cyxl.util.algorithm;  
  2.   
  3. import java.util.ArrayList;  
  4. import java.util.Random;  
  5.   
  6. /** 
  7.  * K均值聚类算法 
  8.  */  
  9. public class CommonKmeans {  
  10.     private int k;// 分成多少簇   
  11.     private int m;// 迭代次数   
  12.     private int dataSetLength;// 数据集元素个数,即数据集的长度   
  13.     private ArrayList<KmeansObject> dataSet;// 数据集链表   
  14.     private ArrayList<KmeansObject> center;// 中心链表   
  15.     private ArrayList<ArrayList<KmeansObject>> cluster; // 簇   
  16.     private ArrayList<Float> jc;// 误差平方和,k越接近dataSetLength,误差越小   
  17.     private Random random;  
  18.   
  19.     /** 
  20.      * 设置需分组的原始数据集 
  21.      *  
  22.      * @param dataSet 
  23.      */  
  24.   
  25.     public void setDataSet(ArrayList<KmeansObject> dataSet) {  
  26.         this.dataSet = dataSet;  
  27.     }  
  28.   
  29.     /** 
  30.      * 获取结果分组 
  31.      *  
  32.      * @return 结果集 
  33.      */  
  34.   
  35.     public ArrayList<ArrayList<KmeansObject>> getCluster() {  
  36.         return cluster;  
  37.     }  
  38.   
  39.     /** 
  40.      * 构造函数,传入需要分成的簇数量 
  41.      *  
  42.      * @param k 
  43.      *            簇数量,若k<=0时,设置为1,若k大于数据源的长度时,置为数据源的长度 
  44.      */  
  45.     public CommonKmeans(int k) {  
  46.         if (k <= 0) {  
  47.             k = 1;  
  48.         }  
  49.         this.k = k;  
  50.     }  
  51.   
  52.     /** 
  53.      * 初始化 
  54.      */  
  55.     private void init() {  
  56.         m = 0;  
  57.         random = new Random();  
  58.         if (dataSet == null || dataSet.size() == 0) {  
  59.             initDataSet();  
  60.         }  
  61.         dataSetLength = dataSet.size();  
  62.         if (k > dataSetLength) {  
  63.             k = dataSetLength;  
  64.         }  
  65.         center = initCenters();  
  66.         cluster = initCluster();  
  67.         jc = new ArrayList<Float>();  
  68.     }  
  69.   
  70.     /** 
  71.      * 如果调用者未初始化数据集,则采用内部测试数据集 
  72.      */  
  73.     private void initDataSet() {  
  74.         dataSet = new ArrayList<KmeansObject>();  
  75.           
  76.         for(int i=0;i<10;i++)  
  77.         {  
  78.             int temp = random.nextInt(100);  
  79.             KmeansObject ko=new KmeansObject();  
  80.             ko.compare=temp;  
  81.             dataSet.add(ko);  
  82.         }  
  83.     }  
  84.   
  85.     /** 
  86.      * 初始化中心数据链表,分成多少簇就有多少个中心点 
  87.      *  
  88.      * @return 中心点集 
  89.      */  
  90.     private ArrayList<KmeansObject> initCenters() {  
  91.         ArrayList<KmeansObject> center = new ArrayList<KmeansObject>();  
  92.         int[] randoms = new int[k];  
  93.         boolean flag;  
  94.         int temp = random.nextInt(dataSetLength);  
  95.         randoms[0] = temp;  
  96.         for (int i = 1; i < k; i++) {  
  97.             flag = true;  
  98.             while (flag) {  
  99.                 temp = random.nextInt(dataSetLength);  
  100.                 int j = 0;  
  101.                 // 不清楚for循环导致j无法加1   
  102.                 // for(j=0;j<i;++j)   
  103.                 // {   
  104.                 // if(temp==randoms[j]);   
  105.                 // {   
  106.                 // break;   
  107.                 // }   
  108.                 // }   
  109.                 while (j < i) {  
  110.                     if (temp == randoms[j]) {  
  111.                         break;  
  112.                     }  
  113.                     j++;  
  114.                 }  
  115.                 if (j == i) {  
  116.                     flag = false;  
  117.                 }  
  118.             }  
  119.             randoms[i] = temp;  
  120.         }  
  121.   
  122.         for (int i = 0; i < k; i++) {  
  123.             center.add(dataSet.get(randoms[i]));// 生成初始化中心链表   
  124.         }  
  125.         return center;  
  126.     }  
  127.   
  128.     /** 
  129.      * 初始化簇集合 
  130.      *  
  131.      * @return 一个分为k簇的空数据的簇集合 
  132.      */  
  133.     private ArrayList<ArrayList<KmeansObject>> initCluster() {  
  134.         ArrayList<ArrayList<KmeansObject>> cluster = new ArrayList<ArrayList<KmeansObject>>();  
  135.         for (int i = 0; i < k; i++) {  
  136.             cluster.add(new ArrayList<KmeansObject>());  
  137.         }  
  138.   
  139.         return cluster;  
  140.     }  
  141.   
  142.     /** 
  143.      * 计算两个点之间的距离 
  144.      *  
  145.      * @param element 
  146.      *            点1 
  147.      * @param center 
  148.      *            点2 
  149.      * @return 距离 
  150.      */  
  151.     private float distance(KmeansObject element, KmeansObject center) {  
  152.         float distance = 0.0f;  
  153.   
  154.         distance=Math.abs(element.compare-center.compare);  
  155.           
  156.         return distance;  
  157.     }  
  158.   
  159.     /** 
  160.      * 获取距离集合中最小距离的位置 
  161.      *  
  162.      * @param distance 
  163.      *            距离数组 
  164.      * @return 最小距离在距离数组中的位置 
  165.      */  
  166.     private int minDistance(float[] distance) {  
  167.         float minDistance = distance[0];  
  168.         int minLocation = 0;  
  169.         for (int i = 1; i < distance.length; i++) {  
  170.             if (distance[i] < minDistance) {  
  171.                 minDistance = distance[i];  
  172.                 minLocation = i;  
  173.             } else if (distance[i] == minDistance) // 如果相等,随机返回一个位置   
  174.             {  
  175.                 if (random.nextInt(10) < 5) {  
  176.                     minLocation = i;  
  177.                 }  
  178.             }  
  179.         }  
  180.   
  181.         return minLocation;  
  182.     }  
  183.   
  184.     /** 
  185.      * 核心,将当前元素放到最小距离中心相关的簇中 
  186.      */  
  187.     private void clusterSet() {  
  188.         float[] distance = new float[k];  
  189.         for (int i = 0; i < dataSetLength; i++) {  
  190.             for (int j = 0; j < k; j++) {  
  191.                 distance[j] = distance(dataSet.get(i), center.get(j));  
  192.   
  193.             }  
  194.             int minLocation = minDistance(distance);  
  195.   
  196.             cluster.get(minLocation).add(dataSet.get(i));// 核心,将当前元素放到最小距离中心相关的簇中   
  197.   
  198.         }  
  199.     }  
  200.   
  201.     /** 
  202.      * 求两点误差平方的方法 
  203.      *  
  204.      * @param element 
  205.      *            点1 
  206.      * @param center 
  207.      *            点2 
  208.      * @return 误差平方 
  209.      */  
  210.     private float errorSquare(KmeansObject element, KmeansObject center) {  
  211.         float x = Math.abs(element.compare-center.compare);  
  212.           
  213.         float errSquare = x * x;  
  214.   
  215.         return errSquare;  
  216.     }  
  217.   
  218.     /** 
  219.      * 计算误差平方和准则函数方法 
  220.      */  
  221.     private void countRule() {  
  222.         float jcF = 0;  
  223.         for (int i = 0; i < cluster.size(); i++) {  
  224.             for (int j = 0; j < cluster.get(i).size(); j++) {  
  225.                 jcF += errorSquare(cluster.get(i).get(j), center.get(i));  
  226.   
  227.             }  
  228.         }  
  229.         jc.add(jcF);  
  230.     }  
  231.   
  232.     /** 
  233.      * 设置新的簇中心方法 
  234.      */  
  235.     private void setNewCenter() {  
  236.         for (int i = 0; i < k; i++) {  
  237.             int n = cluster.get(i).size();  
  238.             if (n != 0) {  
  239.                 KmeansObject newCenter = new KmeansObject();  
  240.                 for (int j = 0; j < n; j++) {  
  241.                     newCenter.compare += cluster.get(i).get(j).compare;  
  242.                 }  
  243.                 // 设置一个平均值   
  244.                 newCenter.compare=newCenter.compare/n;  
  245.                   
  246.                 center.set(i, newCenter);  
  247.             }  
  248.         }  
  249.     }  
  250.   
  251.     /** 
  252.      * 打印数据,测试用 
  253.      *  
  254.      * @param dataArray 
  255.      *            数据集 
  256.      * @param dataArrayName 
  257.      *            数据集名称 
  258.      */  
  259.     public void printDataArray(ArrayList<KmeansObject> dataArray,  
  260.             String dataArrayName) {  
  261.         for (int i = 0; i < dataArray.size(); i++) {  
  262.             System.out.println("print:" + dataArrayName + "[" + i + "]={"  
  263.                     + dataArray.get(i) + "}");  
  264.         }  
  265.         System.out.println("===================================");  
  266.     }  
  267.   
  268.     /** 
  269.      * Kmeans算法核心过程方法 
  270.      */  
  271.     private void kmeans() {  
  272.         init();  
  273.   
  274.         // 循环分组,直到误差不变为止   
  275.         while (true) {  
  276.             clusterSet();  
  277.   
  278.             countRule();  
  279.               
  280.             // 误差不变了,分组完成   
  281.             if (m != 0) {  
  282.                 if (jc.get(m) - jc.get(m - 1) == 0) {  
  283.                     break;  
  284.                 }  
  285.             }  
  286.   
  287.             setNewCenter();  
  288.             m++;  
  289.             cluster.clear();  
  290.             cluster = initCluster();  
  291.         }  
  292.           
  293.     }  
  294.   
  295.     /** 
  296.      * 执行算法 
  297.      */  
  298.     public void execute() {  
  299.         long startTime = System.currentTimeMillis();  
  300.         System.out.println("kmeans begins");  
  301.         kmeans();  
  302.         long endTime = System.currentTimeMillis();  
  303.         System.out.println("kmeans running time=" + (endTime - startTime)  
  304.                 + "ms");  
  305.         System.out.println("kmeans ends");  
  306.         System.out.println();  
  307.     }  
  308.       
  309.       
  310. }  

3)测试算法,首先建立一个Person类,目标在于对人进行分类

  1. package org.cyxl.util.algorithm;  
  2.   
  3. public class Person extends KmeansObject {  
  4.     String name="";  
  5.     int age=0;  
  6.     float qz=1;     //权重   
  7.       
  8.     public Person(){}  
  9.       
  10.     public Person(String name,int age,float qz)  
  11.     {  
  12.         this.name=name;  
  13.         this.age=age;  
  14.         this.qz=qz;  
  15.     }  
  16.       
  17.     public String getName() {  
  18.         return name;  
  19.     }  
  20.     public void setName(String name) {  
  21.         this.name = name;  
  22.     }  
  23.     public int getAge() {  
  24.         return age;  
  25.     }  
  26.     public void setAge(int age) {  
  27.         this.age = age;  
  28.     }  
  29.       
  30.     public float getQz() {  
  31.         return qz;  
  32.     }  
  33.   
  34.     public void setQz(float qz) {  
  35.         this.qz = qz;  
  36.     }  
  37.   
  38.     public String toString()  
  39.     {  
  40.         return "name:"+this.name+";age:"+this.age+";qz:"+this.qz+";compare:"+super.compare;  
  41.     }  
  42. }  

4)客户端测试代码

  1.               CommonKmeans k=new CommonKmeans(5);  
  2. ArrayList<KmeansObject> list=new ArrayList<KmeansObject>();  
  3.   
  4. for(int i=0;i<10;i++)  
  5. {  
  6.     float qz=(float)(new Random().nextInt(10))/10;  
  7.     Person p=new Person("name"+i,i,qz);  
  8.     p.compare=new Random().nextInt(100)*p.getQz();  
  9.     list.add(p);  
  10. }  
  11. k.setDataSet(list);  
  12. k.printDataArray(k.dataSet, "before");  
  13. k.execute();  
  14. ArrayList<ArrayList<KmeansObject>> cluster=k.getCluster();  
  15. //查看结果   
  16. for(int i=0;i<cluster.size();i++)  
  17. {  
  18.     k.printDataArray(cluster.get(i), "cluster["+i+"]");  
  19. }  

5)输出结果

  1. print:before[0]={name:name0;age:0;qz:0.0;compare:0.0}  
  2. print:before[1]={name:name1;age:1;qz:0.9;compare:48.6}  
  3. print:before[2]={name:name2;age:2;qz:0.9;compare:57.6}  
  4. print:before[3]={name:name3;age:3;qz:0.4;compare:28.4}  
  5. print:before[4]={name:name4;age:4;qz:0.0;compare:0.0}  
  6. print:before[5]={name:name5;age:5;qz:0.4;compare:33.600002}  
  7. print:before[6]={name:name6;age:6;qz:0.5;compare:2.0}  
  8. print:before[7]={name:name7;age:7;qz:0.2;compare:14.6}  
  9. print:before[8]={name:name8;age:8;qz:0.6;compare:5.4}  
  10. print:before[9]={name:name9;age:9;qz:0.9;compare:52.199997}  
  11. ===================================  
  12. kmeans begins  
  13. kmeans running time=0ms  
  14. kmeans ends  
  15.   
  16. print:cluster[0][0]={name:name3;age:3;qz:0.4;compare:28.4}  
  17. print:cluster[0][1]={name:name5;age:5;qz:0.4;compare:33.600002}  
  18. ===================================  
  19. print:cluster[1][0]={name:name7;age:7;qz:0.2;compare:14.6}  
  20. ===================================  
  21. print:cluster[2][0]={name:name2;age:2;qz:0.9;compare:57.6}  
  22. ===================================  
  23. print:cluster[3][0]={name:name1;age:1;qz:0.9;compare:48.6}  
  24. print:cluster[3][1]={name:name9;age:9;qz:0.9;compare:52.199997}  
  25. ===================================  
  26. print:cluster[4][0]={name:name0;age:0;qz:0.0;compare:0.0}  
  27. print:cluster[4][1]={name:name4;age:4;qz:0.0;compare:0.0}  
  28. print:cluster[4][2]={name:name6;age:6;qz:0.5;compare:2.0}  
  29. print:cluster[4][3]={name:name8;age:8;qz:0.6;compare:5.4}  
  30. ===================================  

4、说明及总结。

       1)基类KmeansObject定义了一个compare,我们把它叫做比较因子,分类时只要就是对分类因子进行分类计算的。所以这个分类因子很重要,每个对象的分类因子可以具体的根据业务进行计算设置。比如我们客户端测试代码中的比较因子的计算方法是,首先给每个对象赋予一个权值qz,然后根据权值和年龄的乘积(具体计算方法根据业务定)来对人群进行分类

2)该算法中对于比较因子compare的计算是影响该算法准确性的一个很重要方面,具体表现在距离(distance方法)和误差(errorSquare方法)计算中。想要改善该算法可以从这两个方法中进行修改

3)当然,我对于这个算法的实现和应用都还是很浅。如果有什么不对或者可以改善的地方请不吝赐教

    本站是提供个人知识管理的网络存储空间,所有内容均由用户发布,不代表本站观点。请注意甄别内容中的联系方式、诱导购买等信息,谨防诈骗。如发现有害或侵权内容,请点击一键举报。
    转藏 分享 献花(0

    0条评论

    发表

    请遵守用户 评论公约

    类似文章 更多