Pretty Tensor - Fluent Neural Networks in TensorFlow

Overview

Pretty Tensor - Fluent Neural Networks in TensorFlow

Pretty Tensor provides a high level builder API for TensorFlow. It provides thin wrappers on Tensors so that you can easily build multi-layer neural networks.

Pretty Tensor provides a set of objects that behave likes Tensors, but also support a chainable object syntax to quickly define neural networks and other layered architectures in TensorFlow.

result = (pretty_tensor.wrap(input_data, m)
          .flatten()
          .fully_connected(200, activation_fn=tf.nn.relu)
          .fully_connected(10, activation_fn=None)
          .softmax(labels, name=softmax_name))

Please look here for full documentation of the PrettyTensor object for all available operations: Available Operations or you can check out the complete documentation

See the tutorial directory for samples: tutorial/

Installation

The easiest installation is just to use pip:

  1. Follow the instructions at tensorflow.org
  2. pip install prettytensor

Note: Head is tested against the TensorFlow nightly builds and pip is tested against TensorFlow release.

Quick start

Imports

import prettytensor as pt
import tensorflow as tf

Setup your input

my_inputs = # numpy array of shape (BATCHES, BATCH_SIZE, DATA_SIZE)
my_labels = # numpy array of shape (BATCHES, BATCH_SIZE, CLASSES)
input_tensor = tf.placeholder(np.float32, shape=(BATCH_SIZE, DATA_SIZE))
label_tensor = tf.placeholder(np.float32, shape=(BATCH_SIZE, CLASSES))
pretty_input = pt.wrap(input_tensor)

Define your model

softmax, loss = (pretty_input.
                 fully_connected(100).
                 softmax_classifier(CLASSES, labels=label_tensor))

Train and evaluate

accuracy = softmax.evaluate_classifier(label_tensor)

optimizer = tf.train.GradientDescentOptimizer(0.1)  # learning rate
train_op = pt.apply_optimizer(optimizer, losses=[loss])

init_op = tf.initialize_all_variables()

with tf.Session() as sess:
    sess.run(init_op)
    for inp, label in zip(my_inputs, my_labels):
        unused_loss_value, accuracy_value = sess.run([loss, accuracy],
                                 {input_tensor: inp, label_tensor: label})
        print 'Accuracy: %g' % accuracy_value

Features

Thin

Full power of TensorFlow is easy to use

Pretty Tensors can be used (almost) everywhere that a tensor can. Just call pt.wrap to make a tensor pretty.

You can also add any existing TensorFlow function to the chain using apply. apply applies the current Tensor as the first argument and takes all the other arguments as normal.

Note: because apply is so generic, Pretty Tensor doesn't try to wrap the world.

Plays well with other libraries

It also uses standard TensorFlow idioms so that it plays well with other libraries, this means that you can use it a little bit in a model or throughout. Just make sure to run the update_ops on each training set (see with_update_ops).

Terse

You've already seen how a Pretty Tensor is chainable and you may have noticed that it takes care of handling the input shape. One other feature worth noting are defaults. Using defaults you can specify reused values in a single place without having to repeat yourself.

with pt.defaults_scope(activation_fn=tf.nn.relu):
  hidden_output2 = (pretty_images.flatten()
                   .fully_connected(100)
                   .fully_connected(100))

Check out the documentation to see all supported defaults.

Code matches model

Sequential mode lets you break model construction across lines and provides the subdivide syntactic sugar that makes it easy to define and understand complex structures like an inception module:

with pretty_tensor.defaults_scope(activation_fn=tf.nn.relu):
  seq = pretty_input.sequential()
  with seq.subdivide(4) as towers:
    towers[0].conv2d(1, 64)
    towers[1].conv2d(1, 112).conv2d(3, 224)
    towers[2].conv2d(1, 32).conv2d(5, 64)
    towers[3].max_pool(2, 3).conv2d(1, 32)

Inception module showing branch and rejoin

Templates provide guaranteed parameter reuse and make unrolling recurrent networks easy:

output = [], s = tf.zeros([BATCH, 256 * 2])

A = (pretty_tensor.template('x')
     .lstm_cell(num_units=256, state=UnboundVariable('state'))

for x in pretty_input_array:
  h, s = A.construct(x=x, state=s)
  output.append(h)

There are also some convenient shorthands for LSTMs and GRUs:

pretty_input_array.sequence_lstm(num_units=256)

Unrolled RNN

Extensible

You can call any existing operation by using apply and it will simply subsitute the current tensor for the first argument.

pretty_input.apply(tf.mul, 5)

You can also create a new operation There are two supported registration mechanisms to add your own functions. @Register() allows you to create a method on PrettyTensor that operates on the Tensors and returns either a loss or a new value. Name scoping and variable scoping are handled by the framework.

The following method adds the leaky_relu method to every Pretty Tensor:

@pt.Register
def leaky_relu(input_pt):
  return tf.select(tf.greater(input_pt, 0.0), input_pt, 0.01 * input_pt)

@RegisterCompoundOp() is like adding a macro, it is designed to group together common sets of operations.

Safe variable reuse

Within a graph, you can reuse variables by using templates. A template is just like a regular graph except that some variables are left unbound.

See more details in PrettyTensor class.

Accessing Variables

Pretty Tensor uses the standard graph collections from TensorFlow to store variables. These can be accessed using tf.get_collection(key) with the following keys:

  • tf.GraphKeys.VARIABLES: all variables that should be saved (including some statistics).
  • tf.GraphKeys.TRAINABLE_VARIABLES: all variables that can be trained (including those before a stop_gradients` call). These are what would typically be called parameters of the model in ML parlance.
  • pt.GraphKeys.TEST_VARIABLES: variables used to evaluate a model. These are typically not saved and are reset by the LocalRunner.evaluate method to get a fresh evaluation.

Authors

Eider Moore (eiderman)

with key contributions from:

  • Hubert Eichner
  • Oliver Lange
  • Sagar Jain (sagarjn)
Owner
Google
Google ❤️ Open Source
Google
Python framework for Stochastic Differential Equations modeling

SDElearn: a Python package for SDE modeling This package implements functionalities for working with Stochastic Differential Equations models (SDEs fo

4 May 10, 2022
DECAF: Generating Fair Synthetic Data Using Causally-Aware Generative Networks

DECAF (DEbiasing CAusal Fairness) Code Author: Trent Kyono This repository contains the code used for the "DECAF: Generating Fair Synthetic Data Using

van_der_Schaar \LAB 7 Nov 24, 2022
Unrolled Generative Adversarial Networks

Unrolled Generative Adversarial Networks Luke Metz, Ben Poole, David Pfau, Jascha Sohl-Dickstein arxiv:1611.02163 This repo contains an example notebo

Ben Poole 292 Dec 06, 2022
Scaling and Benchmarking Self-Supervised Visual Representation Learning

FAIR Self-Supervision Benchmark is deprecated. Please see VISSL, a ground-up rewrite of benchmark in PyTorch. FAIR Self-Supervision Benchmark This cod

Meta Research 584 Dec 31, 2022
Official PyTorch implementation of PICCOLO: Point-Cloud Centric Omnidirectional Localization (ICCV 2021)

Official PyTorch implementation of PICCOLO: Point-Cloud Centric Omnidirectional Localization (ICCV 2021)

16 Nov 19, 2022
Code for BMVC2021 paper "Boundary Guided Context Aggregation for Semantic Segmentation"

Boundary-Guided-Context-Aggregation Boundary Guided Context Aggregation for Semantic Segmentation Haoxiang Ma, Hongyu Yang, Di Huang In BMVC'2021 Pape

Haoxiang Ma 31 Jan 08, 2023
A pytorch implementation of Detectron. Both training from scratch and inferring directly from pretrained Detectron weights are available.

Use this instead: https://github.com/facebookresearch/maskrcnn-benchmark A Pytorch Implementation of Detectron Example output of e2e_mask_rcnn-R-101-F

Roy 2.8k Dec 29, 2022
YOLOv5 detection interface - PyQt5 implementation

所有代码已上传,直接clone后,运行yolo_win.py即可开启界面。 2021/9/29:加入置信度选择 界面是在ultralytics的yolov5基础上建立的,界面使用pyqt5实现,内容较简单,娱乐而已。 功能: 模型选择 本地文件选择(视频图片均可) 开关摄像头

487 Dec 27, 2022
Baseline model for "GraspNet-1Billion: A Large-Scale Benchmark for General Object Grasping" (CVPR 2020)

GraspNet Baseline Baseline model for "GraspNet-1Billion: A Large-Scale Benchmark for General Object Grasping" (CVPR 2020). [paper] [dataset] [API] [do

GraspNet 209 Dec 29, 2022
yolov5 deepsort 行人 车辆 跟踪 检测 计数

yolov5 deepsort 行人 车辆 跟踪 检测 计数 实现了 出/入 分别计数。 默认是 南/北 方向检测,若要检测不同位置和方向,可在 main.py 文件第13行和21行,修改2个polygon的点。 默认检测类别:行人、自行车、小汽车、摩托车、公交车、卡车。 检测类别可在 detect

554 Dec 30, 2022
[ICML 2021, Long Talk] Delving into Deep Imbalanced Regression

Delving into Deep Imbalanced Regression This repository contains the implementation code for paper: Delving into Deep Imbalanced Regression Yuzhe Yang

Yuzhe Yang 568 Dec 30, 2022
Codebase for BMVC 2021 paper "Text Based Person Search with Limited Data"

Text Based Person Search with Limited Data This is the codebase for our BMVC 2021 paper. Please bear with me refactoring this codebase after CVPR dead

Xiao Han 33 Nov 24, 2022
Multi-Stage Progressive Image Restoration

Multi-Stage Progressive Image Restoration Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, Ming-Hsuan Yang, and Ling Sh

Syed Waqas Zamir 859 Dec 22, 2022
Segmentation models with pretrained backbones. Keras and TensorFlow Keras.

Python library with Neural Networks for Image Segmentation based on Keras and TensorFlow. The main features of this library are: High level API (just

Pavel Yakubovskiy 4.2k Jan 09, 2023
Blind Video Temporal Consistency via Deep Video Prior

deep-video-prior (DVP) Code for NeurIPS 2020 paper: Blind Video Temporal Consistency via Deep Video Prior PyTorch implementation | paper | project web

Chenyang LEI 272 Dec 21, 2022
Visualization toolkit for neural networks in PyTorch! Demo -->

FlashTorch A Python visualization toolkit, built with PyTorch, for neural networks in PyTorch. Neural networks are often described as "black box". The

Misa Ogura 692 Dec 29, 2022
The FIRST GANs-based omics-to-omics translation framework

OmiTrans Please also have a look at our multi-omics multi-task DL freamwork 👀 : OmiEmbed The FIRST GANs-based omics-to-omics translation framework Xi

Xiaoyu Zhang 6 Dec 14, 2022
Fortuitous Forgetting in Connectionist Networks

Fortuitous Forgetting in Connectionist Networks Introduction This repository includes reference code for the paper Fortuitous Forgetting in Connection

Hattie Zhou 14 Nov 26, 2022
A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch

Introduction This is a Python package available on PyPI for NVIDIA-maintained utilities to streamline mixed precision and distributed training in Pyto

Artit 'Art' Wangperawong 5 Sep 29, 2021
CCAFNet: Crossflow and Cross-scale Adaptive Fusion Network for Detecting Salient Objects in RGB-D Images

Code and result about CCAFNet(IEEE TMM) 'CCAFNet: Crossflow and Cross-scale Adaptive Fusion Network for Detecting Salient Objects in RGB-D Images' IEE

zyrant丶 14 Dec 29, 2021