当前位置:网站首页>TensoFlow学习记录(二):基础操作
TensoFlow学习记录(二):基础操作
2022-08-04 01:05:00 【狸狸Arina】
1. 数据类型
- 数据类型有整型(默认是int32),浮点型(默认是float32),以及布尔型,字符串;

- Tensor 常见属性:
- .device 当前设备信息;
- .numpy() 将Tensor转换为numpy格式;
- .ndim .shape .rank 返回Tensor形状;
- .is_tensor 判断是否是Tensor 类型


- 数据类型转换
- tf.convert_to_tensor 将数据转换成tensor类型,当从Numpy转换成tensor的时候,会默认是int64,需要指定一下类型,才能成为tf默认的类型也就是int32;
- tf.cast() 实现tensor的数据类型转换;


- 求导特性
- 对参数求梯度的,需要variable包装一下,就拥有了trainable属性,这样就才求梯度。假如是自己写传播过程,更新后的参数也需要用variable包装;

- 对参数求梯度的,需要variable包装一下,就拥有了trainable属性,这样就才求梯度。假如是自己写传播过程,更新后的参数也需要用variable包装;
- To numpy

2. 创建Tensor
- from numpy, list

- tf.zeros

- tf.zeros_like

- tf.ones

- Fill

- Normal 正态分布,传入形状,可指定均值方差
- truncated_normal 裁剪过后的数据,裁去了前后分布太少的数据,只从中间数据多的地方取数据,同样可以指定均值方差。

- Uniform 均匀分布初始化,形状,最小值,最大值

- random.shuffle 随机打散,可以打散一个索引顺序,通过tf.gather去对应,这样可以实现两个同样行数的数据,进行索引一一对应的随机打散

- tf.constant

3. 索引与切片
- Basic indexing

- Numpy-style indexing

- start:end

- Indexing by :

- Indexing by ::

- ::-1

- …

- tf.gather()
- 输入参数:数据、维度、索引

- 输入参数:数据、维度、索引
- tf.gather_nd
- 前面输入数据,后面填取的联合维度。只把最内层的括号当做联合索引的坐标;


- 前面输入数据,后面填取的联合维度。只把最内层的括号当做联合索引的坐标;
- tf.boolean_mask
- 按布尔值索引,不指定维度相当于是第一个维度,指定axis就会根据axis去索引。给索引矩阵也可以;

- 按布尔值索引,不指定维度相当于是第一个维度,指定axis就会根据axis去索引。给索引矩阵也可以;
4. 维度变换
- Reshape


- tf.transpose 转置,perm数字指的是数字所在位置上放哪一个原来的维度
- pytorch中图像存储维度是[b,c,h,w],tf中是[b,h,w,c];

- pytorch中图像存储维度是[b,c,h,w],tf中是[b,h,w,c];
- tf.expand_dims

- tf.squeeze 可以去掉为1的维度。不指定维度的话就去掉所有的为1的维度

5. Broadcasting

- tf.broadcast_to

- Broadcast VS Tile

6. 数学运算
- ±*/%//

- tf.math.log tf.exp

- log2, log10?

- pow, sqrt

- @ matmul


7. 前向传播
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# x: [60k, 28, 28],
# y: [60k]
(x, y), _ = datasets.mnist.load_data()
# x: [0~255] => [0~1.]
x = tf.convert_to_tensor(x, dtype=tf.float32) / 255.
y = tf.convert_to_tensor(y, dtype=tf.int32)
print(x.shape, y.shape, x.dtype, y.dtype)
print(tf.reduce_min(x), tf.reduce_max(x))
print(tf.reduce_min(y), tf.reduce_max(y))
train_db = tf.data.Dataset.from_tensor_slices((x,y)).batch(128)
train_iter = iter(train_db)
sample = next(train_iter)
print('batch:', sample[0].shape, sample[1].shape)
# [b, 784] => [b, 256] => [b, 128] => [b, 10]
# [dim_in, dim_out], [dim_out]
w1 = tf.Variable(tf.random.truncated_normal([784, 256], stddev=0.1))
b1 = tf.Variable(tf.zeros([256]))
w2 = tf.Variable(tf.random.truncated_normal([256, 128], stddev=0.1))
b2 = tf.Variable(tf.zeros([128]))
w3 = tf.Variable(tf.random.truncated_normal([128, 10], stddev=0.1))
b3 = tf.Variable(tf.zeros([10]))
lr = 1e-3
for epoch in range(10): # iterate db for 10
for step, (x, y) in enumerate(train_db): # for every batch
# x:[128, 28, 28]
# y: [128]
# [b, 28, 28] => [b, 28*28]
x = tf.reshape(x, [-1, 28*28])
with tf.GradientTape() as tape: # tf.Variable
# x: [b, 28*28]
# h1 = [email protected] + b1
# [b, 784]@[784, 256] + [256] => [b, 256] + [256] => [b, 256] + [b, 256]
h1 = [email protected] + tf.broadcast_to(b1, [x.shape[0], 256])
h1 = tf.nn.relu(h1)
# [b, 256] => [b, 128]
h2 = [email protected] + b2
h2 = tf.nn.relu(h2)
# [b, 128] => [b, 10]
out = [email protected] + b3
# compute loss
# out: [b, 10]
# y: [b] => [b, 10]
y_onehot = tf.one_hot(y, depth=10)
# mse = mean(sum(y-out)^2)
# [b, 10]
loss = tf.square(y_onehot - out)
# mean: scalar
loss = tf.reduce_mean(loss)
# compute gradients
grads = tape.gradient(loss, [w1, b1, w2, b2, w3, b3])
# print(grads)
# w1 = w1 - lr * w1_grad
w1.assign_sub(lr * grads[0])
b1.assign_sub(lr * grads[1])
w2.assign_sub(lr * grads[2])
b2.assign_sub(lr * grads[3])
w3.assign_sub(lr * grads[4])
b3.assign_sub(lr * grads[5])
if step % 100 == 0:
print(epoch, step, 'loss:', float(loss))
边栏推荐
- C 学生管理系统 显示链表信息、删除链表
- 【虚拟化生态平台】虚拟化平台esxi挂载USB硬盘
- Use nodejs switch version (no need to uninstall and reinstall)
- Observability:你所需要知道的关于 Syslog 的一些知识
- 多渠道打包
- nodejs 安装多版本 版本切换
- 中原银行实时风控体系建设实践
- 螺旋矩阵_数组 | leecode刷题笔记
- 身为程序员的我们如何卷死别人?破局重生。
- Analysis of usage scenarios of mutex, read-write lock, spin lock, and atomic operation instructions xaddl and cmpxchg
猜你喜欢

Vant3 - click on the corresponding name name to jump to the next page corresponding to the location of the name of the TAB bar

nodeJs--async/await

【超详细教程】LVS+KeepAlived高可用部署实战应用

取模运算(MOD)

600MHz频段来了,它会是新的黄金频段吗?

虚拟机CentOS7中无图形界面安装Oracle

【虚拟户生态平台】虚拟化平台安装时遇到的坑

MATLAB三维绘图命令plot3入门

字符串变形

手撕Gateway源码,今日撕工作流程、负载均衡源码
随机推荐
How to find the cause of Fiori Launchpad routing errors by single-step debugging
jmeter分布式压测
字符串变形
114. How to find the cause of Fiori Launchpad routing error by single-step debugging
C 学生管理系统_添加学生
Tanabata festival coming, VR panoramic look god assists for you
typescript57 - Array generic interface
Mvc、Mvp和Mvvm
Installation and configuration of nodejs+npm
dynamic memory two
【详细教程】一文参透MongoDB聚合查询
typescript50-交叉类型和接口之间的类型说明
Getting started with MATLAB 3D drawing command plot3
ASP.NET 获取数据库的数据并写入到excel表格中
KunlunBase 1.0 is released!
Shell编程之循环语句(for、while)
手撕Gateway源码,今日撕工作流程、负载均衡源码
typescript48-函数之间的类型兼容性
电子制造企业部署WMS仓储管理系统的好处是什么
【虚拟化生态平台】虚拟化平台搭建