找回密码
 立即注册
首页 业界区 业界 KNN算法

KNN算法

简千叶 2025-9-25 10:58:09
 
 
1.png

 代码实现:
  1.   1 # !/usr/bin/python3
  2.   2 # -*- coding: utf-8 -*-
  3.   3
  4.   4 import pandas as pd
  5.   5 import numpy as np
  6.   6 from sklearn.datasets import make_classification
  7.   7 from sklearn.neighbors import KNeighborsClassifier
  8.   8 from sklearn.model_selection import train_test_split
  9.   9 from sklearn.metrics import confusion_matrix,classification_report # 混淆矩阵
  10. 10 from sklearn.metrics import precision_recall_curve,roc_auc_score,roc_curve,auc # ROC曲线等
  11. 11 import matplotlib.pyplot as plt
  12. 12
  13. 13 '''
  14. 14     1. 准备数据,划分训练集和测试集
  15. 15 '''
  16. 16
  17. 17 # 生成一个5000样本量,30个特征,3个分类的数据集
  18. 18 X,y = make_classification(n_samples=5000,n_features=30,n_classes=2,n_informative=2,random_state=10)
  19. 19
  20. 20 '''
  21. 21 make_classification 参数:
  22. 22     n_samples‌:生成的样本数量,默认为100。
  23. 23     n_features‌:特征数量,默认为20。
  24. 24     n_informative‌:信息性特征数量,默认为2。这些特征与输出类别有关。
  25. 25     n_redundant‌:冗余特征数量,默认为2。这些特征是信息性特征的线性组合。
  26. 26     n_repeated‌:重复特征数量,默认为0。这些特征是从其他特征中复制的。
  27. 27     n_classes‌:类别数量,默认为2。
  28. 28     n_clusters_per_class‌:每个类别中的簇数量,默认为2。
  29. 29     weights‌:每个类别的样本权重,默认为None。
  30. 30     flip_y‌:标签翻转概率,默认为0.01,用于增加噪声。
  31. 31     class_sep‌:类间分离因子,默认为1.0。值越大,类分离越明显。
  32. 32     hypercube‌:布尔值,指定特征是否在超立方体中生成,默认为True。
  33. 33     shift‌和‌scale‌:用于特征的偏移和平移。
  34. 34     shuffle‌:布尔值,指定生成数据后是否打乱数据,默认为True。
  35. 35     random_state‌:随机数生成器的状态或种子,用于确保数据可重复。
  36. 36     ‌返回值‌包括两个数组:X(形状为[n_samples, n_features]的特征矩阵)和y(形状为[n_samples]的目标向量
  37. 37
  38. 38 '''
  39. 39
  40. 40 # 将数据集划分为训练集和测试集
  41. 41
  42. 42 Xtrain,Xtest,Ytrain,Ytest = train_test_split(X,y,test_size=0.3,random_state=1)
  43. 43
  44. 44
  45. 45 '''
  46. 46     2. 建立模型&评估模型
  47. 47 '''
  48. 48 k_values = [1,3,5,7]
  49. 49
  50. 50 for k in k_values:
  51. 51     clf = KNeighborsClassifier(n_neighbors=k) # 实例化模型
  52. 52     clf = clf.fit(Xtrain,Ytrain,) # 使用训练集训练模型
  53. 53     score = clf.score(Xtest,Ytest) # 看模型在新数据集(测试集)上的预测效果
  54. 54     print(score) # 准确率
  55. 55     # 看测试集上的获得的预测概率
  56. 56
  57. 57     y_prob_1 = clf.predict_proba(Xtest)[:,1]
  58. 58     # print(y_prob_1)
  59. 59
  60. 60     '''
  61. 61     predict_proba返回的是一个n行k列的数组,其中每一行代表一个测试样本,每一列代表一个类别。
  62. 62     例如,对于二分类问题,返回的数组有两列,第一列表示属于第一个类别的概率,第二列表示属于第二个类别的概率。
  63. 63     '''
  64. 64     # print(y_prob_0)
  65. 65     # print(y_prob_1)
  66. 66     # print(y_prob_2)
  67. 67     # print(y_prob)
  68. 68
  69. 69
  70. 70     '''
  71. 71         3. 绘制ROC曲线,PR曲线
  72. 72     '''
  73. 73
  74. 74
  75. 75     # 假正率(FPR)、真正率(TPR)和阈值(thresholds)
  76. 76     '''
  77. 77     假正率(FPR):假正率表示在实际为负例的样本中,被模型错误预测为正例的比例。
  78. 78     真正率(TPR):真正率也称为灵敏度(Sensitivity)或召回率(Recall),它表示在实际为正例的样本中,被模型正确预测为正例的比例。
  79. 79     '''
  80. 80
  81. 81     FPR, TPR, thresholds = roc_curve(Ytest,y_prob_1)
  82. 82
  83. 83     '''
  84. 84     ROC曲线‌:ROC曲线主要用于衡量二分类器的性能。它以假正率(FPR)为横坐标,真正率(TPR)为纵坐标,绘制出分类器的性能曲线。
  85. 85     ROC曲线越靠近左上角,表示分类器的性能越好‌
  86. 86     ROC曲线越靠近左上角(0, 1)点,说明分类器的性能越好。
  87. 87
  88. 88     AUC(Area Under the Curve)是ROC 曲线下方的面积,范围在0到1之间,可以理解为模型正确区分正例和反例的能力。
  89. 89     一个完美的分类器的AUC值为1,而一个随机猜测的分类器的AUC值为0.5。
  90. 90     
  91. 91     '''
  92. 92
  93. 93     # 计算ROC曲线的参数
  94. 94
  95. 95     ROC_AUC = auc(FPR, TPR)
  96. 96
  97. 97     # 绘制ROC曲线
  98. 98     plt.subplot(1, 2, 2)
  99. 99     plt.plot(FPR, TPR, lw=2, label=f'k={k}, AUC = {ROC_AUC:.2f}')
  100. 100     plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
  101. 101     plt.legend(loc="lower right")
  102. 102
  103. 103     plt.grid(True)
  104. 104     plt.title('ROC')
  105. 105     plt.xlabel('FPR')
  106. 106     plt.ylabel('TPR')
  107. 107     # plt.show()
  108. 108
  109. 109     '''
  110. 110     PR 曲线,全称为 Precision - Recall 曲线,是用于评估分类模型性能的重要工具
  111. 111     定义与原理:PR 曲线通过绘制精度(Precision)与召回率(Recall)之间的关系曲线,来展示模型在不同阈值下的表现。
  112. 112     精度表示在所有被预测为正类的样本中实际为正类的比例,召回率表示在所有实际为正类的样本中被正确预测为正类的比例。
  113. 113     通过改变分类阈值,可以得到一系列不同的精度和召回率值,将这些值绘制成曲线,就得到了 PR 曲线。
  114. 114     
  115. 115     绘制方法:首先计算模型在不同阈值下的精度和召回率。然后,以召回率为横坐标,精度为纵坐标,将各个阈值下对应的点连接起来,形成 PR 曲线。
  116. 116     PR曲线越靠近右上角(1, 1)点,说明分类器的性能越好。
  117. 117     '''
  118. 118
  119. 119
  120. 120     # 绘制PR曲线
  121. 121     # 绘制不同K值的kNN分类器在测试集上的PR曲线,并计算对应的AUC值。
  122. 122     # precision_recall_curve 返回结果依次是precision(精确度)、recall(召回率)和 thresholds(阈值)
  123. 123     precision, recall, thresholds = precision_recall_curve(Ytest,y_prob_1)
  124. 124     pr_auc = auc(recall, precision)
  125. 125
  126. 126     # 绘制PR曲线
  127. 127     plt.subplot(1, 2, 1)
  128. 128     plt.plot(recall, precision, lw=2, label=f'k={k}, AUC = {pr_auc:.2f}')
  129. 129     plt.legend(loc="lower left")
  130. 130     plt.grid(True)
  131. 131     plt.title('PR')
  132. 132     plt.xlabel('recall')
  133. 133     plt.ylabel('precision')
  134. 134
  135. 135 plt.subplots_adjust(hspace=0.3, wspace=0.3)  # 调整间距
  136. 136 plt.show()
  137. 137
  138. 138
  139. 139 '''
  140. 140     4. 确定K值,确定最终模型
  141. 141 '''
  142. 142 clf = KNeighborsClassifier(n_neighbors=7)  # 实例化模型
  143. 143 clf = clf.fit(Xtrain, Ytrain, )  # 使用训练集训练模型
  144. 144 score = clf.score(Xtest, Ytest)  # 看模型在新数据集(测试集)上的预测效果
  145. 145 print(score)  # 准确率
  146. 146
  147. 147 print("训练集上的预测准确率为:", clf.score(Xtrain, Ytrain))
  148. 148 print("测试集上的预测准确率为:", clf.score(Xtest, Ytest))
  149. 149
  150. 150 print('混淆矩阵:',confusion_matrix(Ytest, clf.predict(Xtest)))
  151. 151 '''
  152. 152 混淆矩阵是机器学习中用于评估分类模型性能的重要工具。它通过表格形式直观展示模型预测结果与真实标签的对比,帮助分析分类错误的具体类型。
  153. 153 以下是一个二分类混淆矩阵的表格:
  154. 154
  155. 155              实际为正例    实际为反例
  156. 156 预测为正例    真正例(TP)    假正例(FP)
  157. 157 预测为反例    真反例(TN)    真反例(FN)
  158. 158 真正例(True Positive,TP):正确预测为正类的样本数。
  159. 159 假正例(False Positive,FP):实际为负类但被错误预测为正类的样本数。
  160. 160 假反例(False Negative,FN):实际为正类但被错误预测为负类的样本数。
  161. 161 真反例(True Negative,TN):正确预测为负类的样本数。
  162. 162 通过混淆矩阵,我们可以求得准确率、精确率、召回率等性能指标
  163. 163
  164. 164 '''
  165. 165 print('AUC值:',roc_auc_score(Ytest, clf.predict(Xtest)))
  166. 166
  167. 167 print('整体情况:',classification_report(Ytest, clf.predict(Xtest)))
复制代码
 
  1. <br>结果展示:
复制代码
2.png

 
  1. <br><br>
复制代码
3.png

 
  1. <br><br><br><br>
复制代码
来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!

相关推荐

您需要登录后才可以回帖 登录 | 立即注册