分享

java实现一元线性回归算法

 昵称50o19 2017-01-02

网上看一个达人用java写的一元线性回归的实现,我觉得挺有用的,一些企业做数据挖掘不是用到了,预测运营收入的功能吗?采用一元线性回归算法,可以计算出类似的功能。直接上代码吧:

1、定义一个DataPoint类,对X和Y坐标点进行封装:

/** * File : DataPoint.java * Author : zhouyujie * Date : 2012-01-11 16:00:00 * Description : Java实现一元线性回归的算法,座标点实体类,(可实现统计指标的预测) */ package com.zyujie.dm; public class DataPoint { /** the x value */ public float x; /** the y value */ public float y; /** * Constructor. * * @param x * the x value * @param y * the y value */ public DataPoint(float x, float y) { this.x = x; this.y = y; } }
2、下面是算法实现回归线:

/** * File : DataPoint.java * Author : zhouyujie * Date : 2012-01-11 16:00:00 * Description : Java实现一元线性回归的算法,回归线实现类,(可实现统计指标的预测) */ package com.zyujie.dm; import java.math.BigDecimal; import java.util.ArrayList; public class RegressionLine // implements Evaluatable { /** sum of x */ private double sumX; /** sum of y */ private double sumY; /** sum of x*x */ private double sumXX; /** sum of x*y */ private double sumXY; /** sum of y*y */ private double sumYY; /** sum of yi-y */ private double sumDeltaY; /** sum of sumDeltaY^2 */ private double sumDeltaY2; /** 误差 */ private double sse; private double sst; private double E; private String[] xy; private ArrayList listX; private ArrayList listY; private int XMin, XMax, YMin, YMax; /** line coefficient a0 */ private float a0; /** line coefficient a1 */ private float a1; /** number of data points */ private int pn; /** true if coefficients valid */ private boolean coefsValid; /** * Constructor. */ public RegressionLine() { XMax = 0; YMax = 0; pn = 0; xy = new String[2]; listX = new ArrayList(); listY = new ArrayList(); } /** * Constructor. * * @param data * the array of data points */ public RegressionLine(DataPoint data[]) { pn = 0; xy = new String[2]; listX = new ArrayList(); listY = new ArrayList(); for (int i = 0; i < data.length; i) { addDataPoint(data[i]); } } /** * Return the current number of data points. * * @return the count */ public int getDataPointCount() { return pn; } /** * Return the coefficient a0. * * @return the value of a0 */ public float getA0() { validateCoefficients(); return a0; } /** * Return the coefficient a1. * * @return the value of a1 */ public float getA1() { validateCoefficients(); return a1; } /** * Return the sum of the x values. * * @return the sum */ public double getSumX() { return sumX; } /** * Return the sum of the y values. * * @return the sum */ public double getSumY() { return sumY; } /** * Return the sum of the x*x values. * * @return the sum */ public double getSumXX() { return sumXX; } /** * Return the sum of the x*y values. * * @return the sum */ public double getSumXY() { return sumXY; } public double getSumYY() { return sumYY; } public int getXMin() { return XMin; } public int getXMax() { return XMax; } public int getYMin() { return YMin; } public int getYMax() { return YMax; } /** * Add a new data point: Update the sums. * * @param dataPoint * the new data point */ public void addDataPoint(DataPoint dataPoint) { sumX = dataPoint.x; sumY = dataPoint.y; sumXX = dataPoint.x * dataPoint.x; sumXY = dataPoint.x * dataPoint.y; sumYY = dataPoint.y * dataPoint.y; if (dataPoint.x > XMax) { XMax = (int) dataPoint.x; } if (dataPoint.y > YMax) { YMax = (int) dataPoint.y; } // 把每个点的具体坐标存入ArrayList中,备用 xy[0] = (int) dataPoint.x ''; xy[1] = (int) dataPoint.y ''; if (dataPoint.x != 0 && dataPoint.y != 0) { System.out.print(xy[0] ','); System.out.println(xy[1]); try { // System.out.println('n:' n); listX.add(pn, xy[0]); listY.add(pn, xy[1]); } catch (Exception e) { e.printStackTrace(); } /* * System.out.println('N:' n); System.out.println('ArrayList * listX:' listX.get(n)); System.out.println('ArrayList listY:' * listY.get(n)); */ } pn; coefsValid = false; } /** * Return the value of the regression line function at x. (Implementation of * Evaluatable.) * * @param x * the value of x * @return the value of the function at x */ public float at(int x) { if (pn < 2) return Float.NaN; validateCoefficients(); return a0 a1 * x; } /** * Reset. */ public void reset() { pn = 0; sumX = sumY = sumXX = sumXY = 0; coefsValid = false; } /** * Validate the coefficients. 计算方程系数 y=ax b 中的a */ private void validateCoefficients() { if (coefsValid) return; if (pn >= 2) { float xBar = (float) sumX / pn; float yBar = (float) sumY / pn; a1 = (float) ((pn * sumXY - sumX * sumY) / (pn * sumXX - sumX * sumX)); a0 = (float) (yBar - a1 * xBar); } else { a0 = a1 = Float.NaN; } coefsValid = true; } /** * 返回误差 */ public double getR() { // 遍历这个list并计算分母 for (int i = 0; i < pn - 1; i ) { float Yi = (float) Integer.parseInt(listY.get(i).toString()); float Y = at(Integer.parseInt(listX.get(i).toString())); float deltaY = Yi - Y; float deltaY2 = deltaY * deltaY; /* * System.out.println('Yi:' Yi); System.out.println('Y:' Y); * System.out.println('deltaY:' deltaY); * System.out.println('deltaY2:' deltaY2); */ sumDeltaY2 = deltaY2; // System.out.println('sumDeltaY2:' sumDeltaY2); } sst = sumYY - (sumY * sumY) / pn; // System.out.println('sst:' sst); E = 1 - sumDeltaY2 / sst; return round(E, 4); } // 用于实现精确的四舍五入 public double round(double v, int scale) { if (scale < 0) { throw new IllegalArgumentException( 'The scale must be a positive integer or zero'); } BigDecimal b = new BigDecimal(Double.toString(v)); BigDecimal one = new BigDecimal('1'); return b.divide(one, scale, BigDecimal.ROUND_HALF_UP).doubleValue(); } public float round(float v, int scale) { if (scale < 0) { throw new IllegalArgumentException( 'The scale must be a positive integer or zero'); } BigDecimal b = new BigDecimal(Double.toString(v)); BigDecimal one = new BigDecimal('1'); return b.divide(one, scale, BigDecimal.ROUND_HALF_UP).floatValue(); } }
3、线性回归测试类:

/** * File : DataPoint.java * Author : zhouyujie * Date : 2012-01-11 16:00:00 * Description : Java实现一元线性回归的算法,线性回归测试类,(可实现统计指标的预测) */ package com.zyujie.dm; /** * <p> * <b>Linear Regression</b> <br> * Demonstrate linear regression by constructing the regression line for a set * of data points. * * <p> * require DataPoint.java,RegressionLine.java * * <p> * 为了计算对于给定数据点的最小方差回线,需要计算SumX,SumY,SumXX,SumXY; (注:SumXX = Sum (X^2)) * <p> * <b>回归直线方程如下: f(x)=a1x a0 </b> * <p> * <b>斜率和截距的计算公式如下:</b> <br> * n: 数据点个数 * <p> * a1=(n(SumXY)-SumX*SumY)/(n*SumXX-(SumX)^2) <br> * a0=(SumY - SumY * a1)/n <br> * (也可表达为a0=averageY-a1*averageX) * * <p> * <b>画线的原理:两点成一直线,只要能确定两个点即可</b><br> * 第一点:(0,a0) 再随意取一个x1值代入方程,取得y1,连结(0,a0)和(x1,y1)两点即可。 * 为了让线穿过整个图,x1可以取横坐标的最大值Xmax,即两点为(0,a0),(Xmax,Y)。如果y=a1*Xmax a0,y大于 * 纵坐标最大值Ymax,则不用这个点。改用y取最大值Ymax,算得此时x的值,使用(X,Ymax), 即两点为(0,a0),(X,Ymax) * * <p> * <b>拟合度计算:(即Excel中的R^2)</b> * <p> * *R2 = 1 - E * <p> * 误差E的计算:E = SSE/SST * <p> * SSE=sum((Yi-Y)^2) SST=sumYY - (sumY*sumY)/n; * <p> */ public class LinearRegression { private static final int MAX_POINTS = 10; private double E; /** * Main program. * * @param args * the array of runtime arguments */ public static void main(String args[]) { RegressionLine line = new RegressionLine(); line.addDataPoint(new DataPoint(1, 136)); line.addDataPoint(new DataPoint(2, 143)); line.addDataPoint(new DataPoint(3, 132)); line.addDataPoint(new DataPoint(4, 142)); line.addDataPoint(new DataPoint(5, 147)); printSums(line); printLine(line); } /** * Print the computed sums. * * @param line * the regression line */ private static void printSums(RegressionLine line) { System.out.println('\n数据点个数 n = ' line.getDataPointCount()); System.out.println('\nSum x = ' line.getSumX()); System.out.println('Sum y = ' line.getSumY()); System.out.println('Sum xx = ' line.getSumXX()); System.out.println('Sum xy = ' line.getSumXY()); System.out.println('Sum yy = ' line.getSumYY()); } /** * Print the regression line function. * * @param line * the regression line */ private static void printLine(RegressionLine line) { System.out.println('\n回归线公式: y = ' line.getA1() 'x ' line.getA0()); System.out.println('误差: R^2 = ' line.getR()); } //y = 2.1x 133.7 2.1 * 6 133.7 = 12.6 133.7 = 146.3 //y = 2.1x 133.7 2.1 * 7 133.7 = 14.7 133.7 = 148.4 }

我们运行测试类,得到运行结果:

1,136
2,143
3,132
4,142
5,147

数据点个数 n = 5

Sum x  = 15.0
Sum y  = 700.0
Sum xx = 55.0
Sum xy = 2121.0
Sum yy = 98142.0

回归线公式:  y = 2.1x 133.7
误差:     R^2 = 0.3658

假如某公司:

1月收入,136万元
2月收入,143万元
3月收入,132万元
4月收入,142万元
5月收入,147万元

我们可以根据回归线公式:y = 2.1x 133.7,预测出6月份收入:

y = 2.1 * 6 133.7 = 12.6 133.7 = 146.3

    本站是提供个人知识管理的网络存储空间,所有内容均由用户发布,不代表本站观点。请注意甄别内容中的联系方式、诱导购买等信息,谨防诈骗。如发现有害或侵权内容,请点击一键举报。
    转藏 分享 献花(0

    0条评论

    发表

    请遵守用户 评论公约