在使用回归模型分析实际问题中,所研究的变量往往不全是区间变量而是顺序变量或属性变量,比如二项分布问题。通过分析年龄、性别、体质指数、平均血压、疾病指数等指标,判断一个人是否换糖尿病,Y=0表示未患病,Y=1表示患病,这里的响应变量是一个两点(0或1)分布变量,它就不能用映射函数h连续的值来预测因变量Y(Y只能取0或1)。 线性回归或多项式回归模型通常是处理因变量为连续变量的问题,如果因变量是定性变量,线性回归模型就不再适用了,此时需采用逻辑回归模型解决。 1.基本概念 逻辑回归(Logistic Regression)是用于处理因变量为分类变量的回归问题,常见的是二分类或二项分布问题,也可以处理多分类问题,它实际上是属于一种分类方法。 二分类问题的概率与自变量之间的关系图形往往是一个S型曲线,采用的Sigmoid函数实现。这里我们将该函数定义如下: 函数的定义域为全体实数,值域在[0,1]之间,x轴在0点对应的结果为0.5。当x取值足够大的时候,可以看成0或1两类问题,大于0.5可以认为是1类问题,反之是0类问题,而刚好是0.5,则可以划分至0类或1类。 采用线性模型进行分析,其公式变换如下: 而实际应用中,概率p与因变量往往是非线性的,为了解决该类问题,我们引入了logit变换,使得logit§与自变量之间存在线性相关的关系,逻辑回归模型定义如下: 通过推导,概率p变换如下,这与Sigmoid函数相符,也体现了概率p与因变量之间的非线性关系。以0.5为界限,预测p大于0.5时,我们判断此时y更可能为1,否则y为0。 2.LogisticRegression LogisticRegression回归模型在Sklearn.linear_model子类下。 调用sklearn逻辑回归算法步骤: 导入模型。调用逻辑回归LogisticRegression()函数。 fit()训练。调用fit(x,y)的方法来训练模型,其中x为数据的属性,y为所属类型。 predict()预测。利用训练得到的模型对数据集进行预测,返回预测结果。 3.鸢尾花数据集回归分析实例 在Sklearn机器学习包中,集成了各种各样的数据集,这里引入的是鸢尾花卉(Iris)数据集,它也是一个很常用的数据集。该数据集一共包含4个特征变量,1个类别变量,共有150个样本。其中四个特征分别是萼片的长度和宽度、花瓣的长度和宽度,一个类别变量是标记鸢尾花所属的分类情况,该值包含三种情况,即山鸢尾、变色鸢尾和维吉尼亚鸢尾。 iris里有两个属性iris.data,iris.target。data是一个矩阵,每一列代表了萼片或花瓣的长宽,一共4列,每一行代表一个被测量的鸢尾植物,一共采样了150条记录,即150朵鸢尾花样本。 from sklearn.datasets import load_iris from sklearn.linear_model import LogisticRegression import matplotlib.pyplot as plt import numpy as np iris=load_iris()#载入数据集 d=iris.data x=[x[0] for x in d] y=[x[1] for x in d] #绘制散点图 plt.scatter(x[:50],y[:50],color='r',marker='o',label='setosa') plt.scatter(x[50:100],y[50:100],color='b',marker='x',label='versicolor') plt.scatter(x[100:],y[100:],color='g',marker='+',label='virginica') plt.legend() plt.show() 获取鸢尾花数据集的前两列数据,绘制三种类别的鸢尾散点图。 可以看出,数据集是线性可分的,划分为3类,分别对应三种类型的鸢尾花,下面采用逻辑回归对其进行分析预测。 #逻辑回归 x1=d[:,:2]#获取数据集的两列数据 y1=iris.target lr=LogisticRegression(C=1e5)#初始化逻辑回归模型,C=1e5表示目标函数 lr.fit(x1,y1)#调用逻辑回归模型进行训练,参数X为数据特征,参数Y为数据类标 h=0.02 #取第一列的最小值、最大值和步长h生成数组 x1min,x1max=x1[:,0].min()-.5,x1[:,0].max()+.5 y1min,y1max=x1[:,1].min()-.5,x1[:,1].max()+.5 #meshgrid函数生成两个网格矩阵xx和yy xx,yy=np.meshgrid(np.arange(x1min,x1max,h),np.arange(y1min,y1max,h)) #ravel()函数将xx和yy的两个矩阵转变成一维数组 #np.c_[xx.ravel(), yy.ravel()]是获取并合并成矩阵 z=lr.predict(np.c_[xx.ravel(),yy.ravel()]) z=z.reshape(xx.shape) plt.figure(1,figsize=(8,6)) plt.pcolormesh(xx,yy,z,cmap=plt.cm.Paired) #绘制散点图 plt.scatter(x1[:50,0],x1[:50,1],color='r',marker='o',label='setosa') plt.scatter(x1[50:100,0],x1[50:100,1],color='b',marker='x',label='versicolor') plt.scatter(x1[100:,0],x1[100:,1],color='g',marker='+',label='virginica') plt.xlabel('sepal length') plt.ylabel('sepal width') plt.xlim(xx.min(),xx.max()) plt.ylim(yy.min(),yy.max()) plt.xticks() plt.yticks() plt.legend() plt.show() 经过逻辑回归后划分为三个区域,左上角部分为红色的圆点,对应setosa鸢尾花;右上角部分为绿色方块,对应virginica鸢尾花;中间下部分为蓝色星形,对应versicolor鸢尾花。 散点图为各数据点真实的花类型,划分的三个区域为数据点预测的花类型,预测的分类结果与训练数据的真实结果结果基本一致,部分鸢尾花出现交叉。 |
|