Learning Chinese Character style with conditional GAN

Overview

zi2zi: Master Chinese Calligraphy with Conditional Adversarial Networks

animation

Introduction

Learning eastern asian language typefaces with GAN. zi2zi(字到字, meaning from character to character) is an application and extension of the recent popular pix2pix model to Chinese characters.

Details could be found in this blog post.

Network Structure

Original Model

alt network

The network structure is based off pix2pix with the addition of category embedding and two other losses, category loss and constant loss, from AC-GAN and DTN respectively.

Updated Model with Label Shuffling

alt network

After sufficient training, d_loss will drop to near zero, and the model's performance plateaued. Label Shuffling mitigate this problem by presenting new challenges to the model.

Specifically, within a given minibatch, for the same set of source characters, we generate two sets of target characters: one with correct embedding labels, the other with the shuffled labels. The shuffled set likely will not have the corresponding target images to compute L1_Loss, but can be used as a good source for all other losses, forcing the model to further generalize beyond the limited set of provided examples. Empirically, label shuffling improves the model's generalization on unseen data with better details, and decrease the required number of characters.

You can enable label shuffling by setting flip_labels=1 option in train.py script. It is recommended that you enable this after d_loss flatlines around zero, for further tuning.

Gallery

Compare with Ground Truth

compare

Brush Writing Fonts

brush

Cursive Script (Requested by SNS audience)

cursive

Mingchao Style (宋体/明朝体)

gaussian

Korean

korean

Interpolation

animation

Animation

animation animation

easter egg

How to Use

Step Zero

Download tons of fonts as you please

Requirement

  • Python 2.7
  • CUDA
  • cudnn
  • Tensorflow >= 1.0.1
  • Pillow(PIL)
  • numpy >= 1.12.1
  • scipy >= 0.18.1
  • imageio

Preprocess

To avoid IO bottleneck, preprocessing is necessary to pickle your data into binary and persist in memory during training.

First run the below command to get the font images:

python font2img.py --src_font=src.ttf
                   --dst_font=tgt.otf
                   --charset=CN 
                   --sample_count=1000
                   --sample_dir=dir
                   --label=0
                   --filter=1
                   --shuffle=1

Four default charsets are offered: CN, CN_T(traditional), JP, KR. You can also point it to a one line file, it will generate the images of the characters in it. Note, filter option is highly recommended, it will pre sample some characters and filter all the images that have the same hash, usually indicating that character is missing. label indicating index in the category embeddings that this font associated with, default to 0.

After obtaining all images, run package.py to pickle the images and their corresponding labels into binary format:

python package.py --dir=image_directories
                  --save_dir=binary_save_directory
                  --split_ratio=[0,1]

After running this, you will find two objects train.obj and val.obj under the save_dir for training and validation, respectively.

Experiment Layout

experiment/
└── data
    ├── train.obj
    └── val.obj

Create a experiment directory under the root of the project, and a data directory within it to place the two binaries. Assuming a directory layout enforce bettet data isolation, especially if you have multiple experiments running.

Train

To start training run the following command

python train.py --experiment_dir=experiment 
                --experiment_id=0
                --batch_size=16 
                --lr=0.001
                --epoch=40 
                --sample_steps=50 
                --schedule=20 
                --L1_penalty=100 
                --Lconst_penalty=15

schedule here means in between how many epochs, the learning rate will decay by half. The train command will create sample,logs,checkpoint directory under experiment_dir if non-existed, where you can check and manage the progress of your training.

Infer and Interpolate

After training is done, run the below command to infer test data:

python infer.py --model_dir=checkpoint_dir/ 
                --batch_size=16 
                --source_obj=binary_obj_path 
                --embedding_ids=label[s] of the font, separate by comma
                --save_dir=save_dir/

Also you can do interpolation with this command:

python infer.py --model_dir= checkpoint_dir/ 
                --batch_size=10
                --source_obj=obj_path 
                --embedding_ids=label[s] of the font, separate by comma
                --save_dir=frames/ 
                --output_gif=gif_path 
                --interpolate=1 
                --steps=10
                --uroboros=1

It will run through all the pairs of fonts specified in embedding_ids and interpolate the number of steps as specified.

Pretrained Model

Pretained model can be downloaded here which is trained with 27 fonts, only generator is saved to reduce the model size. You can use encoder in the this pretrained model to accelerate the training process.

Acknowledgements

Code derived and rehashed from:

License

Apache 2.0

Owner
Yuchen Tian
Born in the year of Snake, now stuck with Python.
Yuchen Tian
This is the code for the paper "Contrastive Clustering" (AAAI 2021)

Contrastive Clustering (CC) This is the code for the paper "Contrastive Clustering" (AAAI 2021) Dependency python=3.7 pytorch=1.6.0 torchvision=0.8

Yunfan Li 210 Dec 30, 2022
[CVPR 2021] MiVOS - Mask Propagation module. Reproduced STM (and better) with training code :star2:. Semi-supervised video object segmentation evaluation.

MiVOS (CVPR 2021) - Mask Propagation Ho Kei Cheng, Yu-Wing Tai, Chi-Keung Tang [arXiv] [Paper PDF] [Project Page] [Papers with Code] This repo impleme

Rex Cheng 106 Jan 03, 2023
(JMLR'19) A Python Toolbox for Scalable Outlier Detection (Anomaly Detection)

Python Outlier Detection (PyOD) Deployment & Documentation & Stats Build Status & Coverage & Maintainability & License PyOD is a comprehensive and sca

Yue Zhao 6.6k Jan 03, 2023
Code for "Long Range Probabilistic Forecasting in Time-Series using High Order Statistics"

Long Range Probabilistic Forecasting in Time-Series using High Order Statistics This is the code produced as part of the paper Long Range Probabilisti

16 Dec 06, 2022
Evaluation framework for testing segmentation networks in PyTorch

Evaluation framework for testing segmentation networks in PyTorch. What segmentation network to choose for next Kaggle competition? This benchmark knows the answer!

Eugene Khvedchenya 37 Apr 27, 2022
This repository is an implementation of paper : Improving the Training of Graph Neural Networks with Consistency Regularization

CRGNN Paper : Improving the Training of Graph Neural Networks with Consistency Regularization Environments Implementing environment: GeForce RTX™ 3090

THUDM 28 Dec 09, 2022
Introduction to CPM

CPM CPM is an open-source program on large-scale pre-trained models, which is conducted by Beijing Academy of Artificial Intelligence and Tsinghua Uni

Tsinghua AI 136 Dec 23, 2022
A library for efficient similarity search and clustering of dense vectors.

Faiss Faiss is a library for efficient similarity search and clustering of dense vectors. It contains algorithms that search in sets of vectors of any

Meta Research 18.8k Jan 08, 2023
A complete end-to-end demonstration in which we collect training data in Unity and use that data to train a deep neural network to predict the pose of a cube. This model is then deployed in a simulated robotic pick-and-place task.

Object Pose Estimation Demo This tutorial will go through the steps necessary to perform pose estimation with a UR3 robotic arm in Unity. You’ll gain

Unity Technologies 187 Dec 24, 2022
Wordle Env: A Daily Word Environment for Reinforcement Learning

Wordle Env: A Daily Word Environment for Reinforcement Learning Setup Steps: git pull [email&#

2 Mar 28, 2022
The Implicit Bias of Gradient Descent on Generalized Gated Linear Networks

The Implicit Bias of Gradient Descent on Generalized Gated Linear Networks This folder contains the code to reproduce the data in "The Implicit Bias o

Samuel Lippl 0 Feb 05, 2022
A code generator from ONNX to PyTorch code

onnx-pytorch Generating pytorch code from ONNX. Currently support onnx==1.9.0 and torch==1.8.1. Installation From PyPI pip install onnx-pytorch From

Wenhao Hu 94 Jan 06, 2023
Multi-modal Text Recognition Networks: Interactive Enhancements between Visual and Semantic Features

Multi-modal Text Recognition Networks: Interactive Enhancements between Visual and Semantic Features | paper | Official PyTorch implementation for Mul

48 Dec 28, 2022
A high-level Python library for Quantum Natural Language Processing

lambeq About lambeq is a toolkit for quantum natural language processing (QNLP). Documentation: https://cqcl.github.io/lambeq/ Getting started Prerequ

Cambridge Quantum 315 Jan 01, 2023
Multiwavelets-based operator model

Multiwavelet model for Operator maps Gaurav Gupta, Xiongye Xiao, and Paul Bogdan Multiwavelet-based Operator Learning for Differential Equations In Ne

Gaurav 33 Dec 04, 2022
Fast SHAP value computation for interpreting tree-based models

FastTreeSHAP FastTreeSHAP package is built based on the paper Fast TreeSHAP: Accelerating SHAP Value Computation for Trees published in NeurIPS 2021 X

LinkedIn 369 Jan 04, 2023
ShapeGlot: Learning Language for Shape Differentiation

ShapeGlot: Learning Language for Shape Differentiation Created by Panos Achlioptas, Judy Fan, Robert X.D. Hawkins, Noah D. Goodman, Leonidas J. Guibas

Panos 32 Dec 23, 2022
HugsVision is a easy to use huggingface wrapper for state-of-the-art computer vision

HugsVision is an open-source and easy to use all-in-one huggingface wrapper for computer vision. The goal is to create a fast, flexible and user-frien

Labrak Yanis 166 Nov 27, 2022
Instance-Dependent Partial Label Learning

Instance-Dependent Partial Label Learning Installation pip install -r requirements.txt Run the Demo benchmark-random mnist python -u main.py --gpu 0 -

17 Dec 29, 2022
Monitora la qualità della ricezione dei segnali radio nelle province siciliane.

FMap-server Monitora la qualità della ricezione dei segnali radio nelle province siciliane. Conversion data Frequency - StationName maps are stored in

Triglie 5 May 24, 2021