当前位置:网站首页>MNIST image classification and comparison based on K-nearest neighbor
MNIST image classification and comparison based on K-nearest neighbor
2022-06-22 00:04:00 【WihauShe】
Data set read
Because the data source website is unstable , Individuals download data sets locally and then read them
Most of the data sets on the Internet are read into three-dimensional arrays for easy display , However, due to the convenience of calculation and use sklearn Is a two-dimensional array , So I revised it later
def decode_idx3_ubyte(idx3_ubyte_file):
""" analysis idx3 General function of file :param idx3_ubyte_file: idx3 File path :return: Data sets """
# Read binary data
bin_data = gzip.open(idx3_ubyte_file, 'rb').read()
# Parsing file header information , The order is magic number 、 graphics 、 Each picture is high 、 Each picture is wide
offset = 0
fmt_header = '>IIII'
# Parsing data sets
offset += struct.calcsize(fmt_header)
fmt_image = '>784B'
image_size = 100
# Determine whether it is a training set
if 'train' in idx3_ubyte_file:
image_size = 6000
images = np.empty((image_size, 784))
for i in range(image_size):
temp = struct.unpack_from(fmt_image, bin_data, offset)
images[i] = np.reshape(temp, 784)
offset += struct.calcsize(fmt_image)
return images
def decode_idx1_ubyte(idx1_ubyte_file):
""" analysis idx1 General function of file :param idx1_ubyte_file: idx1 File path :return: Data sets """
# Read binary data
bin_data = gzip.open(idx1_ubyte_file, 'rb').read()
# Parsing file header information , Then magic number and tag number
offset = 0
fmt_header = '>II'
# Parsing data sets
offset += struct.calcsize(fmt_header)
fmt_label = '>B'
label_size = 100
# Determine whether it is a training set
if 'train' in idx1_ubyte_file:
label_size = 6000
labels = np.empty(label_size, np.int)
for i in range(label_size):
labels[i] = struct.unpack_from(fmt_label, bin_data, offset)[0]
offset += struct.calcsize(fmt_label)
return labels
This controls the number of reads , Only one tenth of the original data set is used
Realization k Nearest neighbor algorithm
class NearstNeighbour:
def __init__(self, k):
self.k = k
def train(self, X, y):
self.Xtr = X
self.ytr = y
return self
def predict(self, test_images):
predictions = []
# This code uses https://github.com/Youngphone/KNN-MNIST/blob/master/KNN-MNIST.ipynb
# Coordinates of the currently running test case
for test_item in test_images:
datasetsize = self.Xtr.shape[0]
# Distance calculation formula
diffMat = np.tile(test_item, (datasetsize, 1)) - self.Xtr
sqDiffMat = diffMat ** 2
sqDistances = sqDiffMat.sum(axis = 1)
distances = sqDistances ** 0.5
# The distance is sorted from large to small , Returns the sequence number of the distance
sortedDistIndicies = distances.argsort()
# Dictionaries
classCount = {
}
# front K The smallest distance
for i in range(self.k):
# sortedDistIndicies[0] The sequence number of the data sample with the smallest distance is returned
# labels[sortedDistIndicies[0]] The label of the data sample with the smallest distance
voteIlabel = self.ytr[sortedDistIndicies[i]]
# If it belongs to a certain category, the weight will be increased by one
classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
# Sort
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
predictions.append(sortedClassCount[0][0])
return predictions
And sklearn Of k Nearest neighbor comparison
# -*- encoding: utf-8 -*-
''' @File : NearstNeighbour.py @Time : 2021/03/27 15:40:05 @Author : Wihau @Version : 1.0 @Desc : None '''
# here put the import lib
import gzip
import numpy as np
import struct
import operator
import time
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, confusion_matrix
train_images_idx3_ubyte_file = 'train-images-idx3-ubyte.gz'
train_labels_idx1_ubyte_file = 'train-labels-idx1-ubyte.gz'
test_images_idx3_ubyte_file = 't10k-images-idx3-ubyte.gz'
test_labels_idx1_ubyte_file = 't10k-labels-idx1-ubyte.gz'
def decode_idx3_ubyte(idx3_ubyte_file):
""" analysis idx3 General function of file :param idx3_ubyte_file: idx3 File path :return: Data sets """
# Read binary data
bin_data = gzip.open(idx3_ubyte_file, 'rb').read()
# Parsing file header information , The order is magic number 、 graphics 、 Each picture is high 、 Each picture is wide
offset = 0
fmt_header = '>IIII'
# Parsing data sets
offset += struct.calcsize(fmt_header)
fmt_image = '>784B'
image_size = 100
# Determine whether it is a training set
if 'train' in idx3_ubyte_file:
image_size = 6000
images = np.empty((image_size, 784))
for i in range(image_size):
temp = struct.unpack_from(fmt_image, bin_data, offset)
images[i] = np.reshape(temp, 784)
offset += struct.calcsize(fmt_image)
return images
def decode_idx1_ubyte(idx1_ubyte_file):
""" analysis idx1 General function of file :param idx1_ubyte_file: idx1 File path :return: Data sets """
# Read binary data
bin_data = gzip.open(idx1_ubyte_file, 'rb').read()
# Parsing file header information , Then magic number and tag number
offset = 0
fmt_header = '>II'
# Parsing data sets
offset += struct.calcsize(fmt_header)
fmt_label = '>B'
label_size = 100
# Determine whether it is a training set
if 'train' in idx1_ubyte_file:
label_size = 6000
labels = np.empty(label_size, np.int)
for i in range(label_size):
labels[i] = struct.unpack_from(fmt_label, bin_data, offset)[0]
offset += struct.calcsize(fmt_label)
return labels
def load_train_images(idx_ubyte_file=train_images_idx3_ubyte_file):
return decode_idx3_ubyte(idx_ubyte_file)
def load_train_labels(idx_ubyte_file=train_labels_idx1_ubyte_file):
return decode_idx1_ubyte(idx_ubyte_file)
def load_test_images(idx_ubyte_file=test_images_idx3_ubyte_file):
return decode_idx3_ubyte(idx_ubyte_file)
def load_test_labels(idx_ubyte_file=test_labels_idx1_ubyte_file):
return decode_idx1_ubyte(idx_ubyte_file)
class NearstNeighbour:
def __init__(self, k):
self.k = k
def train(self, X, y):
self.Xtr = X
self.ytr = y
return self
def predict(self, test_images):
predictions = []
# Coordinates of the currently running test case
for test_item in test_images:
datasetsize = self.Xtr.shape[0]
# Distance calculation formula
diffMat = np.tile(test_item, (datasetsize, 1)) - self.Xtr
sqDiffMat = diffMat ** 2
sqDistances = sqDiffMat.sum(axis = 1)
distances = sqDistances ** 0.5
# The distance is sorted from large to small , Returns the sequence number of the distance
sortedDistIndicies = distances.argsort()
# Dictionaries
classCount = {
}
# front K The smallest distance
for i in range(self.k):
# sortedDistIndicies[0] The sequence number of the data sample with the smallest distance is returned
# labels[sortedDistIndicies[0]] The label of the data sample with the smallest distance
voteIlabel = self.ytr[sortedDistIndicies[i]]
# If it belongs to a certain category, the weight will be increased by one
classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
# Sort
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
predictions.append(sortedClassCount[0][0])
return predictions
train_images = load_train_images()
train_labels = load_train_labels()
test_images = load_test_images()
test_labels = load_test_labels()
k = 5
# personal k Nearest neighbor prediction
print("-----Personal k nearest neighbour-----")
# Forecast time
start = time.time()
knn = NearstNeighbour(k)
predictions = knn.train(train_images, train_labels).predict(test_images)
end = time.time()
print("time of prediction:%.3f s" % (end-start))
# Accuracy rate
accuracy = accuracy_score(test_labels, predictions)
print("accuracy score:", accuracy)
# Confusion matrix
matrix = confusion_matrix(test_labels, predictions)
print(matrix)
# sklearn Of k Nearest neighbor prediction
print("-----Sklearn nearest neighbour-----")
# Forecast time
start = time.time()
sknn = KNeighborsClassifier(n_neighbors = k)
skpredictions = sknn.fit(train_images, train_labels).predict(test_images)
end = time.time()
print("time of prediction:%.3f s" % (end-start))
# Accuracy rate
skaccuracy = accuracy_score(test_labels, skpredictions)
print("accuracy score:", skaccuracy)
# Confusion matrix
skmatrix = confusion_matrix(test_labels, skpredictions)
print(skmatrix)
give the result as follows
k = 5 when 
k = 10 when 
边栏推荐
- 外部排序的基本内容
- Basic contents of external sorting
- Object partition
- The solution to the error "xxx.pri has modification time XXXX s in the futrue" in the compilation of domestic Kirin QT
- JS implementation of Fibonacci sequence
- 7. target detection
- Student management system experiment report -asp Net programming
- 關於 麒麟系統開發錯誤“fatal error: GL/gl.h: No such file or directory“ 的解决方法
- Analysis of Eureka
- Mono 的創建
猜你喜欢

Today's sleep quality record 81 points

What if the program input point cannot be located in the dynamic link library

洞见数据价值,启迪数字未来,《数字化的力量》问世

ERP已死,管理后台已凉,秒杀系统称王!

麒麟系统开发笔记(五):制作安装麒麟系统的启动U盘、物理机安装麒麟系统以及搭建Qt开发环境

Win11 hotspot connection successful but no network? Solution of win11 mobile hotspot and network conflict

转载:网络加载框架 - Retrofit

ERP is dead, the management background is cold, and seckill system is king!

Layout roadmap, the perfect combination of spatial layout and data visualization

如何使用tensorboard add_histogram
随机推荐
[technical remarks] [reprint]analysis of several parameters of ffmpeg compressed video
Win11 hotspot connection successful but no network? Solution of win11 mobile hotspot and network conflict
Go language learning tutorial (12)
小程序与工业互联网是怎样相辅相成的
[Database Course Design] classroom information management system based on SQL Server (with part of source code)
所谓的0拷贝不就是为了让CPU休息吗?深入理解mmap
Hardware development notes (IV): basic process of hardware development, making a USB to RS232 module (III): design schematic diagram
JS implementation of Fibonacci sequence
学生管理系统实验报告-asp.net程序设计
211高校神級碩士論文刷屏!75行字錯了20行!學校回應:導師停招...
JS listening and removing listening events
What if the word selection box is not displayed for win11 typing? Solution of not displaying word selection box in win11 typing
请问东方财富期货正规吗?这家平台安全靠谱么?
关于 QtCreator的设计器QtDesigner完全无法正常拽托控件 的解决方法
数据库每日一题---第19天:排名靠前的旅行者
Student management system experiment report -asp Net programming
Création de mono
Operate files through QT drag events
How to open a VIP account in flush? Is it safe?
eureka的解析