当前位置:网站首页>Stanford cs231n course assignment - nearest neighbor classifier
Stanford cs231n course assignment - nearest neighbor classifier
2022-07-23 05:10:00 【Fu_ Tianshu】
Course website :http://cs231n.stanford.edu/
Course materials :http://cs231n.stanford.edu/syllabus.html
Course PDF:http://cs231n.stanford.edu/slides/2020/lecture_2.pdf
Operation data :https://cs231n.github.io/classification/#k—nearest-neighbor-classifier
CIFAR-10 Data official website :http://www.cs.toronto.edu/~kriz/cifar.html
Here is the code , Note the loaded file path
import numpy as np
import os
import pickle
import time
class NearestNeighbor(object):
def __init__(self):
pass
# There is no training on this network , Just load all training data into memory
def train(self, X, y):
# X by 50000x3072 Array of ,y by 50000x1 Array of
self.Xtr = X
self.ytr = y
def predict(self, X):
# Get the number of test data , In this case 10000
num_test = X.shape[0]
# Generate a 10000x10000 Of all the 0 matrix , Used as a storage test to generate labels, The type of element is the same as that in the training data labels identical , In this example, it should be int32
Ypred = np.zeros(num_test, dtype=self.ytr.dtype)
# Cycle through each picture to be tested , common 10000 Time
for i in range(num_test):
print("testing %d" % i)
# Use L1 distance , Calculate all training pictures to i Distance between test pictures
# np.abs To calculate the absolute value
# np.sum(..., axis=1) In order to compare the results according to paragraph 2 Sum the expansion direction of the coordinate axes
# distances = np.sum(np.abs(self.Xtr - X[i,:]), axis=1)
# Use L2 distance , Calculate all training pictures to i Distance between test pictures
# np.square Square each element in the array
# np.sum(..., axis=1) In order to compare the results according to paragraph 2 Sum the expansion direction of the coordinate axes
# np.sqrt For square root
distances = np.sqrt(np.sum(np.square(self.Xtr - X[i, :]), axis=1))
# np.argmin Used to find out distances The smallest element in ( That is, the training image closest to the test image ) Of index
min_index = np.argmin(distances)
Ypred[i] = self.ytr[min_index]
return Ypred
def load_CIFAR10(path):
xs = []
ys = []
# loop ,b In turn 1,2,3,4,5
for b in range(1,6):
# os.path.join Used to splice file paths
f = os.path.join(path, 'data_batch_%d' % (b,))
# Load each data_batch_x file
X, Y = load_CIFAR_batch(f)
# load 5 Time ,xs by 50000 Data of pictures ,ys by 50000 individual 0-9 The number of
xs.append(X)
ys.append(Y)
# np.concatenate For splicing arrays ,???
Xtr = np.concatenate(xs)
Ytr = np.concatenate(ys)
# Delete variables X and Y,???
del X, Y
# load test_batch file
Xte, Yte = load_CIFAR_batch(os.path.join(path, 'test_batch'))
return Xtr, Ytr, Xte, Yte
def load_CIFAR_batch(filename):
# with It can close automatically after using up files (close) file ,'rb' Is a binary read-only file
with open(filename, 'rb') as f:
# pickle.load Deserialize the target into an object
datadict = pickle.load(f, encoding='latin1')
# batch After the file is deserialized, a dictionary is obtained , contain 'data' and 'labels' Two key
''' data It's a 10000x3072 Of numpy Array , Each row of the array stores one 32x32 Color image of . front 1024 Numbers are red (red), middle 1024 It's green (green), Last 1024 For blue (blue). labels It's a 10000 individual 0-9 List of numbers for . '''
X = datadict['data']
Y = datadict['labels']
# reshape Used to change the format of data ,transpose The different axes used to exchange tensors ,astype To change np.array The data type of all data elements in the .
X = X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
# hold Y The format of becomes np.array
Y = np.array(Y)
return X, Y
def runNN():
print("load data")
# Xtr For the picture data in the training data ,Ytr Label data in training data
# Xte For the picture data in the test data ,Yte Label data in test data
Xtr, Ytr, Xte, Yte = load_CIFAR10('data/cifar-10-batches-py')
# hold Xtr and Xte The format of the two variables becomes 50000x3072 Array of
Xtr_rows = Xtr.reshape(Xtr.shape[0], 32 * 32 * 3)
Xte_rows = Xte.reshape(Xte.shape[0], 32 * 32 * 3)
# Construct a NearestNeighbor() Class object nn
nn = NearestNeighbor()
print("start training")
# Use training data to nn This network conducts training
nn.train(Xtr_rows, Ytr)
print("start testing")
# Use the trained network nn To test
Yte_predict = nn.predict(Xte_rows)
# Statistical test accuracy
print('accuracy: %f' % (np.mean(Yte_predict == Yte)))
if __name__ == '__main__':
print (" Starting time : ", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
runNN()
print (" End time : ", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
Use L1 Print the distance
accuracy: 0.385900
Use L2 Print the distance
accuracy: 0.353900
边栏推荐
猜你喜欢

*Resources * super practical academic paper learning website and English paper writing (free)

*环境配置*win10安装cuda,cudnn,pytorch-gpu,tensorflow-gpu后测试

jupyter import包失败

解决anaconda navigator打不开的问题.

数字孪生示范项目——从单摆谈起(3)实体模型探索

*论文篇*lightweightnetwork轻量级网络论文速览,持续更新中。。。
![[deep learning] loss function (mean absolute error, mean square error, smoothing loss, cross entropy, cross entropy with weight, dice loss, focalloss)](/img/43/a29839d3c8b122c9fc2b49d26fd1b8.gif)
[deep learning] loss function (mean absolute error, mean square error, smoothing loss, cross entropy, cross entropy with weight, dice loss, focalloss)

合工大苍穹战队视觉组培训Day3——机器学习,强化使用YOLO模型,学习南瓜书,西瓜书

Center_ Loss experiment on MNIST

*After inputting the picture size, the number of channels changes and the printing of network structure and parameter quantity, which is suitable for querying your own network parameter size
随机推荐
机器人操作持续学习论文(1)原文阅读与翻译——机器人操作中无灾难性遗忘的原语生成策略学习
Binary SCA fingerprint extraction black Technology: go language Reverse Technology
*编码理解*cv2中应该了解的那些基本函数(加代码练习)
Center_ Loss experiment on MNIST
*编码理解*numpy中那些必须要理解的基本(加代码)
*编码理解*Pytorch中常见的函数解析
[deep learning] loss function (mean absolute error, mean square error, smoothing loss, cross entropy, cross entropy with weight, dice loss, focalloss)
Espressif 8266 AT+MQTT连接AWS IoT
Tensorflow for MNIST handwritten numeral recognition
Center_loss-在mnist上实验
【北交】图像处理:基本概念、图像增强、形态学处理、图像分割
Druid源码阅读2-DruidDataSource的init过程
Espressif esp-aws-iot 入门
tensorflow——tf.train.slice_input_producer,tf.train.string_input_producer两种队列批量读取方式研究
Jetpack chapter - Overview
Deep learning series -- alxenet realizes MNIST handwritten numeral recognition
决策树——ID3、C4.5、CART
51nod 1677 treecnt (tree DP, inverse element, contribution)
*Thesis * understanding of attention mechanism se thesis
*Papers *lightweightnetwork lightweight network paper quick view, constantly updating...