Making decision trees competitive with neural networks on CIFAR10, CIFAR100, TinyImagenet200, Imagenet

Overview

Neural-Backed Decision Trees · Site · Paper · Blog · Video

Try In Colab

Alvin Wan, *Lisa Dunlap, *Daniel Ho, Jihan Yin, Scott Lee, Henry Jin, Suzanne Petryk, Sarah Adel Bargal, Joseph E. Gonzalez
*denotes equal contribution

NBDTs match or outperform modern neural networks on CIFAR10, CIFAR100, TinyImagenet200, ImageNet and better generalize to unseen classes by up to 16%. Furthermore, our loss improves the original model’s accuracy by up to 2%. We attain 76.60% on ImageNet. See the 3-minute YouTube summary.

Table of Contents

Updates

  • 2/2/21 Talks: released a 3-minute YouTube video summarizing NBDT, along with a 15-minute technical talk
  • 1/28/21 arXiv: updated arXiv with latest results, improving neural network accuracy, generalization, and interpretability (4 new human studies, 600 responses each).
  • 1/22/21 Accepted: NBDT was accepted to ICLR 2021. Repository has been updated with new results and supporting code.

Quickstart

Running Pretrained NBDT on Examples

Don't want to download? Try your own images on the web demo.

pip install nbdt, and run our CLI on any image. Below, we run a CIFAR10 model on images from the web, which outputs both the class prediction and all the intermediate decisions. Although the Bear and Zebra classes were not seen at train time, the model still correctly picks Animal over Vehicle for both.

# install our cli
pip install nbdt

# Cat picture - can be a local image path or an image URL
nbdt https://images.pexels.com/photos/126407/pexels-photo-126407.jpeg?auto=compress&cs=tinysrgb&dpr=2&w=32
# Prediction: cat // Decisions: animal (Confidence: 99.47%), chordate (Confidence: 99.20%), carnivore (Confidence: 99.42%), cat (Confidence: 99.86%)

# Zebra picture (not in CIFAR10) - picks the closest CIFAR10 animal, which is horse
nbdt https://images.pexels.com/photos/750539/pexels-photo-750539.jpeg?auto=compress&cs=tinysrgb&dpr=2&h=32
# Prediction: horse // Decisions: animal (Confidence: 99.31%), ungulate (Confidence: 99.25%), horse (Confidence: 99.62%)

# Bear picture (not in CIFAR10)
nbdt https://images.pexels.com/photos/158109/kodiak-brown-bear-adult-portrait-wildlife-158109.jpeg?auto=compress&cs=tinysrgb&dpr=2&w=32
# Prediction: dog // Decisions: animal (Confidence: 99.51%), chordate (Confidence: 99.35%), carnivore (Confidence: 99.69%), dog (Confidence: 99.22%)

Pictures are taken from pexels.com, which are free to use per the Pexels license.

Loading Pretrained NBDTs in Code

Don't want to download? Try inference on a pre-filled Google Colab Notebook.

pip install nbdt to use our models. We have pretrained models for ResNet18 and WideResNet28x10 for CIFAR10, CIFAR100, and TinyImagenet200. See Models for adding other models. See nbdt-pytorch-image-models for EfficientNet on ImageNet.

Try below script on Google Colab

from nbdt.model import SoftNBDT
from nbdt.models import ResNet18, wrn28_10_cifar10, wrn28_10_cifar100, wrn28_10  # use wrn28_10 for TinyImagenet200

model = wrn28_10_cifar10()
model = SoftNBDT(
  pretrained=True,
  dataset='CIFAR10',
  arch='wrn28_10_cifar10',
  model=model)

Example in ~30 lines: See nbdt/bin/nbdt, which loads the pretrained model, loads an image, and runs inference on the image in ~30 lines. This file is the executable nbdt in the previous section. Try this in a Google Colab Notebook.

Convert Neural Networks to Decision Trees

To convert your neural network into a neural-backed decision tree, perform the following 3 steps:

  1. First, if you haven't already, pip install the nbdt utility:
pip install nbdt
  1. Second, train the original neural network with an NBDT loss. All NBDT losses work by wrapping the original criterion. To demonstrate this, we wrap the original loss criterion with a soft tree supervision loss.
from nbdt.loss import SoftTreeSupLoss
criterion = SoftTreeSupLoss(dataset='CIFAR10', criterion=criterion)  # `criterion` is your original loss function e.g., nn.CrossEntropyLoss
  1. Third, perform inference or validate using an NBDT model. All NBDT models work by wrapping the original model you trained in step 2. To demonstrate this, we wrap the model with a custom NBDT wrapper below. Note this model wrapper is only for inference and validation, not for train time.
from nbdt.model import SoftNBDT
model = SoftNBDT(dataset='CIFAR10', model=model)  # `model` is your original model

Example integration with repository: See nbdt-pytorch-image-models, which applies this 3-step integration to a popular image classification repository pytorch-image-models.

Example integration with a random neural network in 16 lines [click to expand]

You can also include arbitrary image classification neural networks not explicitly supported in this repository. For example, after installing pretrained-models.pytorch using pip, you can instantiate and pass any pretrained model into our NBDT utility functions.

import torch.nn as nn
from nbdt.model import SoftNBDT
from nbdt.loss import SoftTreeSupLoss
from nbdt.hierarchy import generate_hierarchy
import pretrainedmodels

model = pretrainedmodels.__dict__['fbresnet152'](num_classes=1000, pretrained='imagenet')

# 1. generate hierarchy from pretrained model
generate_hierarchy(dataset='Imagenet1000', arch='fbresnet152', model=model)

# 2. Fine-tune model with tree supervision loss
criterion = nn.CrossEntropyLoss()
criterion = SoftTreeSupLoss(dataset='Imagenet1000', hierarchy='induced-fbresnet152', criterion=criterion)

# 3. Run inference using embedded decision rules
model = SoftNBDT(model=model, dataset='Imagenet1000', hierarchy='induced-fbresnet152')

For more information on generating different hierarchies, see Induced Hierarchy.

Want to build and use your own induced hierarchy? [click to expand]

Use the nbdt-hierarchy utility to generate a new induced hierarchy from a pretrained model.

nbdt-hierarchy --arch=efficientnet_b0 --dataset=Imagenet1000

Then, pass the hierarchy name to the loss and models. You may alternatively pass the fully-qualified path_graph path.

from nbdt.loss import SoftTreeSupLoss
from nbdt.model import SoftNBDT

criterion = SoftTreeSupLoss(dataset='Imagenet1000', criterion=criterion, hierarchy='induced-efficientnet_b0')
model = SoftNBDT(dataset='Imagenet1000', model=model, hierarchy='induced-efficientnet_b0')

For more information on generating different hierarchies, see Induced Hierarchy.

Training and Evaluation

To reproduce experimental results, clone the repository, install all requirements, and run our bash script.

git clone [email protected]:alvinwan/neural-backed-decision-trees.git  # or http addr if you don't have private-public github key setup
cd neural-backed-decision-trees
python setup.py develop # install all requirements
bash scripts/gen_train_eval_wideresnet.sh # reproduce paper core CIFAR10, CIFAR100, and TinyImagenet200 results

We (1) generate the hierarchy and (2) train the neural network with a tree supervision loss. Then, we (3) run inference by featurizing images using the network backbone and running embedded decision rules. Notes:

  • See below sections for details on visualizations, reproducing ablation studies, and different configurations (e.g., different hierarchies).
  • To reproduce our ImageNet results, see examples/imagenet for ResNet and nbdt-pytorch-image-models for EfficientNet.
  • For all scripts, you can use any torchvision model or any pytorchcv model, as we directly support both model zoos. Customization for each step is explained below.

1. Generate Hierarchy

Run the following to generate and test induced hierarchies for CIFAR10 based off of the WideResNet model.

nbdt-hierarchy --arch=wrn28_10_cifar10 --dataset=CIFAR10
See how it works and how to configure. [click to expand]

induced_structure

The script loads the pretrained model (Step A), populates the leaves of the tree with fully-connected layer weights (Step B) and performs hierarchical agglomerative clustering (Step C). Note that the above command can be rerun with different architectures, different datasets, or random neural network checkpoints to produce different hierarchies.

# different architecture: ResNet18
nbdt-hierarchy --arch=ResNet18 --dataset=CIFAR10

# different dataset: ImageNet
nbdt-hierarchy --arch=efficientnet_b7 --dataset=Imagenet1000

# arbitrary checkpoint
wget https://download.pytorch.org/models/resnet18-5c106cde.pth -O resnet18.pth
nbdt-hierarchy --checkpoint=resnet18.pth --dataset=Imagenet1000

You can also run the hierarchy generation from source directly, without using the command-line tool, by passing in a pretrained model.

from nbdt.hierarchy import generate_hierarchy
from nbdt.models import wrn28_10_cifar10

model = wrn28_10_cifar10(pretrained=True)
generate_hierarchy(dataset='Imagenet1000', arch='wrn28_10_cifar10', model=model)
See example visualization. [click to expand]

By default, the generation script outputs the HTML file containing a d3 visualization. All visualizations are stored in out/. We will generate another visualization with larger font size and includes wordnet IDs where available.

nbdt-hierarchy --vis-sublabels --vis-zoom=1.25 --dataset=CIFAR10 --arch=wrn28_10_cifar10

The above script's output will end with the following.

==> Reading from ./nbdt/hierarchies/CIFAR10/graph-induced-wrn28_10_cifar10.json
Found just 1 root.
==> Wrote HTML to out/induced-wrn28_10_cifar10-tree.html

Open up out/induced-wrn28_10_cifar10-tree.html in your browser to view the d3 tree visualization.

Screen Shot 2020-03-24 at 1 51 49 AM
Want to reproduce hierarchy visualizations from the paper? [click to expand]

To generate figures from the paper, use a larger zoom and do not include sublabels. The checkpoints used to generate the induced hierarchy visualizations are included in this repository's hub of models.

nbdt-hierarchy --vis-zoom=2.5 --dataset=CIFAR10 --arch=ResNet10 --vis-force-labels-left conveyance vertebrate chordate vehicle motor_vehicle mammal placental
nbdt-hierarchy --vis-zoom=2.5 --dataset=CIFAR10 --arch=wrn28_10_cifar10 --vis-leaf-images --vis-image-resize-factor=1.5 --vis-force-labels-left motor_vehicle craft chordate vertebrate carnivore ungulate craft
nbdt-hierarchy --vis-zoom=2.5 --dataset=CIFAR10 --arch=wrn28_10_cifar10 --vis-color-nodes whole --vis-no-color-leaves --vis-force-labels-left motor_vehicle craft chordate vertebrate carnivore ungulate craft
CIFAR10-induced-wrn28_10_cifar10 CIFAR10_ResNet10_Tree
Generate WordNet hierarchy and see how it works. [click to expand]

Run the following to generate and test WordNet hierarchies for CIFAR10, CIFAR100, and TinyImagenet200. The script also downloads the NLTK WordNet corpus.

bash scripts/generate_hierarchies_wordnet.sh

The below just explains the above generate_hierarchies_wordnet.sh, using CIFAR10. You do not need to run the following after running the above bash script.

# Generate mapping from classes to WNID. This is required for CIFAR10 and CIFAR100.
nbdt-wnids --dataset=CIFAR10

# Generate hierarchy, using the WNIDs. This is required for all datasets: CIFAR10, CIFAR100, TinyImagenet200
nbdt-hierarchy --method=wordnet --dataset=CIFAR10
See example WordNet visualization. [click to expand]

We can generate a visualization with a slightly improved zoom and with wordnet IDs. By default, the script builds the Wordnet hierarchy for CIFAR10.

nbdt-hierarchy --method=wordnet --vis-zoom=1.25 --vis-sublabels
Screen Shot 2020-03-24 at 2 02 16 AM
Generate random hierarchy. [click to expand]

Use --method=random to randomly generate a binary-ish hierarchy. Optionally, use the --seed (--seed=-1 to not shuffle leaves) and --branching-factor flags. When debugging, we set branching factor to the number of classes. For example, the sanity check hierarchy for CIFAR10 is

nbdt-hierarchy --seed=-1 --branching-factor=10 --dataset=CIFAR10

2. Tree Supervision Loss

In the below training commands, we uniformly use --path-resume=<path/to/checkpoint> --lr=0.01 to fine-tune instead of training from scratch. Our results using a recently state-of-the-art pretrained checkpoint (WideResNet) were fine-tuned. Run the following to fine-tune WideResNet with soft tree supervision loss on CIFAR10.

python main.py --lr=0.01 --dataset=CIFAR10 --arch=wrn28_10_cifar10 --hierarchy=induced-wrn28_10_cifar10 --pretrained --loss=SoftTreeSupLoss
See how it works and how to configure. [click to expand]

tree_supervision_loss

The tree supervision loss features two variants: a hard version and a soft version. Simply change the loss to HardTreeSupLoss or SoftTreeSupLoss, depending on the one you want.

# fine-tune the wrn pretrained checkpoint on CIFAR10 with hard tree supervision loss
python main.py --lr=0.01 --dataset=CIFAR10 --arch=wrn28_10_cifar10 --hierarchy=induced-wrn28_10_cifar10 --pretrained --loss=HardTreeSupLoss

# fine-tune the wrn pretrained checkpoint on CIFAR10 with soft tree supervision loss
python main.py --lr=0.01 --dataset=CIFAR10 --arch=wrn28_10_cifar10 --hierarchy=induced-wrn28_10_cifar10 --pretrained --loss=SoftTreeSupLoss

To train from scratch, use --lr=0.1 and do not pass the --path-resume or --pretrained flags. We fine-tune WideResnet on CIFAR10, CIFAR100, but where the baseline neural network accuracy is reproducible, we train from scratch.

3. Inference

Like with the tree supervision loss variants, there are two inference variants: one is hard and one is soft. Below, we run soft inference on the model we just trained with the soft loss.

Run the following bash script to obtain these numbers.

python main.py --dataset=CIFAR10 --arch=wrn28_10_cifar10 --hierarchy=induced-wrn28_10_cifar10 --loss=SoftTreeSupLoss --eval --resume --analysis=SoftEmbeddedDecisionRules
See how it works and how to configure. [click to expand]

inference_modes

Note the following commands are nearly identical to the corresponding train commands -- we drop the lr, pretrained flags and add resume, eval, and the analysis type (hard or soft inference). The best results in our paper, oddly enough, were obtained by running hard and soft inference both on the neural network supervised by a soft tree supervision loss. This is reflected in the commands below.

# running soft inference on soft-supervised model
python main.py --dataset=CIFAR10 --arch=wrn28_10_cifar10 --hierarchy=induced-wrn28_10_cifar10 --loss=SoftTreeSupLoss --eval --resume --analysis=SoftEmbeddedDecisionRules

# running hard inference on soft-supervised model
python main.py --dataset=CIFAR10 --arch=wrn28_10_cifar10 --hierarchy=induced-wrn28_10_cifar10 --loss=SoftTreeSupLoss --eval --resume --analysis=HardEmbeddedDecisionRules
Logging maximum and minimum 'path entropy' samples. [click to expand]
# get min and max entropy samples for baseline neural network
python main.py --pretrained --dataset=TinyImagenet200 --eval --dataset-test=Imagenet1000 --disable-test-eval --analysis=TopEntropy  # or Entropy, or TopDifference

# download public checkpoint
wget https://github.com/alvinwan/neural-backed-decision-trees/releases/download/0.0.1/ckpt-TinyImagenet200-ResNet18-induced-ResNet18-SoftTreeSupLoss-tsw10.0.pth -O checkpoint/ckpt-TinyImagenet200-ResNet18-induced-ResNet18-SoftTreeSupLoss-tsw10.0.pth

# get min and max 'path entropy' samples for NBDT
python main.py --dataset TinyImagenet200 --resume --path-resume checkpoint/ckpt-TinyImagenet200-ResNet18-induced-ResNet18-SoftTreeSupLoss-tsw10.0.pth --eval --analysis NBDTEntropyMaxMin --dataset-test=Imagenet1000 --disable-test-eval --hierarchy induced-ResNet18
Running zero-shot evaluation on superclasses. [click to expand]
# get wnids for animal and vehicle -- use the outputted wnids for below commands
nbdt-wnids --classes animal vehicle

# evaluate CIFAR10-trained ResNet18 on "Animal vs. Vehicle" superclasses, with images from TinyImagenet200
python main.py --dataset-test=TinyImagenet200 --dataset=CIFAR10 --disable-test-eval --eval --analysis=Superclass --superclass-wnids n00015388 n04524313 --pretrained

# download public checkpoint
wget https://github.com/alvinwan/neural-backed-decision-trees/releases/download/0.0.1/ckpt-CIFAR100-ResNet18-induced-ResNet18-SoftTreeSupLoss.pth -O checkpoint/ckpt-CIFAR10-ResNet18-induced-SoftTreeSupLoss.pth

# evaluate CIFAR10-trained NBDT-ResNet18 on "Animal vs. Vehicle" superclasses, with images from TinyImagenet200
python main.py --dataset-test=TinyImagenet200 --dataset=CIFAR10 --disable-test-eval --eval --analysis=SuperclassNBDT --superclass-wnids n00015388 n04524313  --loss=SoftTreeSupLoss --resume
Visualize decision nodes using 'prototypical' samples. [click to expand]
# get wnids for animal and vehicle -- use the outputted wnids for below commands
nbdt-wnids --classes animal vehicle

# find samples representative for CIFAR10-trained ResNet18, from animal and vehicle ImageNet images
python main.py --dataset-test=Imagenet1000 --dataset=CIFAR10 --disable-test-eval --eval --analysis=VisualizeDecisionNode --vdnw=n00015388 --pretrained --superclass-wnids n00015388 n04524313  # samples for "animal" node
python main.py --dataset-test=Imagenet1000 --dataset=CIFAR10 --disable-test-eval --eval --analysis=VisualizeDecisionNode --vdnw=n00015388 --pretrained --superclass-wnids n00015388 n04524313  # samples for "ungulate" node

# download public checkpoint
wget https://github.com/alvinwan/neural-backed-decision-trees/releases/download/0.0.1/ckpt-CIFAR100-ResNet18-induced-ResNet18-SoftTreeSupLoss.pth -O checkpoint/ckpt-CIFAR10-ResNet18-induced-SoftTreeSupLoss.pth

# find samples representative for CIFAR10-trained NBDT with ResNet18 backbone, from animal and vehicle ImageNet images
python main.py --dataset-test=Imagenet1000 --dataset=CIFAR10 --disable-test-eval --eval --analysis=VisualizeDecisionNode --vdnw=n01466257 --loss=SoftTreeSupLoss --resume --hierarchy=induced-ResNet18 --superclass-wnids n00015388 n04524313  # samples for "animal" node
Visualize inference probabilities in hierarchy. [click to expand]
python main.py --analysis=VisualizeHierarchyInference --eval --pretrained # soft inference by default

Results

We compare against all previous decision-tree-based methods that report on CIFAR10, CIFAR100, and/or ImageNet; we use numbers reported in the original papers (except DNDF, which did not have CIFAR or ImageNet top-1 scores):

  • Deep Neural Decision Forest (DNDF, updated with ResNet18)
  • Explainable Observer-Classifier (XOC)
  • Deep ConvolutionalDecision Jungle (DCDJ)
  • Network of Experts (NofE)
  • Deep Decision Network (DDN)
  • Adaptive Neural Trees (ANT)
  • Oblique Decision Trees (ODT)
  • Classic Decision Trees
CIFAR10 CIFAR100 TinyImagenet200 ImageNet
NBDT (Ours) 97.55% 82.97% 67.72% 76.60%
Best Pre-NBDT Acc 94.32% 76.24% 44.56% 61.29%
Best Pre-NBDT Method DNDF NofE DNDF NofE
Our improvement 3.23% 6.73% 23.16% 15.31%

Our pretrained checkpoints (CIFAR10, CIFAR100, and TinyImagenet200) may deviate from these numbers by 0.1-0.2%, as we retrained all models for public release.

Customize Repository for Your Application

As discussed above, you can use the nbdt python library to integrate NBDT training into any existing training pipeline, like ClassyVision (ClassyVision + NBDT Imagenet example). However, if you wish to use the barebones training utilities here, refer to the following sections for adding custom models and datasets.

If you have not already, start by cloning the repository and installing all requirements. As a sample, we've included copies of the WideResNet bash script but for ResNet18.

git clone [email protected]:alvinwan/neural-backed-decision-trees.git  # or http addr if you don't have private-public github key setup
cd neural-backed-decision-trees
python setup.py develop
bash scripts/gen_train_eval_resnet.sh

For any models that have pretrained checkpoints for the datasets of interest (e.g., CIFAR10, CIFAR100, and ImageNet models from pytorchcv or ImageNet models from torchvision), modify scripts/gen_train_eval_pretrained.sh; it suffices to change the model name. For all models that do not have pretrained checkpoint for the dataset of interest, modify scripts/gen_train_eval_nopretrained.sh.

Models

Without any modifications to main.py, you can replace ResNet18 with your favorite network: Pass any torchvision.models model or any pytorchcv model to --arch, as we directly support both model zoos. Note that the former only supports models pretrained on ImageNet. The latter supports models pretrained on CIFAR10, CIFAR100, andd ImageNet; for each dataset, the corresponding model name includes the dataset e.g., wrn28_10_cifar10. However, neither supports models pretrained on TinyImagenet.

To add a new model from scratch:

  1. Create a new file containing your network, such as ./nbdt/models/yournet.py. This file should contain an __all__ only exposing functions that return a model. These functions should accept pretrained: bool and progress: bool, then forward all other keyword arguments to the model constructor.
  2. Expose your new file via ./nbdt/models/__init__.py: from .yournet import *.
  3. Train the original neural network on the target dataset. e.g., python main.py --arch=yournet18.

Dataset

Without any modifications to main.py, you can use any image classification dataset found at torchvision.datasets by passing it to --dataset. To add a new dataset from scratch:

  1. Create a new file containing your dataset, such as ./nbdt/data/yourdata.py. Say the data class is YourData10. Like before, only expose the dataset class via __all__. This dataset class should support a .classes attribute which returns a list of human-readable class names.
  2. Expose your new file via './nbdt/data/__init__.py': from .yourdata import *.
  3. Modify nbdt.utils.DATASETS to include the name of your dataset, which is YourData10 in this example.
  4. Also in nbdt/utils.py, modify DATASET_TO_NUM_CLASSES and DATASET_TO_CLASSES to include your new dataset.
  5. (Optional) Create a text file with wordnet IDs in ./nbdt/wnids/{dataset}.txt. This list should be in the same order that your dataset's .classes is. You may optionally use the utility nbdt-wnids to generate wnids (see note below)
  6. Train the original neural network on the target dataset. e.g., python main.py --dataset=YourData10

*Note: You may optionally use the utility nbdt-wnids to generate wnids:

nbdt-wnids --dataset=YourData10

, where YourData is your dataset name. If a provided class name from YourData.classes does not exist in the WordNet corpus, the script will generate a fake wnid. This does not affect training but subsequent analysis scripts will be unable to provide WordNet-imputed node meanings.

Tests

To run tests, use the following command

pytest nbdt tests

Citation

If you find this work useful for your research, please cite our paper:

@misc{nbdt,
    title={NBDT: Neural-Backed Decision Trees},
    author={Alvin Wan and Lisa Dunlap and Daniel Ho and Jihan Yin and Scott Lee and Henry Jin and Suzanne Petryk and Sarah Adel Bargal and Joseph E. Gonzalez},
    year={2020},
    eprint={2004.00221},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}
Comments
  • Error when applying SoftTreeSupLoss to custom model/dataset

    Error when applying SoftTreeSupLoss to custom model/dataset

    Hey, Thanks for sharing your code. I have a pretrained pytorch(Fastai) model on my custom dataset and when i am trying to run the below code on my dataset. I get an assertion error.

    The dataset I have is loaded using a dataloader having 3 classes from nbdt.loss import SoftTreeSupLoss criterion = nn.CrossEntropyLoss() criterion = SoftTreeSupLoss(dataset='CIFAR10', criterion=criterion) -- originial criterion = SoftTreeSupLoss(dataset=data, criterion=criterion) -- custom

    Also, do I have to train the whole model again or can just pass this new loss function to that model and get the inference?

    Any help would be really appreciated :)

    opened by akshatshreemali 5
  • Extending NBDT to the NLP domain

    Extending NBDT to the NLP domain

    Hello @alvinwan and team!

    First off, thank you for the awsome work. As per this repo and your paper, I noticed that the applications of the NBDT technique have been limited to images. Do you foresee this technique being used in other domains such as NLP?

    I am interested in working on an interpretable decision-tree learning framework which can come close to DNNs' performance in NLP, and am drawing some inspiration from your work. Would be great to get your thoughts on this :)

    opened by atreyasha 4
  • [Question][Bug?] Why are you using FC layer outputs instead of Neural Backbone outputs during inference?

    [Question][Bug?] Why are you using FC layer outputs instead of Neural Backbone outputs during inference?

    According to the paper at 3.1 Inference with Embedded Decision Rules

    First, our NBDT approach featurizes each sample using the neural network backbone; the backbone consists of all neural network layers before the final fully-connected layer.

    So, It means the sample is run on the NN architecture excluding the final fully-connected layer(?)

    If so, Why are you including the final fully-connected layer in here before passing x to self.rules.forward_with_decisions? https://github.com/alvinwan/neural-backed-decision-trees/blob/7ef5fe5034281aeb0d10786495d5556e99b98af4/nbdt/model.py#L323-L324

    Can you please explain?

    opened by sukeesh 2
  • Change '--model' flag to '--arch' in bash scripts

    Change '--model' flag to '--arch' in bash scripts

    This PR changes "--model" flags in the bash scripts to "--arch". The "--model" flag does not appear in CLI interpreter code in Main.py

    opened by MDutro 1
  • BrokenPipeError: [Errno 32] Broken pipe

    BrokenPipeError: [Errno 32] Broken pipe

    When I cloned the repository and ran main.py (without changing anything) I get this output:

     not enough values to unpack (expected 2, got 0)
     ==> Preparing data..
     Files already downloaded and verified
     Files already downloaded and verified
     Training with dataset CIFAR10 and 10 classes 
     ==> Building model..
     ==> Checkpoints will be saved to: ./checkpoint/ckpt-CIFAR10-ResNet18.pth
     classes:	(callable) 
    
     Epoch: 0
     Traceback (most recent call last):
    
    File "<ipython-input-7-f9bd8031870b>", line 1, in <module>
      runfile('C:/Users/Matthew Chen/Documents/GitHub/neural-backed-decision-trees/main.py', wdir='C:/Users/Matthew Chen/Documents/GitHub/neural-backed-decision-trees')
    
    File "C:\Users\Matthew Chen\Anaconda3\lib\site-packages\spyder_kernels\customize\spydercustomize.py", line 827, in runfile
      execfile(filename, namespace)
    
    File "C:\Users\Matthew Chen\Anaconda3\lib\site-packages\spyder_kernels\customize\spydercustomize.py", line 110, in execfile
      exec(compile(f.read(), filename, 'exec'), namespace)
    
    File "C:/Users/Matthew Chen/Documents/GitHub/neural-backed-decision-trees/main.py", line 315, in <module>
      train(epoch, analyzer)
    
    File "C:/Users/Matthew Chen/Documents/GitHub/neural-backed-decision-trees/main.py", line 227, in train
      for batch_idx, (inputs, targets) in enumerate(trainloader):
    
    File "C:\Users\Matthew Chen\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 279, in __iter__
      return _MultiProcessingDataLoaderIter(self)
    
    File "C:\Users\Matthew Chen\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 719, in __init__
      w.start()
    
    File "C:\Users\Matthew Chen\Anaconda3\lib\multiprocessing\process.py", line 112, in start
      self._popen = self._Popen(self)
    
    File "C:\Users\Matthew Chen\Anaconda3\lib\multiprocessing\context.py", line 223, in _Popen
      return _default_context.get_context().Process._Popen(process_obj)
    
    File "C:\Users\Matthew Chen\Anaconda3\lib\multiprocessing\context.py", line 322, in _Popen
      return Popen(process_obj)
    
    File "C:\Users\Matthew Chen\Anaconda3\lib\multiprocessing\popen_spawn_win32.py", line 89, in __init__
      reduction.dump(process_obj, to_child)
    
    File "C:\Users\Matthew Chen\Anaconda3\lib\multiprocessing\reduction.py", line 60, in dump
      ForkingPickler(file, protocol).dump(obj)
    
      BrokenPipeError: [Errno 32] Broken pipe
    

    I'm not sure if I downloaded all the packages correctly, but this seems to be an error where the request to some server is blocked/timed out?

    opened by MatthewChen37 1
  • batch-size of training ImageNet

    batch-size of training ImageNet

    Hello, the current code setting in ImageNet, I can only feed batch-size=16 one gpu for resnet50 model with ImageNet database, my gpu is V100, however, the same settings in my friends computer the batch-size=36.

    opened by Muzijiajian 1
  • "RuntimeError: CUDA out of memory" when training with soft tree supervision loss

    Hello,

    I am trying to integrate the rare planes dataset with this code repository. I followed the steps for a custom dataset in the README and I was able to get one of the scripts running (partially). I can run 'step 0' and 'step 1' of the script gen_train_eval_nopretrained.sh, however, when I run 'step 2' I get a CUDA out of memory error. Any ideas on why this would happen?

    Note: I have tried lowering the batch size but that does not seem to affect the error message.

    For reference here is the script I am running:

    # Want to train with wordnet hierarchy? Just set '--hierarchy=wordnet' below.
    # This script is for networks that do NOT come with a pretrained checkpoint provided either by a model zoo or by the NBDT utility itself.
    
    model="ResNet18"
    dataset=RarePlanes
    weight=1
    batch_size=4
    
    # 0. train the baseline neural network
    python main.py --dataset=${dataset} --arch=${model} --batch-size=${batch_size}
    
    # 1. generate hieararchy -- for models without a pretrained checkpoint, use 'checkpoint'
    nbdt-hierarchy --dataset=${dataset} --checkpoint=./checkpoint/ckpt-${dataset}-${model}.pth
    
    # 2. train with soft tree supervision loss -- for models without a pretrained checkpoint, use 'path-resume' OR just train from scratch, without 'path-resume'
    # python main.py --lr=0.01 --dataset=${dataset} --model=${model} --hierarchy=induced-${model} --path-resume=./checkpoint/ckpt-${dataset}-${model}.pth --loss=SoftTreeSupLoss --tree-supervision-weight=${weight}  # fine-tuning
    python main.py --dataset=${dataset} --arch=${model} --hierarchy=induced-${model} --loss=SoftTreeSupLoss --tree-supervision-weight=${weight}  # training from scratch
    
    # 3. evaluate with soft then hard inference
    for analysis in SoftEmbeddedDecisionRules HardEmbeddedDecisionRules; do
      python main.py --dataset=${dataset} --arch=${model} --hierarchy=induced-${model} --loss=SoftTreeSupLoss --eval --resume --analysis=${analysis} --tree-supervision-weight=${weight}
    done
    

    And here is the error I get on line 17:

    Training with dataset RarePlanes and 54 classes 
    ==> Building model..
    ==> Checkpoints will be saved to: ./checkpoint/ckpt-RarePlanes-ResNet18-induced-ResNet18-SoftTreeSupLoss.pth
    path_graph:     /home/pguerrie/neural-backed-decision-trees/nbdt/hierarchies/RarePlanes/graph-induced-ResNet18.json 
    path_wnids:     /home/pguerrie/neural-backed-decision-trees/nbdt/wnids/RarePlanes.txt 
    tree_supervision_weight:        1.0 
    classes:        (callable) 
    dataset:        (callable) 
    criterion:      (callable) 
    classes:        (callable) 
    
    Epoch: 0
    Traceback (most recent call last):
      File "main.py", line 315, in <module>
        train(epoch, analyzer)
      File "main.py", line 230, in train
        outputs = net(inputs)
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
        result = self.forward(*input, **kwargs)
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 150, in forward
        return self.module(*inputs[0], **kwargs[0])
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/pguerrie/neural-backed-decision-trees/nbdt/models/resnet.py", line 112, in forward
        out = self.features(x)
      File "/home/pguerrie/neural-backed-decision-trees/nbdt/models/resnet.py", line 102, in features
        out = F.relu(self.bn1(self.conv1(x)))
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
        result = self.forward(*input, **kwargs)
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/batchnorm.py", line 107, in forward
        exponential_average_factor, self.eps)
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/functional.py", line 1670, in batch_norm
        training, momentum, eps, torch.backends.cudnn.enabled
    RuntimeError: CUDA out of memory. Tried to allocate 6.12 GiB (GPU 0; 11.91 GiB total capacity; 6.45 GiB already allocated; 4.86 GiB free; 6.47 GiB reserved in total by PyTorch)
    

    I don't understand why I would be able to run step 0 with a large batch size but I can't run step 2 even with a very small batch size. I was thinking the networks were largely the same (except for the extra output nodes in step 2 due to the hierarchy and the hierarchical loss used to train). Any help would be greatly appreciated!

    opened by guerriep 1
  • Getting the intermediate decisions on Colab

    Getting the intermediate decisions on Colab

    When I call model.forward_with_decisions to get the intermediate decisions on Colab, we get a list of 'node': <nbdt.data.custom.Node object at 0x7f2318a35a90> objects. How could I get access to the information within the node? And how could I generate a visualization of a decision tree based on this result?

    opened by WesDeng 1
  • commands not available

    commands not available

    `nbdt https://images.pexels.com/photos/126407/pexels-photo-126407.jpeg?auto=compress&cs=tinysrgb&dpr=2&w=32

    OR run on a local image

    nbdt /imaginary/path/to/local/image.png`

    is this not available yet?

    opened by innerpace-X 1
  • please advise

    please advise

    Hi NBDT,

    When I ran the code below....

    from nbdt.model import SoftNBDT from nbdt.models import ResNet18, wrn28_10_cifar10, wrn28_10_cifar100, wrn28_10 # use wrn28_10 for TinyImagenet200

    model = wrn28_10_cifar10() model = SoftNBDT( pretrained=True, dataset='CIFAR10', arch='wrn28_10_cifar10', model=model)

    I received the following error

    not enough values to unpack (expected 2, got 0) Downloading: "https://github.com/alvinwan/neural-backed-decision-trees/releases/download/0.0.1/ckpt-CIFAR10-wrn28_10_cifar10-induced-wrn28_10_cifar10-SoftTreeSupLoss.pth" to /home/jupyter/.cache/torch/checkpoints/ckpt-CIFAR10-wrn28_10_cifar10-induced-wrn28_10_cifar10-SoftTreeSupLoss.pth

    Did it work? Thank you Vivek

    opened by KhatriVivek 0
  • Query about Loss during Training

    Query about Loss during Training

    Just wondering what is the loop "for _loss in args.loss:" in line 195 of main.py for? It seems like the loss is always overridden by the latest one.

    opened by ericotjo001 0
  • Does NBDTs support jit?

    Does NBDTs support jit?

    Hi,

    Great project and article, congratulations.

    I'm wondering if, once the training is finished, the final NBDT can be fully exported as a JIT file?

    It seems that everything is written in pytorch and inference can be done in a single forward pass so it should be fine but I'm still a bit worried about the tree part and the self.rules.forward_with_decisions(x). Have you ever tried to export one of your models as JIT? Does it work as is?

    Thanks in advance for your answer.

    opened by Optimox 0
  • Does it work better than original nerual network in binary classification

    Does it work better than original nerual network in binary classification

    I mean apply NBDT with soft loss.., you know the inner node weight can't be update with the soft loss in binary classfication due to the special inner probability computing method.

    opened by study1157 0
  • How to train with new dataset

    How to train with new dataset

    Hi, I'm trying to use gen_train_eval_nopretrained.sh to train with a new dataset I implemented. However, in main.py there is this line of code tree = Tree.create_from_args(args, classes=trainset.classes) the error I got is FileNotFound at nbdt/hierarchies/mydataset/graph-induced.json and it seems that this line of code requires generated hierarchy? But I can only generate hierarchy after I've trained the model. So I'm a bit confused on what to do.

    opened by XAVILLA 1
  • How do I unpack the output of model.forward_with_decisions(x) ?

    How do I unpack the output of model.forward_with_decisions(x) ?

    Can I find documentation somewhere? I've foudn the colab notebook useful for the single prediction, but I can't get the different labels of a forward with decision tree.

    opened by MarkTensenSgt 1
Releases(0.0.4)
Owner
Alvin Wan
AI PhD candidate at UC Berkeley, building compact neural networks for computer vision in mixed reality, self-driving cars; fan of cheesecake, corgis, Disneyland
Alvin Wan
⬛ Python Individual Conditional Expectation Plot Toolbox

⬛ PyCEbox Python Individual Conditional Expectation Plot Toolbox A Python implementation of individual conditional expecation plots inspired by R's IC

Austin Rochford 140 Dec 30, 2022
Model analysis tools for TensorFlow

TensorFlow Model Analysis TensorFlow Model Analysis (TFMA) is a library for evaluating TensorFlow models. It allows users to evaluate their models on

1.2k Dec 26, 2022
Visualization toolkit for neural networks in PyTorch! Demo -->

FlashTorch A Python visualization toolkit, built with PyTorch, for neural networks in PyTorch. Neural networks are often described as "black box". The

Misa Ogura 692 Dec 29, 2022
TensorFlowTTS: Real-Time State-of-the-art Speech Synthesis for Tensorflow 2 (supported including English, Korean, Chinese, German and Easy to adapt for other languages)

🤪 TensorFlowTTS provides real-time state-of-the-art speech synthesis architectures such as Tacotron-2, Melgan, Multiband-Melgan, FastSpeech, FastSpeech2 based-on TensorFlow 2. With Tensorflow 2, we c

3k Jan 04, 2023
Logging MXNet data for visualization in TensorBoard.

Logging MXNet Data for Visualization in TensorBoard Overview MXBoard provides a set of APIs for logging MXNet data for visualization in TensorBoard. T

Amazon Web Services - Labs 327 Dec 05, 2022
pytorch implementation of "Distilling a Neural Network Into a Soft Decision Tree"

Soft-Decision-Tree Soft-Decision-Tree is the pytorch implementation of Distilling a Neural Network Into a Soft Decision Tree, paper recently published

Kim Heecheol 262 Dec 04, 2022
An Empirical Review of Optimization Techniques for Quantum Variational Circuits

QVC Optimizer Review Code for the paper "An Empirical Review of Optimization Techniques for Quantum Variational Circuits". Each of the python files ca

Owen Lockwood 5 Jun 28, 2022
Bias and Fairness Audit Toolkit

The Bias and Fairness Audit Toolkit Aequitas is an open-source bias audit toolkit for data scientists, machine learning researchers, and policymakers

Data Science for Social Good 513 Jan 06, 2023
Visualize a molecule and its conformations in Jupyter notebooks/lab using py3dmol

Mol Viewer This is a simple package wrapping py3dmol for a single command visualization of a RDKit molecule and its conformations (embed as Conformer

Benoît BAILLIF 1 Feb 11, 2022
Visualizer for neural network, deep learning, and machine learning models

Netron is a viewer for neural network, deep learning and machine learning models. Netron supports ONNX (.onnx, .pb, .pbtxt), Keras (.h5, .keras), Tens

Lutz Roeder 20.9k Dec 28, 2022
Pytorch Feature Map Extractor

MapExtrackt Convolutional Neural Networks Are Beautiful We all take our eyes for granted, we glance at an object for an instant and our brains can ide

Lewis Morris 40 Dec 07, 2022
Neural network visualization toolkit for tf.keras

Neural network visualization toolkit for tf.keras

Yasuhiro Kubota 262 Dec 19, 2022
Using / reproducing ACD from the paper "Hierarchical interpretations for neural network predictions" 🧠 (ICLR 2019)

Hierarchical neural-net interpretations (ACD) 🧠 Produces hierarchical interpretations for a single prediction made by a pytorch neural network. Offic

Chandan Singh 111 Jan 03, 2023
Tool for visualizing attention in the Transformer model (BERT, GPT-2, Albert, XLNet, RoBERTa, CTRL, etc.)

Tool for visualizing attention in the Transformer model (BERT, GPT-2, Albert, XLNet, RoBERTa, CTRL, etc.)

Jesse Vig 4.7k Jan 01, 2023
A library that implements fairness-aware machine learning algorithms

Themis ML themis-ml is a Python library built on top of pandas and sklearnthat implements fairness-aware machine learning algorithms. Fairness-aware M

Niels Bantilan 105 Dec 30, 2022
A python library for decision tree visualization and model interpretation.

dtreeviz : Decision Tree Visualization Description A python library for decision tree visualization and model interpretation. Currently supports sciki

Terence Parr 2.4k Jan 02, 2023
L2X - Code for replicating the experiments in the paper Learning to Explain: An Information-Theoretic Perspective on Model Interpretation.

L2X Code for replicating the experiments in the paper Learning to Explain: An Information-Theoretic Perspective on Model Interpretation at ICML 2018,

Jianbo Chen 113 Sep 06, 2022
Auralisation of learned features in CNN (for audio)

AuralisationCNN This repo is for an example of auralisastion of CNNs that is demonstrated on ISMIR 2015. Files auralise.py: includes all required func

Keunwoo Choi 39 Nov 19, 2022
Portal is the fastest way to load and visualize your deep neural networks on images and videos 🔮

Portal is the fastest way to load and visualize your deep neural networks on images and videos 🔮

Datature 243 Jan 05, 2023
Code for visualizing the loss landscape of neural nets

Visualizing the Loss Landscape of Neural Nets This repository contains the PyTorch code for the paper Hao Li, Zheng Xu, Gavin Taylor, Christoph Studer

Tom Goldstein 2.2k Dec 30, 2022