当前位置:网站首页>基于k近邻的MNIST图像分类对比
基于k近邻的MNIST图像分类对比
2022-06-21 18:08:00 【WihauShe】
数据集读取
由于数据来源网站不稳定,个人将数据集下载到本地后进行读取
网上多数都是将数据集读取为三维数组方便进行显示,但因计算方便和用sklearn时都是二维数组,所以个人后来修改了下
def decode_idx3_ubyte(idx3_ubyte_file):
""" 解析idx3文件的通用函数 :param idx3_ubyte_file: idx3文件路径 :return: 数据集 """
# 读取二进制数据
bin_data = gzip.open(idx3_ubyte_file, 'rb').read()
# 解析文件头信息,依次为魔数、图片数量、每张图片高、每张图片宽
offset = 0
fmt_header = '>IIII'
# 解析数据集
offset += struct.calcsize(fmt_header)
fmt_image = '>784B'
image_size = 100
# 判断是否是训练集
if 'train' in idx3_ubyte_file:
image_size = 6000
images = np.empty((image_size, 784))
for i in range(image_size):
temp = struct.unpack_from(fmt_image, bin_data, offset)
images[i] = np.reshape(temp, 784)
offset += struct.calcsize(fmt_image)
return images
def decode_idx1_ubyte(idx1_ubyte_file):
""" 解析idx1文件的通用函数 :param idx1_ubyte_file: idx1文件路径 :return: 数据集 """
# 读取二进制数据
bin_data = gzip.open(idx1_ubyte_file, 'rb').read()
# 解析文件头信息,依次为魔数和标签数
offset = 0
fmt_header = '>II'
# 解析数据集
offset += struct.calcsize(fmt_header)
fmt_label = '>B'
label_size = 100
# 判断是否是训练集
if 'train' in idx1_ubyte_file:
label_size = 6000
labels = np.empty(label_size, np.int)
for i in range(label_size):
labels[i] = struct.unpack_from(fmt_label, bin_data, offset)[0]
offset += struct.calcsize(fmt_label)
return labels
这里控制了读取的数量,只使用了原数据集的十分之一
实现k近邻算法
class NearstNeighbour:
def __init__(self, k):
self.k = k
def train(self, X, y):
self.Xtr = X
self.ytr = y
return self
def predict(self, test_images):
predictions = []
# 这段代码借鉴https://github.com/Youngphone/KNN-MNIST/blob/master/KNN-MNIST.ipynb
# 当前运行的测试用例坐标
for test_item in test_images:
datasetsize = self.Xtr.shape[0]
#距离计算公式
diffMat = np.tile(test_item, (datasetsize, 1)) - self.Xtr
sqDiffMat = diffMat ** 2
sqDistances = sqDiffMat.sum(axis = 1)
distances = sqDistances ** 0.5
# 距离从大到小排序,返回距离的序号
sortedDistIndicies = distances.argsort()
# 字典
classCount = {
}
# 前K个距离最小的
for i in range(self.k):
# sortedDistIndicies[0]返回的是距离最小的数据样本的序号
# labels[sortedDistIndicies[0]]距离最小的数据样本的标签
voteIlabel = self.ytr[sortedDistIndicies[i]]
# 若属于某类则权重加一
classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
# 排序
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
predictions.append(sortedClassCount[0][0])
return predictions
与sklearn的k近邻对比
# -*- encoding: utf-8 -*-
''' @File : NearstNeighbour.py @Time : 2021/03/27 15:40:05 @Author : Wihau @Version : 1.0 @Desc : None '''
# here put the import lib
import gzip
import numpy as np
import struct
import operator
import time
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, confusion_matrix
train_images_idx3_ubyte_file = 'train-images-idx3-ubyte.gz'
train_labels_idx1_ubyte_file = 'train-labels-idx1-ubyte.gz'
test_images_idx3_ubyte_file = 't10k-images-idx3-ubyte.gz'
test_labels_idx1_ubyte_file = 't10k-labels-idx1-ubyte.gz'
def decode_idx3_ubyte(idx3_ubyte_file):
""" 解析idx3文件的通用函数 :param idx3_ubyte_file: idx3文件路径 :return: 数据集 """
# 读取二进制数据
bin_data = gzip.open(idx3_ubyte_file, 'rb').read()
# 解析文件头信息,依次为魔数、图片数量、每张图片高、每张图片宽
offset = 0
fmt_header = '>IIII'
# 解析数据集
offset += struct.calcsize(fmt_header)
fmt_image = '>784B'
image_size = 100
# 判断是否是训练集
if 'train' in idx3_ubyte_file:
image_size = 6000
images = np.empty((image_size, 784))
for i in range(image_size):
temp = struct.unpack_from(fmt_image, bin_data, offset)
images[i] = np.reshape(temp, 784)
offset += struct.calcsize(fmt_image)
return images
def decode_idx1_ubyte(idx1_ubyte_file):
""" 解析idx1文件的通用函数 :param idx1_ubyte_file: idx1文件路径 :return: 数据集 """
# 读取二进制数据
bin_data = gzip.open(idx1_ubyte_file, 'rb').read()
# 解析文件头信息,依次为魔数和标签数
offset = 0
fmt_header = '>II'
# 解析数据集
offset += struct.calcsize(fmt_header)
fmt_label = '>B'
label_size = 100
# 判断是否是训练集
if 'train' in idx1_ubyte_file:
label_size = 6000
labels = np.empty(label_size, np.int)
for i in range(label_size):
labels[i] = struct.unpack_from(fmt_label, bin_data, offset)[0]
offset += struct.calcsize(fmt_label)
return labels
def load_train_images(idx_ubyte_file=train_images_idx3_ubyte_file):
return decode_idx3_ubyte(idx_ubyte_file)
def load_train_labels(idx_ubyte_file=train_labels_idx1_ubyte_file):
return decode_idx1_ubyte(idx_ubyte_file)
def load_test_images(idx_ubyte_file=test_images_idx3_ubyte_file):
return decode_idx3_ubyte(idx_ubyte_file)
def load_test_labels(idx_ubyte_file=test_labels_idx1_ubyte_file):
return decode_idx1_ubyte(idx_ubyte_file)
class NearstNeighbour:
def __init__(self, k):
self.k = k
def train(self, X, y):
self.Xtr = X
self.ytr = y
return self
def predict(self, test_images):
predictions = []
# 当前运行的测试用例坐标
for test_item in test_images:
datasetsize = self.Xtr.shape[0]
#距离计算公式
diffMat = np.tile(test_item, (datasetsize, 1)) - self.Xtr
sqDiffMat = diffMat ** 2
sqDistances = sqDiffMat.sum(axis = 1)
distances = sqDistances ** 0.5
# 距离从大到小排序,返回距离的序号
sortedDistIndicies = distances.argsort()
# 字典
classCount = {
}
# 前K个距离最小的
for i in range(self.k):
# sortedDistIndicies[0]返回的是距离最小的数据样本的序号
# labels[sortedDistIndicies[0]]距离最小的数据样本的标签
voteIlabel = self.ytr[sortedDistIndicies[i]]
# 若属于某类则权重加一
classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
# 排序
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
predictions.append(sortedClassCount[0][0])
return predictions
train_images = load_train_images()
train_labels = load_train_labels()
test_images = load_test_images()
test_labels = load_test_labels()
k = 5
# 个人k近邻预测
print("-----Personal k nearest neighbour-----")
# 预测时间
start = time.time()
knn = NearstNeighbour(k)
predictions = knn.train(train_images, train_labels).predict(test_images)
end = time.time()
print("time of prediction:%.3f s" % (end-start))
# 准确率
accuracy = accuracy_score(test_labels, predictions)
print("accuracy score:", accuracy)
# 混淆矩阵
matrix = confusion_matrix(test_labels, predictions)
print(matrix)
# sklearn的k近邻预测
print("-----Sklearn nearest neighbour-----")
# 预测时间
start = time.time()
sknn = KNeighborsClassifier(n_neighbors = k)
skpredictions = sknn.fit(train_images, train_labels).predict(test_images)
end = time.time()
print("time of prediction:%.3f s" % (end-start))
# 准确率
skaccuracy = accuracy_score(test_labels, skpredictions)
print("accuracy score:", skaccuracy)
# 混淆矩阵
skmatrix = confusion_matrix(test_labels, skpredictions)
print(skmatrix)
结果如下
k = 5 时
k = 10 时
边栏推荐
- 尚硅谷 尚硅谷 | 什么是ClickHouse表引擎 Memory和Merge
- Must the database primary key be self incremented? What scenarios do not suggest self augmentation?
- 出院小结识别api接口-医疗票据OCR识别/出院诊断记录/电子病历/理赔服务
- How to use devaxpress WPF to create the first MVVM application in winui?
- 2022年6月25日PMP考试通关宝典-4
- 剑指 Offer II 029. 排序的循环链表
- 力扣今日题1108. IP 地址无效化
- Flink 系例 之 TableAPI & SQL 与 示例模块
- 文献分析 Citespace 6.1.2 下载及安装教程
- 2022年下半年深圳地区数据分析师认证(CPDA),[进入查看]
猜你喜欢

API interface for discharge summary identification - medical bill OCR identification / discharge diagnosis record / electronic medical record / claim settlement service

How to use devaxpress WPF to create the first MVVM application in winui?

Mvcc implementation principle of MySQL

Medical expense list can be entered at a second speed, and OCR recognition can help double the efficiency
Must the database primary key be self incremented? What scenarios do not suggest self augmentation?

如何使用DevExpress WPF在WinUI中创建第一个MVVM应用?

Tableapi & SQL and example module of Flink

尚硅谷 尚硅谷 | 什么是ClickHouse表引擎 Memory和Merge

在Qt中设置程序图标的方法介绍

如何在Chrome浏览器中模拟请求或修改请求的域名
随机推荐
网管型全国产加固交换机如何创建网络冗余
企评家全面解读:【国家电网】中国电力财务有限公司企业成长性
11 Beautiful Soup 解析库的简介及安装
2022年6月25日PMP考试通关宝典-4
Medical expense list can be entered at a second speed, and OCR recognition can help double the efficiency
[high frequency interview questions] linked list interview questions with 1/5 difficulty and lower difficulty
Kubernetes 跨 StorageClass 迁移 Persistent Volumes 完全指南
Literature analysis CiteSpace 6.1.2 download and installation tutorial
Linux MySQL command
如何使用DevExpress WPF在WinUI中创建第一个MVVM应用?
508. Most Frequent Subtree Sum
6月22日直播 | 华南理工詹志辉: 面向昂贵优化的进化计算
如何在Chrome浏览器中模拟请求或修改请求的域名
第298场周赛
The R language catiols package divides the data, randomforest package constructs the random forest model, uses the importance function to calculate the importance of each feature in the random forest
2022年6月25日PMP考试通关宝典-3
三叶的小伙伴们の经历分享 : 千秋澪(千秋总)
Must the database primary key be self incremented? What scenarios do not suggest self augmentation?
《Go题库·9》同一个协程里面,对无缓冲channel同时发送和接收数据有什么问题
11 Beautiful Soup 解析庫的簡介及安裝