当前位置:网站首页>TensoFlow学习记录(二):基础操作
TensoFlow学习记录(二):基础操作
2022-08-04 01:05:00 【狸狸Arina】
1. 数据类型
- 数据类型有整型(默认是int32),浮点型(默认是float32),以及布尔型,字符串;
- Tensor 常见属性:
- .device 当前设备信息;
- .numpy() 将Tensor转换为numpy格式;
- .ndim .shape .rank 返回Tensor形状;
- .is_tensor 判断是否是Tensor 类型
- 数据类型转换
- tf.convert_to_tensor 将数据转换成tensor类型,当从Numpy转换成tensor的时候,会默认是int64,需要指定一下类型,才能成为tf默认的类型也就是int32;
- tf.cast() 实现tensor的数据类型转换;
- 求导特性
- 对参数求梯度的,需要variable包装一下,就拥有了trainable属性,这样就才求梯度。假如是自己写传播过程,更新后的参数也需要用variable包装;
- 对参数求梯度的,需要variable包装一下,就拥有了trainable属性,这样就才求梯度。假如是自己写传播过程,更新后的参数也需要用variable包装;
- To numpy
2. 创建Tensor
- from numpy, list
- tf.zeros
- tf.zeros_like
- tf.ones
- Fill
- Normal 正态分布,传入形状,可指定均值方差
- truncated_normal 裁剪过后的数据,裁去了前后分布太少的数据,只从中间数据多的地方取数据,同样可以指定均值方差。
- Uniform 均匀分布初始化,形状,最小值,最大值
- random.shuffle 随机打散,可以打散一个索引顺序,通过tf.gather去对应,这样可以实现两个同样行数的数据,进行索引一一对应的随机打散
- tf.constant
3. 索引与切片
- Basic indexing
- Numpy-style indexing
- start:end
- Indexing by :
- Indexing by ::
- ::-1
- …
- tf.gather()
- 输入参数:数据、维度、索引
- 输入参数:数据、维度、索引
- tf.gather_nd
- 前面输入数据,后面填取的联合维度。只把最内层的括号当做联合索引的坐标;
- 前面输入数据,后面填取的联合维度。只把最内层的括号当做联合索引的坐标;
- tf.boolean_mask
- 按布尔值索引,不指定维度相当于是第一个维度,指定axis就会根据axis去索引。给索引矩阵也可以;
- 按布尔值索引,不指定维度相当于是第一个维度,指定axis就会根据axis去索引。给索引矩阵也可以;
4. 维度变换
- Reshape
- tf.transpose 转置,perm数字指的是数字所在位置上放哪一个原来的维度
- pytorch中图像存储维度是[b,c,h,w],tf中是[b,h,w,c];
- pytorch中图像存储维度是[b,c,h,w],tf中是[b,h,w,c];
- tf.expand_dims
- tf.squeeze 可以去掉为1的维度。不指定维度的话就去掉所有的为1的维度
5. Broadcasting
- tf.broadcast_to
- Broadcast VS Tile
6. 数学运算
- ±*/%//
- tf.math.log tf.exp
- log2, log10?
- pow, sqrt
- @ matmul
7. 前向传播
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# x: [60k, 28, 28],
# y: [60k]
(x, y), _ = datasets.mnist.load_data()
# x: [0~255] => [0~1.]
x = tf.convert_to_tensor(x, dtype=tf.float32) / 255.
y = tf.convert_to_tensor(y, dtype=tf.int32)
print(x.shape, y.shape, x.dtype, y.dtype)
print(tf.reduce_min(x), tf.reduce_max(x))
print(tf.reduce_min(y), tf.reduce_max(y))
train_db = tf.data.Dataset.from_tensor_slices((x,y)).batch(128)
train_iter = iter(train_db)
sample = next(train_iter)
print('batch:', sample[0].shape, sample[1].shape)
# [b, 784] => [b, 256] => [b, 128] => [b, 10]
# [dim_in, dim_out], [dim_out]
w1 = tf.Variable(tf.random.truncated_normal([784, 256], stddev=0.1))
b1 = tf.Variable(tf.zeros([256]))
w2 = tf.Variable(tf.random.truncated_normal([256, 128], stddev=0.1))
b2 = tf.Variable(tf.zeros([128]))
w3 = tf.Variable(tf.random.truncated_normal([128, 10], stddev=0.1))
b3 = tf.Variable(tf.zeros([10]))
lr = 1e-3
for epoch in range(10): # iterate db for 10
for step, (x, y) in enumerate(train_db): # for every batch
# x:[128, 28, 28]
# y: [128]
# [b, 28, 28] => [b, 28*28]
x = tf.reshape(x, [-1, 28*28])
with tf.GradientTape() as tape: # tf.Variable
# x: [b, 28*28]
# h1 = [email protected] + b1
# [b, 784]@[784, 256] + [256] => [b, 256] + [256] => [b, 256] + [b, 256]
h1 = [email protected] + tf.broadcast_to(b1, [x.shape[0], 256])
h1 = tf.nn.relu(h1)
# [b, 256] => [b, 128]
h2 = [email protected] + b2
h2 = tf.nn.relu(h2)
# [b, 128] => [b, 10]
out = [email protected] + b3
# compute loss
# out: [b, 10]
# y: [b] => [b, 10]
y_onehot = tf.one_hot(y, depth=10)
# mse = mean(sum(y-out)^2)
# [b, 10]
loss = tf.square(y_onehot - out)
# mean: scalar
loss = tf.reduce_mean(loss)
# compute gradients
grads = tape.gradient(loss, [w1, b1, w2, b2, w3, b3])
# print(grads)
# w1 = w1 - lr * w1_grad
w1.assign_sub(lr * grads[0])
b1.assign_sub(lr * grads[1])
w2.assign_sub(lr * grads[2])
b2.assign_sub(lr * grads[3])
w3.assign_sub(lr * grads[4])
b3.assign_sub(lr * grads[5])
if step % 100 == 0:
print(epoch, step, 'loss:', float(loss))
边栏推荐
- nodeJs--async/await
- nodejs+npm的安装与配置
- 优秀的测试/开发程序员,是怎样修炼的?步步为营地去执行......
- typescript58 - generic classes
- nodejs+express realizes the access to the database mysql and displays the data on the page
- dynamic memory two
- typescript52 - simplify generic function calls
- MongoDB数据接入实践
- 微服务的简单介绍
- 114. 如何通过单步调试的方式找到引起 Fiori Launchpad 路由错误的原因
猜你喜欢
随机推荐
How to find the cause of Fiori Launchpad routing errors by single-step debugging
C 学生管理系统_添加学生
贴纸拼词 —— 记忆化搜索 / 状压DP
Mvc, Mvp and Mvvm
dynamic memory two
Modulo operation (MOD)
.NET静态代码织入——肉夹馍(Rougamo) 发布1.1.0
快速入门EasyX图形编程
Web3 安全风险令人生畏?应该如何应对?
微服务的简单介绍
Jmeter cross-platform operation CSV files
观察者模式
nodejs+express实现数据库mysql的访问,并展示数据到页面上
114. How to find the cause of Fiori Launchpad routing error by single-step debugging
LeetCode third topic (the Longest Substring Without Repeating Characters) trilogy # 3: two optimization
typescript50 - type specification between cross types and interfaces
Google Earth Engine ——利用公开的河流数据计算河流的有效宽度
typescript48-函数之间的类型兼容性
七夕佳节即将来到,VR全景云游为你神助攻
Android interview questions and answer analysis of major factories in the first half of 2022 (continuously updated...)