A modular domain adaptation library written in PyTorch.

Overview

Logo

PyPi version

News

November 19: Git repo is now public

Documentation

Google Colab Examples

See the examples folder for notebooks you can download or run on Google Colab.

Overview

This library consists of 11 modules:

Module Description
Adapters Wrappers for training and inference steps
Containers Dictionaries for simplifying object creation
Datasets Commonly used datasets and tools for domain adaptation
Frameworks Wrappers for training/testing pipelines
Hooks Modular building blocks for domain adaptation algorithms
Layers Loss functions and helper layers
Meta Validators Post-processing of metrics, for hyperparameter optimization
Models Architectures used for benchmarking and in examples
Utils Various tools
Validators Metrics for determining and estimating accuracy
Weighters Functions for weighting losses

How to...

Use in vanilla PyTorch

from pytorch_adapt.hooks import DANNHook
from pytorch_adapt.utils.common_functions import batch_to_device

# Assuming that models, optimizers, and dataloader are already created.
hook = DANNHook(optimizers)
for data in tqdm(dataloader):
    data = batch_to_device(data, device)
    # Optimization is done inside the hook.
    # The returned loss is for logging.
    loss, _ = hook({}, {**models, **data})

Build complex algorithms

Let's customize DANNHook with:

  • virtual adversarial training
  • entropy conditioning
from pytorch_adapt.hooks import EntropyReducer, MeanReducer, VATHook

# G and C are the Generator and Classifier models
misc = {"combined_model": torch.nn.Sequential(G, C)}
reducer = EntropyReducer(
    apply_to=["src_domain_loss", "target_domain_loss"], default_reducer=MeanReducer()
)
hook = DANNHook(optimizers, reducer=reducer, post_g=[VATHook()])
for data in tqdm(dataloader):
    data = batch_to_device(data, device)
    loss, _ = hook({}, {**models, **data, **misc})

Wrap with your favorite PyTorch framework

For additional functionality, adapters can be wrapped with a framework (currently just PyTorch Ignite).

from pytorch_adapt.adapters import DANN
from pytorch_adapt.containers import Models, Optimizers
from pytorch_adapt.datasets import DataloaderCreator
from pytorch_adapt.frameworks.ignite import Ignite

# Assume G, C and D are existing models
models_cont = Models(models)
# Override the default optimizer for G and C
optimizers_cont = Optimizers((torch.optim.Adam, {"lr": 0.123}), keys=["G", "C"])
adapter = DANN(models=models_cont, optimizers=optimizers_cont)

dc = DataloaderCreator(num_workers=2)
trainer = Ignite(adapter)
trainer.run(datasets, dataloader_creator=dc)

Wrappers for other frameworks (e.g. PyTorch Lightning and Catalyst) are planned to be added.

Check your model's performance

You can do this in vanilla PyTorch:

from pytorch_adapt.validators import SNDValidator

# Assuming predictions have been collected
target_train = {"preds": preds}
validator = SNDValidator()
score = validator.score(epoch=1, target_train=target_train)

You can also do this using a framework wrapper:

validator = SNDValidator()
trainer = Ignite(adapter, validator=validator)
trainer.run(datasets, dataloader_creator=dc)

Run the above examples

See this notebook and the examples page for other notebooks.

Installation

Pip

pip install pytorch-adapt

To get the latest dev version:

pip install pytorch-adapt --pre

To use pytorch_adapt.frameworks.ignite:

pip install pytorch-adapt[ignite]

Conda

Coming soon...

Dependencies

Required dependencies:

  • numpy
  • torch >= 1.6
  • torchvision
  • torchmetrics
  • pytorch-metric-learning >= 1.0.0.dev5

Acknowledgements

Contributors

Pull requests are welcome!

Advisors

Thank you to Ser-Nam Lim, and my research advisor, Professor Serge Belongie.

Logo

Thanks to Jeff Musgrave for designing the logo.

Comments
  • How do I change a few things in the implementation?

    How do I change a few things in the implementation?

    I would like to run on my own dataset, and also, print the accuracy on the source and target domains. In the paper implementations, I don't see any accuracy metric. Please guide the amateur learner looking at this to change the dataset to their own, and to implement accuracy terms. Thanks

    question 
    opened by chiragpr 20
  • Extension of the TargetDataset class.

    Extension of the TargetDataset class.

    Suggested Feature

    A) The addition of a new TargetDataset class for supervised domain adaptation.

    or

    B) The extension of the TargetDataset class to return labels when passed a supervised flag.

    I think option B could be cleaner?

    Implementation

    A) Create a new class capable of returning target_labels named something like SupervisedTargetDataset.

    or

    B) Update the init function of the TargetDataset to include a supervised flag.

        def __init__(self, dataset: Dataset, domain: int = 1, supervised=False):
            """
            Arguments:
                dataset: The dataset to wrap
                domain: An integer representing the domain.
            """
            super().__init__(dataset, domain, supervised)
    

    Update the getitem method to behave differently under supervised domain adaptation.

        def __getitem__(self, idx: int) -> Dict[str, Any]:
            
            if supervised:
                img, target_labels = self.dataset[idx]
                return {
                    "target_imgs": img,
                    "target_domain": self.domain,
                    "target_labels": target_labels,
                    "target_sample_idx": idx,
                }
            else:
                img, _ = self.dataset[idx]
                return {
                    "target_imgs": img,
                    "target_domain": self.domain,
                    "target_sample_idx": idx,
                }
    

    Reasoning

    To run supervised domain adaptation we need to have labels in the target domain but I think it would still be useful to distinguish between the two domains using different classes. Rather than using SourceDataset on a TargetDataset to achieve the same functionality.

    With this change validators such as AccuracyValidator could be used on target_val in a supervised domain adaptation setting.


    BTW: With these feature suggestions I am happy to do code PRs along with the docs as I previously mentioned!

    opened by deepseek-eoghan 10
  • Question on DataloaderCreator - How to create test sets

    Question on DataloaderCreator - How to create test sets

    Hello,

    Well done on putting together this library I think it will be extremely useful for many people undertaking domain adaptation projects.

    I am wondering how to create a test dataset using the DataloaderCreator class?

    Some background on my issue.

    I am using the MNISTM example within a PyTorch lightning data-module.

    Adapting the code from the examples/DANNLightning.ipynb I have the following code.

    class MnistAdaptDataModule(LightningDataModule):
        def __init__(
            self,
            data_dir: str = "data/mnistm/",
            batch_size: int = 4,
            num_workers: int = 0,
            pin_memory: bool = False,
        ):
            super().__init__()
    
            # this line allows to access init params with 'self.hparams' attribute
            # it also ensures init params will be stored in ckpt
            self.save_hyperparameters(logger=False)
    
            self.data_train: Optional[Dataset] = None
            self.data_val: Optional[Dataset] = None
            self.data_test: Optional[Dataset] = None
            self.dataloaders = None
    
        def prepare_data(self):
            if not os.path.exists(self.hparams.data_dir):
                print("downloading dataset")
                get_mnist_mnistm(["mnist"], ["mnistm"], folder=self.hparams.data_dir, download=True)
            return
    
    
        def setup(self, stage: Optional[str] = None):
            if not self.data_train and not self.data_val and not self.data_test:
                datasets = get_mnist_mnistm(["mnist"], ["mnistm"], folder=self.hparams.data_dir, download=False)
                dc = DataloaderCreator(batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers)
                validator = IMValidator()
                self.dataloaders = dc(**filter_datasets(datasets, validator))
                self.data_train = self.dataloaders.pop("train")
                self.data_val = list(self.dataloaders.values())
                return            
    
        def train_dataloader(self):
            return self.data_train
    
        def val_dataloader(self):
            return self.data_val
    
       def test_dataloader(self):
            # how to make a test dataset?
            return
    
    

    self.dataloaders produces the following object

    {'src_train': SourceDataset(
      domain=0
      (dataset): ConcatDataset(
        len=60000
        (datasets): [Dataset MNIST
            Number of datapoints: 60000
            Root location: /home/eoghan/Code/mnist-domain-adaptation/data/mnist_adapt/
            Split: Train
            StandardTransform
        Transform: Compose(
                       Resize(size=32, interpolation=bilinear, max_size=None, antialias=None)
                       ToTensor()
                       <pytorch_adapt.utils.transforms.GrayscaleToRGB object at 0x7fd1badcbdc0>
                       Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                   )]
      )
    ), 'src_val': SourceDataset(
      domain=0
      (dataset): ConcatDataset(
        len=10000
        (datasets): [Dataset MNIST
            Number of datapoints: 10000
            Root location: /home/eoghan/Code/mnist-domain-adaptation/data/mnist_adapt/
            Split: Test
            StandardTransform
        Transform: Compose(
                       Resize(size=32, interpolation=bilinear, max_size=None, antialias=None)
                       ToTensor()
                       <pytorch_adapt.utils.transforms.GrayscaleToRGB object at 0x7fd1badcb6a0>
                       Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                   )]
      )
    ), 'target_train': TargetDataset(
      domain=1
      (dataset): ConcatDataset(
        len=59001
        (datasets): [MNISTM(
          domain=MNISTM
          len=59001
          (transform): Compose(
              ToTensor()
              Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
          )
        )]
      )
    ), 'target_val': TargetDataset(
      domain=1
      (dataset): ConcatDataset(
        len=9001
        (datasets): [MNISTM(
          domain=MNISTM
          len=9001
          (transform): Compose(
              ToTensor()
              Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
          )
        )]
      )
    ), 'train': CombinedSourceAndTargetDataset(
      (source_dataset): SourceDataset(
        domain=0
        (dataset): ConcatDataset(
          len=60000
          (datasets): [Dataset MNIST
              Number of datapoints: 60000
              Root location: /home/eoghan/Code/mnist-domain-adaptation/data/mnist_adapt/
              Split: Train
              StandardTransform
          Transform: Compose(
                         Resize(size=32, interpolation=bilinear, max_size=None, antialias=None)
                         ToTensor()
                         <pytorch_adapt.utils.transforms.GrayscaleToRGB object at 0x7fd125f69d60>
                         Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                     )]
        )
      )
      (target_dataset): TargetDataset(
        domain=1
        (dataset): ConcatDataset(
          len=59001
          (datasets): [MNISTM(
            domain=MNISTM
            len=59001
            (transform): Compose(
                ToTensor()
                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            )
          )]
        )
      )
    )}
    

    This handles train and val for source and target as well as creating a conjoined train dataset.

    Going by the example ipynb, the concat dataset for train (of source and target) is used as the training dataset for the model.

    The validation set is a list of the remaining keys in the data-loader and has the following form.

    [
    <torch.utils.data.dataloader.DataLoader object at 0x7fd1063e6b80> {
        dataset: TargetDataset(
      domain=1
      (dataset): ConcatDataset(
        len=59001
        (datasets): [MNISTM(
          domain=MNISTM
          len=59001
          (transform): Compose(
              ToTensor()
              Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
          )
        )]
      )
    )
    }
    ]
    

    I am not sure why this is the validation dataset, Do we validate on only the target domain? How would we handle this validation set if the target domain is unlabelled? If you could explain why this is the case I would appreciate some insight.

    In summation I am looking for guidance on is how to use something like a torch.utils.data.random_split to take some of the source and target data and use the DataloaderCreator to pass back test sets along with train and val, is this possible within the framework?

    Many thanks, Eoghan

    question 
    opened by deepseek-eoghan 6
  • Use DANN with target labels

    Use DANN with target labels

    Hi @KevinMusgrave,

    How could I use DANN with target labels, I tried to do that:

    from pytorch_adapt.hooks import DANNHook, CLossHook, FeaturesAndLogitsHook
    G.count, C.count, D.count = 0, 0, 0
    f_hook =  FeaturesAndLogitsHook(domains = ["src", "target"])
    c_hook = CLossHook(f_hook=f_hook)
    hook = DANNHook(opts,c_hook=c_hook)
    model_counts = validate_hook(hook, list(data.keys()))
    outputs, losses = hook({**models, **data})
    print_info(model_counts, outputs, losses, G, C, D)
    

    But I'm having this issue:

    ValueError: in DANNHook: __call__
    in ChainHook: __call__
    in OptimizerHook: __call__
    in ChainHook: __call__
    in ChainHook: __call__
    in CLossHook: __call__
    too many values to unpack (expected 1)
    

    Thanks in advance!

    opened by rtaiello 2
  • Specific Architecture

    Specific Architecture

    Hi @KevinMusgrave,

    I would like to ask the following question, since I'm trying to play with the library and I think that what I want to do it's easily doable exploiting all the library features.

    I would like to try to implement the following architecture, given two separate src (src_1, src_2) and given two independent generators (g_1, g_2) and two independent classifiers (C_1, C_2). Where features_1 = G_1 (src_1) is input of C_1, and likewise features_2 = G_2(src_2) is input of C_2. And both features_1 and features_2 are passed to D (DANN's discriminator) which is shared.

    Many thanks in advance!

    scratch drawio

    opened by rtaiello 2
  • Saving and Restoring a Trained Model

    Saving and Restoring a Trained Model

    Hi, this is roughly the code that I am using for training my models:

    models = Models({"G": G, "C": C, "D": D})
    adapter = DANN(models=models)
    validator = IMValidator()
    dataloaders = dc(**filter_datasets(datasets, validator))
    train_loader = dataloaders.pop("train")
    
    L_adapter = Lightning(adapter, validator=validator)
    trainer = pl.Trainer(gpus=1, 
                         max_epochs=1,
                         default_root_dir="saved_models",
                         enable_checkpointing=True)
    trainer.fit(L_adapter, train_loader, list(dataloaders.values()))
    

    which causes the latest model to be saved under saved_models/lightning_logs/version_0/checkpoints/epoch=1-step=2832.ckpt.

    Question 1): Is it possible to restore all three models, G, C and D from this checkpoint, and if yes how? I know that Lightning provides the function load_from_checkpoint() but I can't get it to work. Question 2) If it is not possible to restore these models from the Lightning checkpoint, should I instead just manually save the state_dicts of G, C and D and then manually restore these, or is there a more elegant way?

    opened by r0f1 2
  • No module named pytorch_adapt

    No module named pytorch_adapt

    I ran pip install pytorch-adapt but only to be stuck later as whenever I say "import pytorch_adapt" on the python shell (Linux) I am faced with this annoying error. Where am I going wrong?

    opened by chiragpr 1
  • Add domain parameter to CLossHook

    Add domain parameter to CLossHook

    Right now it's hardcoded to use src_logits. Adding a domain parameter (set to either src or target) would allow CLossHook to be used for supervised domain adaptation as well.

    https://github.com/KevinMusgrave/pytorch-adapt/blob/3b2713c4860b325c79481f11307a193bb381d53f/src/pytorch_adapt/hooks/classification.py#L75-L88

    enhancement 
    opened by KevinMusgrave 1
  • Typo in the ATDOC algorithm

    Typo in the ATDOC algorithm

    Hi, I am the first author of ATDOC, thanks for including our method in such an impressive library.

    There exists a typo in the paper of Eq.(6) (already updated the arxiv version today), where the index k should be replaced with i. That is to say, in Line 87 of this python file pytorch_adapt/layers/neighborhood_aggregation.py, the correct code would be "logits = (logits ** p) / torch.sum(logits ** p, dim=0)".

    Best,

    Jian

    bug 
    opened by bluelg 1
  • Simplify load_objects by using latest pytorch-ignite

    Simplify load_objects by using latest pytorch-ignite

    See: https://github.com/KevinMusgrave/pytorch-adapt/blob/97afa6d801e48b7e30854dbb11fc7ebae5abb3c3/src/pytorch_adapt/frameworks/ignite/checkpoint_utils.py#L86-L101

    enhancement 
    opened by KevinMusgrave 0
  • from pytorch_adapt error

    from pytorch_adapt error

    Hi I installed pytorch-adapt with pip. But when I tried from pytorch_adapt.datasets import ( CombinedSourceAndTargetDataset, SourceDataset, TargetDataset, ) "No module named 'pytorch_adapt'" occured.

    My python version is 3.9.5.

    Thank you.

    opened by Jio0728 7
  • Make it clear that the downloaded OfficeHome dataset is resized

    Make it clear that the downloaded OfficeHome dataset is resized

    The original OfficeHome dataset has very large images, so I downscaled them so that the shortest side is 256 pixels, but this isn't mentioned anywhere in the docs or the code.

    opened by KevinMusgrave 0
  • Extending the Lightning class (pytorch_adapt/frameworks/lightning/lightning.py)

    Extending the Lightning class (pytorch_adapt/frameworks/lightning/lightning.py)

    Suggested Feature

    The Lightning class could be extended to include two more functions

    1. test_step
    2. test_epoch_end

    Implementation

    The test_step and test_epoch end could operate in the same way as the validation_step and epoch_end and return a test_score.

    Reasoning

    This would allow lightning users to specify a test_dataloader containing a hold-out set in their datamodule. The best saved model on validation data can then be run against the test data using pytorch lightning trainer.test function call.

    enhancement 
    opened by deepseek-eoghan 1
Releases(v0.0.82)
  • v0.0.82(Dec 1, 2022)

  • v0.0.81(Sep 20, 2022)

  • v0.0.80(Sep 2, 2022)

    Features

    • Added pretrained models for DomainNet126
    • Added transforms.classification.get_timm_transform

    Bug fixes

    • Fixed bug where map_location wasn't being used in a useful way when downloading pretrained models.
    • Fixed some formatting issues in the documentation.
    Source code(tar.gz)
    Source code(zip)
  • v0.0.79(Aug 16, 2022)

    Features

    • Added APValidator
    • Added adapters.MultiLabelClassifier
    • Added hooks.MultiLabelClassifierHook
    • Added frameworks.ignite.IgniteMultiLabelClassification
    • Added models.pretrained_scores
    Source code(tar.gz)
    Source code(zip)
  • v0.0.78(Aug 12, 2022)

    Features

    • Added VOCMultiLabel dataset
    • Added Clipart1kMultiLabel dataset
    • Added get_voc_multilabel dataset getter

    Breaking changes

    • Moved get_mnist_transform, get_resnet_transform, and GrayscaleToRGB to a new transforms module
    Source code(tar.gz)
    Source code(zip)
  • v0.0.77(Jul 23, 2022)

    Features

    Made DomainNet126 downloadable:

    from pytorch_adapt.datasets import get_domainnet126
    datasets = get_domainnet126(["clipart"], ["real"], folder=".", download=True)
    
    Source code(tar.gz)
    Source code(zip)
  • v0.0.76(Jul 4, 2022)

  • v0.0.75(Jun 28, 2022)

  • v0.0.74(May 30, 2022)

    Features

    • Pass kwargs down from pretrained model getters to load_state_dict_from_url. For example, this allows map_location to be specified:
    from pytorch_adapt.models import office31C
    
    model = office31C(domain="dslr", pretrained=True, map_location=torch.device("cpu"))
    
    Source code(tar.gz)
    Source code(zip)
  • v0.0.73(May 30, 2022)

  • v0.0.72(Apr 27, 2022)

    Added a supervised flag for dataset getters

    Setting this to True results in labeled target_train and target_val datasets.

    Example:

    from pytorch_adapt.datasets import get_mnist_mnistm
    
    datasets = get_mnist_mnistm(
        ["mnist"],
        ["mnistm"],
        folder=".",
        supervised=True,
    )
    
    # datasets["target_train"] and datasets["target_val"] are of type TargetDataset, with self.supervised = True
    

    Setting return_target_with_labels=True returns type TargetDataset instead of SourceDataset

    Example:

    from pytorch_adapt.datasets import get_mnist_mnistm
    
    datasets = get_mnist_mnistm(
        ["mnist"],
        ["mnistm"],
        folder=".",
        return_target_with_labels=True,
    )
    
    # datasets["target_train_with_labels"] and datasets["target_val_with_labels"] are of type TargetDataset
    

    Thanks to @deepseek-eoghan for the contribution.

    Source code(tar.gz)
    Source code(zip)
  • v0.0.71(Apr 14, 2022)

    Improvements to TargetDataset

    • A new supervised flag, for switching between supervised and unsupervised domain adaptation.
    • Allow the wrapped dataset to return either (data, label) or just data

    See the documentation

    Code changes: #61 by @deepseek-eoghan

    Source code(tar.gz)
    Source code(zip)
  • v0.0.70(Apr 9, 2022)

  • v0.0.61(Mar 2, 2022)

    Debugging messages are appended to the traceback when an exception occurs inside a hook (#24).

    For example:

    Old behavior:

    Traceback (most recent call last):
      ...
    TypeError: forward() takes 2 positional arguments but 3 were given
    

    New behavior:

    Traceback (most recent call last):
      ...
    TypeError: in GVBHook: __call__
    in ChainHook: __call__
    in OptimizerHook: __call__
    in ChainHook: __call__
    in FeaturesLogitsAndGBridge: __call__
    in GBridgeAndLogitsHook: __call__
    GBridgeAndLogitsHook: Getting src
    GBridgeAndLogitsHook: Getting output: ['src_imgs_features_logits', 'src_imgs_features_gbridge']
    GBridgeAndLogitsHook: Using model C with inputs: src_imgs_features, return_bridge
    forward() takes 2 positional arguments but 3 were given
    C.forward() signature is (input: torch.Tensor) -> torch.Tensor
    
    
    Source code(tar.gz)
    Source code(zip)
  • v0.0.60(Feb 28, 2022)

    Swapped order of input and output argument of hooks.

    | Before | After | | - | - | losses, output = hook(losses, inputs) | output, losses = hook(inputs, losses)

    The loss input argument is now optional, which makes the top level syntax cleaner:

    # old
    hook({}, {**models, **data})
    
    # new
    hook({**models, **data})
    
    Source code(tar.gz)
    Source code(zip)
Owner
Kevin Musgrave
Computer science PhD student studying computer vision and machine learning.
Kevin Musgrave
On-device speech-to-index engine powered by deep learning.

On-device speech-to-index engine powered by deep learning.

Picovoice 30 Nov 24, 2022
CVPR 2021

Smoothing the Disentangled Latent Style Space for Unsupervised Image-to-image Translation [Paper] | [Poster] | [Codes] Yahui Liu1,3, Enver Sangineto1,

Yahui Liu 37 Sep 12, 2022
PyTorch-lightning implementation of the ESFW module proposed in our paper Edge-Selective Feature Weaving for Point Cloud Matching

Edge-Selective Feature Weaving for Point Cloud Matching This repository contains a PyTorch-lightning implementation of the ESFW module proposed in our

5 Feb 14, 2022
A PyTorch implementation of the architecture of Mask RCNN

EDIT (AS OF 4th NOVEMBER 2019): This implementation has multiple errors and as of the date 4th, November 2019 is insufficient to be utilized as a reso

Sai Himal Allu 975 Dec 30, 2022
PassAPI is a password generator in hash format and fully developed in Python, with the aim of teaching how to handle and build

simple, elegant and safe Introduction PassAPI is a password generator in hash format and fully developed in Python, with the aim of teaching how to ha

Johnsz 2 Mar 02, 2022
SCAN: Learning to Classify Images without Labels, incl. SimCLR. [ECCV 2020]

Learning to Classify Images without Labels This repo contains the Pytorch implementation of our paper: SCAN: Learning to Classify Images without Label

Wouter Van Gansbeke 1.1k Dec 30, 2022
A GPU-optional modular synthesizer in pytorch, 16200x faster than realtime, for audio ML researchers.

torchsynth The fastest synth in the universe. Introduction torchsynth is based upon traditional modular synthesis written in pytorch. It is GPU-option

torchsynth 229 Jan 02, 2023
Bachelor's Thesis in Computer Science: Privacy-Preserving Federated Learning Applied to Decentralized Data

federated is the source code for the Bachelor's Thesis Privacy-Preserving Federated Learning Applied to Decentralized Data (Spring 2021, NTNU) Federat

Dilawar Mahmood 25 Nov 30, 2022
Personalized Transfer of User Preferences for Cross-domain Recommendation (PTUPCDR)

This is the official implementation of our paper Personalized Transfer of User Preferences for Cross-domain Recommendation (PTUPCDR), which has been accepted by WSDM2022.

Yongchun Zhu 81 Dec 29, 2022
TuckER: Tensor Factorization for Knowledge Graph Completion

TuckER: Tensor Factorization for Knowledge Graph Completion This codebase contains PyTorch implementation of the paper: TuckER: Tensor Factorization f

Ivana Balazevic 296 Dec 06, 2022
Scheduling BilinearRewards

Scheduling_BilinearRewards Requirement Python 3 =3.5 Structure main.py This file includes the main function. For getting the results in Figure 1, ple

junghun.kim 0 Nov 25, 2021
CKD - Collaborative Knowledge Distillation for Heterogeneous Information Network Embedding

Collaborative Knowledge Distillation for Heterogeneous Information Network Embed

zhousheng 9 Dec 05, 2022
Bianace Prediction Pytorch Model

Bianace Prediction Pytorch Model Main Results ETHUSDT from 2021-01-01 00:00:00 t

RoyYang 4 Jul 20, 2022
A Transformer-Based Siamese Network for Change Detection

ChangeFormer: A Transformer-Based Siamese Network for Change Detection (Under review at IGARSS-2022) Wele Gedara Chaminda Bandara, Vishal M. Patel Her

Wele Gedara Chaminda Bandara 214 Dec 29, 2022
Azua - build AI algorithms to aid efficient decision-making with minimum data requirements.

Project Azua 0. Overview Many modern AI algorithms are known to be data-hungry, whereas human decision-making is much more efficient. The human can re

Microsoft 197 Jan 06, 2023
Libraries, tools and tasks created and used at DeepMind Robotics.

dm_robotics: Libraries, tools, and tasks created and used for Robotics research at DeepMind. Package overview Package Summary Transformations Rigid bo

DeepMind 273 Jan 06, 2023
Binary Passage Retriever (BPR) - an efficient passage retriever for open-domain question answering

BPR Binary Passage Retriever (BPR) is an efficient neural retrieval model for open-domain question answering. BPR integrates a learning-to-hash techni

Studio Ousia 147 Dec 07, 2022
Collective Multi-type Entity Alignment Between Knowledge Graphs (WWW'20)

CG-MuAlign A reference implementation for "Collective Multi-type Entity Alignment Between Knowledge Graphs", published in WWW 2020. If you find our pa

Bran Zhu 28 Dec 11, 2022
Causal estimators for use with WhyNot

WhyNot Estimators A collection of causal inference estimators implemented in Python and R to pair with the Python causal inference library whynot. For

ZYKLS 8 Apr 06, 2022