当前位置:网站首页>Tensorflow2 keras 分类模型
Tensorflow2 keras 分类模型
2022-07-02 06:29:00 【图图是只猫】
from pickletools import optimize
from pyexpat import model
from re import X
from tkinter import Y
import matplotlib as mpl
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time
from sklearn import metrics
import tensorflow as tf
from tensorflow import keras
#数据集
fashion_mnist = keras.datasets.fashion_mnist
#训练集和测试集
(x_train_all,y_train_all),(x_test,y_test) = fashion_mnist.load_data()
#验证集和训练集
x_valid,x_train = x_train_all[:5000],x_train_all[5000:]
y_valid,y_train = y_train_all[:5000],y_train_all[5000:]
# 训练集归一化
# x = (x - u)/ std :x - 均值 / 方差
scaler = StandardScaler()
x_train_scaled = scaler.fit_transform(
x_train.astype(np.float32).reshape(-1,1)).reshape(-1,28,28)
x_valid_scaled = scaler.transform(x_valid.astype(np.float32).reshape(-1,1)).reshape(-1,28,28)
x_test_scaled = scaler.transform(x_test.astype(np.float32).reshape(-1,1)).reshape(-1,28,28)
def show_single_image(img_arr):
plt.imshow(img_arr,cmap="binary")
plt.show()
def show_imgs(n_rows,n_cols,x_data,y_data,class_names):
assert len(x_data) == len(y_data)
assert n_rows * n_cols < len(x_data)
#指定图像宽高 英尺单位
plt.figure(figsize=(n_cols * 1.4,n_rows * 1.6))
for row in range(n_rows):
for col in range(n_cols):
index = n_cols * row + col
#创建单个子图
plt.subplot(n_rows,n_cols,index + 1)
plt.imshow(x_data[index],cmap="binary",interpolation='nearest')
plt.axis('off')
plt.title(class_names[y_data[index]])
plt.show()
class_names = ['T-shirt','Trouser','Pullover','Dress','Coat','Sandal','Shirt','Sneaker','Bag','Ankle boot']
#show_imgs(3,5,x_train,y_train,class_names)
#添加模型 sequential线性堆叠模型
model = keras.models.Sequential()
#将28*28的矩阵展平为一维向量
model.add(keras.layers.Flatten(input_shape=[28,28]))
#Dense:每一层的输入来自前面所有层的输出->解决梯度消失的问题
#梯度消失和梯度爆炸:计算深度增加导致求导数据持续过低(0-0.25)或过高(1)
model.add(keras.layers.Dense(300,activation="relu"))
#此100单元与300单元做全联接
#relu:y = max(0,x) 大于0返回x
#softmax:将向量变成概率分布 x = [x1,x2,x3]
# y = [e^x1/sum, e^x2/sum,e^x3/sum] sum = e^x1/sum+e^x2/sum+e^x3/sum
model.add(keras.layers.Dense(100,activation="relu"))
model.add(keras.layers.Dense(10,activation="softmax"))
# sparse_categorical_crossentropy: y是一个数值需要将 y->one_hot->[] 转化为向量,如果是向量需要用categorical_crossentropy
# optimize 模型调整方法
# metrics
# optimizer="adam" sgd ->梯度优化算法
model.compile(loss="sparse_categorical_crossentropy",optimizer="adam", metrics = ["accuracy"])
#模型架构显示
#架构参数:
#1层 [None,784] [样本数*784]
#2层 第一层转化为 [None,300] :[none,784] * w + b -> [none,300] w.shape[784,300], b=[300]
model.summary()
#结果验证
history = model.fit(x_train_scaled,y_train,epochs=10,validation_data=(x_valid_scaled,y_valid))
# history.history
#结果估值
print(model.evaluate(x_test_scaled,y_test))
#结果可视化
def plot_learning_curves(history):
pd.DataFrame(history.history).plot(figsize=(8,5))
plt.grid(True)
plt.gca().set_ylim(0,1)
plt.show()
plot_learning_curves(history)
边栏推荐
- Flex layout
- Gateway is easy to use
- Makefile基本原理
- Global and Chinese market of electric cheese grinder 2022-2028: Research Report on technology, participants, trends, market size and share
- Programmer training, crazy job hunting, overtime ridiculed by colleagues deserve it
- Googlenet network explanation and model building
- Carsim 学习心得-粗略翻译1
- OpenFeign 简单使用
- Detailed explanation of NIN network
- Nacos 下载启动、配置 MySQL 数据库
猜你喜欢

Honeypot attack and defense drill landing application scheme

How to build the alliance chain? How much is the development of the alliance chain

顺序表基本功能函数的实现

Application of kotlin - higher order function

IP protocol and IP address

sqli-labs第2关
![[blackmail virus data recovery] suffix Crylock blackmail virus](/img/b2/8e3a65dd250b9194cfc175138c740c.jpg)
[blackmail virus data recovery] suffix Crylock blackmail virus

文件上传-upload-labs

Minecraft安装资源包

STM32 new project (refer to punctual atom)
随机推荐
Minecraft插件服开服
k8s入门:Helm 构建 MySQL
Chrome debugging
Minecraft模组服开服
2022 Heilongjiang latest construction eight members (materialman) simulated examination questions and answers
PCL calculates the intersection of three mutually nonparallel planes
What is SQL injection
Openfeign is easy to use
The best blog to explain the basics of compilation (share)
Minecraft群组服开服
Routing foundation - dynamic routing
IP协议与IP地址
How to uninstall SQL Server cleanly
HCIA—数据链路层
Flex layout
Simple implementation scheme of transcoding and streaming (I)
[blackmail virus data recovery] suffix Crylock blackmail virus
Mutex
Linked list classic interview questions (reverse the linked list, middle node, penultimate node, merge and split the linked list, and delete duplicate nodes)
Global and Chinese markets of tilting feeders 2022-2028: Research Report on technology, participants, trends, market size and share