分享

【机器学习】逻辑回归

 时予可 2023-04-28 发布于山西

在使用回归模型分析实际问题中,所研究的变量往往不全是区间变量而是顺序变量或属性变量,比如二项分布问题。通过分析年龄、性别、体质指数、平均血压、疾病指数等指标,判断一个人是否换糖尿病,Y=0表示未患病,Y=1表示患病,这里的响应变量是一个两点(0或1)分布变量,它就不能用映射函数h连续的值来预测因变量Y(Y只能取0或1)。

线性回归或多项式回归模型通常是处理因变量为连续变量的问题,如果因变量是定性变量,线性回归模型就不再适用了,此时需采用逻辑回归模型解决。

1.基本概念

逻辑回归(Logistic Regression)是用于处理因变量为分类变量的回归问题,常见的是二分类或二项分布问题,也可以处理多分类问题,它实际上是属于一种分类方法。

二分类问题的概率与自变量之间的关系图形往往是一个S型曲线,采用的Sigmoid函数实现。这里我们将该函数定义如下:

函数的定义域为全体实数,值域在[0,1]之间,x轴在0点对应的结果为0.5。当x取值足够大的时候,可以看成0或1两类问题,大于0.5可以认为是1类问题,反之是0类问题,而刚好是0.5,则可以划分至0类或1类。

采用线性模型进行分析,其公式变换如下:

而实际应用中,概率p与因变量往往是非线性的,为了解决该类问题,我们引入了logit变换,使得logit§与自变量之间存在线性相关的关系,逻辑回归模型定义如下:

通过推导,概率p变换如下,这与Sigmoid函数相符,也体现了概率p与因变量之间的非线性关系。以0.5为界限,预测p大于0.5时,我们判断此时y更可能为1,否则y为0。

2.LogisticRegression

LogisticRegression回归模型在Sklearn.linear_model子类下。

调用sklearn逻辑回归算法步骤:

导入模型。调用逻辑回归LogisticRegression()函数。

fit()训练。调用fit(x,y)的方法来训练模型,其中x为数据的属性,y为所属类型。

predict()预测。利用训练得到的模型对数据集进行预测,返回预测结果。

3.鸢尾花数据集回归分析实例

在Sklearn机器学习包中,集成了各种各样的数据集,这里引入的是鸢尾花卉(Iris)数据集,它也是一个很常用的数据集。该数据集一共包含4个特征变量,1个类别变量,共有150个样本。其中四个特征分别是萼片的长度和宽度、花瓣的长度和宽度,一个类别变量是标记鸢尾花所属的分类情况,该值包含三种情况,即山鸢尾、变色鸢尾和维吉尼亚鸢尾。

iris里有两个属性iris.data,iris.target。data是一个矩阵,每一列代表了萼片或花瓣的长宽,一共4列,每一行代表一个被测量的鸢尾植物,一共采样了150条记录,即150朵鸢尾花样本。

from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
import matplotlib.pyplot as plt
import numpy as np

iris=load_iris()#载入数据集
d=iris.data
x=[x[0] for x in d]
y=[x[1] for x in d]
#绘制散点图
plt.scatter(x[:50],y[:50],color='r',marker='o',label='setosa')
plt.scatter(x[50:100],y[50:100],color='b',marker='x',label='versicolor')
plt.scatter(x[100:],y[100:],color='g',marker='+',label='virginica')
plt.legend()
plt.show()

获取鸢尾花数据集的前两列数据,绘制三种类别的鸢尾散点图。

可以看出,数据集是线性可分的,划分为3类,分别对应三种类型的鸢尾花,下面采用逻辑回归对其进行分析预测。

#逻辑回归
x1=d[:,:2]#获取数据集的两列数据
y1=iris.target
lr=LogisticRegression(C=1e5)#初始化逻辑回归模型,C=1e5表示目标函数
lr.fit(x1,y1)#调用逻辑回归模型进行训练,参数X为数据特征,参数Y为数据类标
h=0.02
#取第一列的最小值、最大值和步长h生成数组
x1min,x1max=x1[:,0].min()-.5,x1[:,0].max()+.5
y1min,y1max=x1[:,1].min()-.5,x1[:,1].max()+.5
#meshgrid函数生成两个网格矩阵xx和yy
xx,yy=np.meshgrid(np.arange(x1min,x1max,h),np.arange(y1min,y1max,h))
#ravel()函数将xx和yy的两个矩阵转变成一维数组
#np.c_[xx.ravel(), yy.ravel()]是获取并合并成矩阵
z=lr.predict(np.c_[xx.ravel(),yy.ravel()])
z=z.reshape(xx.shape)
plt.figure(1,figsize=(8,6))
plt.pcolormesh(xx,yy,z,cmap=plt.cm.Paired)
#绘制散点图
plt.scatter(x1[:50,0],x1[:50,1],color='r',marker='o',label='setosa')
plt.scatter(x1[50:100,0],x1[50:100,1],color='b',marker='x',label='versicolor')
plt.scatter(x1[100:,0],x1[100:,1],color='g',marker='+',label='virginica')
plt.xlabel('sepal length')
plt.ylabel('sepal width')
plt.xlim(xx.min(),xx.max())
plt.ylim(yy.min(),yy.max())
plt.xticks()
plt.yticks()
plt.legend()
plt.show()

经过逻辑回归后划分为三个区域,左上角部分为红色的圆点,对应setosa鸢尾花;右上角部分为绿色方块,对应virginica鸢尾花;中间下部分为蓝色星形,对应versicolor鸢尾花。

散点图为各数据点真实的花类型,划分的三个区域为数据点预测的花类型,预测的分类结果与训练数据的真实结果结果基本一致,部分鸢尾花出现交叉。

    本站是提供个人知识管理的网络存储空间,所有内容均由用户发布,不代表本站观点。请注意甄别内容中的联系方式、诱导购买等信息,谨防诈骗。如发现有害或侵权内容,请点击一键举报。
    转藏 分享 献花(0

    0条评论

    发表

    请遵守用户 评论公约

    类似文章 更多