当前位置:网站首页>Tensorflow2 keras 分类模型
Tensorflow2 keras 分类模型
2022-07-02 06:29:00 【图图是只猫】
from pickletools import optimize
from pyexpat import model
from re import X
from tkinter import Y
import matplotlib as mpl
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time
from sklearn import metrics
import tensorflow as tf
from tensorflow import keras
#数据集
fashion_mnist = keras.datasets.fashion_mnist
#训练集和测试集
(x_train_all,y_train_all),(x_test,y_test) = fashion_mnist.load_data()
#验证集和训练集
x_valid,x_train = x_train_all[:5000],x_train_all[5000:]
y_valid,y_train = y_train_all[:5000],y_train_all[5000:]
# 训练集归一化
# x = (x - u)/ std :x - 均值 / 方差
scaler = StandardScaler()
x_train_scaled = scaler.fit_transform(
x_train.astype(np.float32).reshape(-1,1)).reshape(-1,28,28)
x_valid_scaled = scaler.transform(x_valid.astype(np.float32).reshape(-1,1)).reshape(-1,28,28)
x_test_scaled = scaler.transform(x_test.astype(np.float32).reshape(-1,1)).reshape(-1,28,28)
def show_single_image(img_arr):
plt.imshow(img_arr,cmap="binary")
plt.show()
def show_imgs(n_rows,n_cols,x_data,y_data,class_names):
assert len(x_data) == len(y_data)
assert n_rows * n_cols < len(x_data)
#指定图像宽高 英尺单位
plt.figure(figsize=(n_cols * 1.4,n_rows * 1.6))
for row in range(n_rows):
for col in range(n_cols):
index = n_cols * row + col
#创建单个子图
plt.subplot(n_rows,n_cols,index + 1)
plt.imshow(x_data[index],cmap="binary",interpolation='nearest')
plt.axis('off')
plt.title(class_names[y_data[index]])
plt.show()
class_names = ['T-shirt','Trouser','Pullover','Dress','Coat','Sandal','Shirt','Sneaker','Bag','Ankle boot']
#show_imgs(3,5,x_train,y_train,class_names)
#添加模型 sequential线性堆叠模型
model = keras.models.Sequential()
#将28*28的矩阵展平为一维向量
model.add(keras.layers.Flatten(input_shape=[28,28]))
#Dense:每一层的输入来自前面所有层的输出->解决梯度消失的问题
#梯度消失和梯度爆炸:计算深度增加导致求导数据持续过低(0-0.25)或过高(1)
model.add(keras.layers.Dense(300,activation="relu"))
#此100单元与300单元做全联接
#relu:y = max(0,x) 大于0返回x
#softmax:将向量变成概率分布 x = [x1,x2,x3]
# y = [e^x1/sum, e^x2/sum,e^x3/sum] sum = e^x1/sum+e^x2/sum+e^x3/sum
model.add(keras.layers.Dense(100,activation="relu"))
model.add(keras.layers.Dense(10,activation="softmax"))
# sparse_categorical_crossentropy: y是一个数值需要将 y->one_hot->[] 转化为向量,如果是向量需要用categorical_crossentropy
# optimize 模型调整方法
# metrics
# optimizer="adam" sgd ->梯度优化算法
model.compile(loss="sparse_categorical_crossentropy",optimizer="adam", metrics = ["accuracy"])
#模型架构显示
#架构参数:
#1层 [None,784] [样本数*784]
#2层 第一层转化为 [None,300] :[none,784] * w + b -> [none,300] w.shape[784,300], b=[300]
model.summary()
#结果验证
history = model.fit(x_train_scaled,y_train,epochs=10,validation_data=(x_valid_scaled,y_valid))
# history.history
#结果估值
print(model.evaluate(x_test_scaled,y_test))
#结果可视化
def plot_learning_curves(history):
pd.DataFrame(history.history).plot(figsize=(8,5))
plt.grid(True)
plt.gca().set_ylim(0,1)
plt.show()
plot_learning_curves(history)
边栏推荐
- Minecraft群組服開服
- Analysis of the use of comparable, comparator and clonable interfaces
- Development of digital collection trading website development of metauniverse digital collection
- zipkin 简单使用
- sqli-labs第1关
- Sentinel 简单使用
- OpenShift 部署应用
- The best blog to explain the basics of compilation (share)
- Web security -- Logical ultra vires
- SQL operation database syntax
猜你喜欢

Tcp/ip - transport layer

Minecraft群组服开服

Comparable,Comparator,Clonable 接口使用剖析

Matlab mathematical modeling tool

OpenFeign 簡單使用

OpenShift 容器平台社区版 OKD 4.10.0部署

Analysis of the use of comparable, comparator and clonable interfaces

When a custom exception encounters reflection

cve_ 2019_ 0708_ bluekeep_ Rce vulnerability recurrence

C language replaces spaces in strings with%20
随机推荐
What are the platforms for selling green label domain names? What is the green label domain name like?
Sqli labs level 1
c语言自定义类型枚举,联合(枚举的巧妙使用,联合体大小的计算)
Zipkin is easy to use
Hcia - Application Layer
【无标题】
Valin cable: BI application promotes enterprise digital transformation
文件上传-upload-labs
Programmer training, crazy job hunting, overtime ridiculed by colleagues deserve it
Pclpy projection filter -- projection of point cloud to cylinder
Rotating linked list (illustration)
Causes of laptop jam
When a custom exception encounters reflection
sqli-labs第2关
程序猿学英语-指令式编程
St-link connection error invalid ROM table of STM32 difficult and miscellaneous diseases
IP protocol and IP address
Short video with goods source code, double-click to zoom in when watching the video
SQL operation database syntax
Flex layout