Powerful unsupervised domain adaptation method for dense retrieval.

Overview

Generative Pseudo Labeling (GPL)

GPL is an unsupervised domain adaptation method for training dense retrievers. It is based on query generation and pseudo labeling with powerful cross-encoders. To train a domain-adapted model, it needs only the unlabeled target corpus and can achieve significant improvement over zero-shot models.

For more information, checkout our publication:

Installation

One can either install GPL via pip

pip install gpl

or via git clone

git clone https://github.com/UKPLab/gpl.git && cd gpl
pip install -e .

Usage

GPL accepts data in the BeIR-format. For example, we can download the FiQA dataset hosted by BeIR:

wget https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/fiqa.zip
unzip fiqa.zip
head -n 2 fiqa/corpus.jsonl  # One can check this data format. Actually GPL only need this `corpus.jsonl` as data input for training.

Then we can either use the python -m function to run GPL training directly:

export dataset="fiqa"
python -m gpl.train \
    --path_to_generated_data "generated/$dataset" \
    --base_ckpt 'distilbert-base-uncased' \
    --batch_size_gpl 32 \
    --gpl_steps 140000 \
    --output_dir "output/$dataset" \
    --evaluation_data "./$dataset" \
    --evaluation_output "evaluation/$dataset" \
    --generator "BeIR/query-gen-msmarco-t5-base-v1" \
    --retrievers "msmarco-distilbert-base-v3" "msmarco-MiniLM-L-6-v3" \
    --cross_encoder "cross-encoder/ms-marco-MiniLM-L-6-v2" \
    --qgen_prefix "qgen" \
    --do_evaluation \
    # --use_amp   # Use this for efficient training if the machine supports AMP

# One can run `python -m gpl.train --help` for the information of all the arguments
# To reproduce the experiments in the paper, set `base_ckpt` to "GPL/msmarco-distilbert-margin-mse" (https://huggingface.co/GPL/msmarco-distilbert-margin-mse)

or import GPL's trainining method in a python script:

import gpl

dataset = 'fiqa'
gpl.train(
    path_to_generated_data=f"generated/{dataset}",
    base_ckpt='distilbert-base-uncased',  
    # base_ckpt='GPL/msmarco-distilbert-margin-mse',  # The starting checkpoint of the experiments in the paper
    batch_size_gpl=32,
    gpl_steps=140000,
    output_dir=f"output/{dataset}",
    evaluation_data=f"./{dataset}",
    evaluation_output=f"evaluation/{dataset}",
    generator="BeIR/query-gen-msmarco-t5-base-v1",
    retrievers=["msmarco-distilbert-base-v3", "msmarco-MiniLM-L-6-v3"],
    cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
    qgen_prefix="qgen",
    do_evaluation=True,
    # --use_amp   # One can use this flag for enabling the efficient float16 precision
)

How does GPL work?

The workflow of GPL is shown as follows:

  1. GPL first use a seq2seq (we use BeIR/query-gen-msmarco-t5-base-v1 by default) model to generate queries_per_passage queries for each passage in the unlabeled corpus. The query-passage pairs are viewed as positive examples for training.

    Result files (under path $path_to_generated_data): (1) ${qgen}-qrels/train.tsv, (2) ${qgen}-queries.jsonl and also (3) corpus.jsonl (copied from $evaluation_data/);

  2. Then, it runs negative mining with the generated queries as input on the target corpus. The mined passages will be viewed as negative examples for training. One can specify any dense retrievers (SBERT or Huggingface/transformers checkpoints, we use msmarco-distilbert-base-v3 + msmarco-MiniLM-L-6-v3 by default) or BM25 to the argument retrievers as the negative miner.

    Result file (under path $path_to_generated_data): hard-negatives.jsonl;

  3. Finally, it does pseudo labeling with the powerful cross-encoders (we use cross-encoder/ms-marco-MiniLM-L-6-v2 by default.) on the query-passage pairs that we have so far (for both positive and negative examples).

    Result file (under path $path_to_generated_data): gpl-training-data.tsv. It contains (gpl_steps * batch_size_gpl) tuples in total.

Up to now, we have the actual training data ready. One can look at sample-data/generated/fiqa for a quick example about the data format. The very last step is to apply the MarginMSE loss to teach the student retriever to mimic the margin scores, CE(query, positive) - CE(query, negative) labeled by the teacher model (Cross-Encoder, CE).

Customized data

One can also replace/put the customized data for any intermediate step under the path $path_to_generated_data with the same name fashion. GPL will skip the intermediate steps by using these provided data.

Citation

If you use the code for evaluation, feel free to cite our publication GPL: Generative Pseudo Labeling for Unsupervised Domain Adaptation of Dense Retrieval:

@article{wang2021gpl,
    title = "GPL: Generative Pseudo Labeling for Unsupervised Domain Adaptation of Dense Retrieval",
    author = "Kexin Wang and Nandan Thakur and Nils Reimers and Iryna Gurevych", 
    journal= "arXiv preprint arXiv:2112.07577",
    month = "4",
    year = "2021",
    url = "https://arxiv.org/abs/2112.07577",
}

Contact person and main contributor: Kexin Wang, [email protected]

https://www.ukp.tu-darmstadt.de/

https://www.tu-darmstadt.de/

Don't hesitate to send us an e-mail or report an issue, if something is broken (and it shouldn't be) or if you have further questions.

This repository contains experimental software and is published for the sole purpose of giving additional background details on the respective publication.

Comments
  • Error while running the training script

    Error while running the training script

    2022-04-14 06:00:25] INFO [gpl.toolkit.pl.run:60] Begin pseudo labeling 0%| | 0/140000 [00:00<?, ?it/s] Traceback (most recent call last): File "/home/ec2-user/SageMaker/gpl/gpl/toolkit/pl.py", line 63, in run batch = next(hard_negative_iterator) File "/home/ec2-user/SageMaker/kernels/gpl_env/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 530, in next data = self._next_data() File "/home/ec2-user/SageMaker/kernels/gpl_env/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 569, in _next_data index = self._next_index() # may raise StopIteration File "/home/ec2-user/SageMaker/kernels/gpl_env/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 521, in _next_index return next(self._sampler_iter) # may raise StopIteration StopIteration

    opened by kingafy 3
  • Loss function

    Loss function

    Is it a typo of having the minus sign "-" in the MarginMSE loss function in Equation (1) in the GPL paper?

    There should be no minus sign "-". Because the model should minimize the MSE(delta_teacher, delta_student), not maximize it. I checked the released code of GPL, the loss function is without the minus sign "-".

    image image
    opened by dli1 2
  • GPU speedup

    GPU speedup

    I recon this is more of a generic question for TSADE + GPL (or any transformer used) , but can you use GPU by simply doing something like gpl.to(device)?

    opened by ahadda5 1
  • [KTLO-6] Hints for missing evaluation data

    [KTLO-6] Hints for missing evaluation data

    The previous code does not give enough hint about missing evaluation data

    • gpl/toolkit/evaluation.py: Added checking for missing evaluation data
    • tests/unit/conftest.py: Separated sbert and sbert_path fixtures
    • tests/unit/test_eval.py: Added test
    opened by kwang2049 0
  • [KTLO-5] batch size larger than data size

    [KTLO-5] batch size larger than data size

    The previous code did not check whether the batch size is larger than the number of data points (or number of generated queries) in PseudoLabeler.run

    • pl/toolkit/pl.py: Added check at the beginning of run about batch size vs data size
    • tests/unit/test_pl.py: Added test
    opened by kwang2049 0
  • [KTLO-4] OOM error in qgen

    [KTLO-4] OOM error in qgen

    Previous code does not detect OOM error in QGen, which might be due to large QPP or batch size

    modified: gpl/toolkit/qgen.py: Added try catch new file: tests/unit/test_qgen.py: Added test

    opened by kwang2049 0
  • [KTLO-3] OOM error in loadable checking

    [KTLO-3] OOM error in loadable checking

    The current version could not identify OOM error in loadable_by_sbert_oom, since OOM is also a runtime error and this loadable checking views all runtime errors as not loadable

    • modified: gpl/toolkit/sbert.py: Raise OOM error (runtime error)
    • modified: setup.py: Added pytest
    • new file: tests/unit/conftest.py: SBERT fixture
    • new file: tests/unit/test_sbert.py: Test OOM error case
    opened by kwang2049 0
  • [KTLO-0] New EES version and black formatting

    [KTLO-0] New EES version and black formatting

    • README.md: Hint of installing PyTorch correctly wrt. the CUDA version.
    • gpl/toolkit/beir.py: Black
    • gpl/toolkit/dataset.py: Black
    • gpl/toolkit/evaluation.py: Black
    • gpl/toolkit/log.py: Black
    • gpl/toolkit/loss.py: Black
    • gpl/toolkit/mine.py: Black
    • gpl/toolkit/mnrl.py: Black
    • gpl/toolkit/pl.py: Black
    • gpl/toolkit/qgen.py: Black
    • gpl/toolkit/reformat.py: Black
    • gpl/toolkit/rescale.py: Black
    • gpl/toolkit/resize.py: Black
    • gpl/toolkit/sbert.py: Black
    • gpl/train.py: Black
    • setup.py: Added protobuf, required by T5 and seems to be ignored by simply installing transformer; specified ees>=0.0.8 (where the es version is kept the same with that required by beir)
    opened by kwang2049 0
  • Should the leaning domain contain only assertion texts (like

    Should the leaning domain contain only assertion texts (like "Python is a high-level general-purpose programming language")?

    Hi. Should the leaning domain contain only assertion texts (like "Python is a high-level general-purpose programming language" in your example)? In your pipeline the first step is Query Generation: For a given text from our domain, we first use a T5 model that generates a possible query for the given text. E.g. when your text is “Python is a high-level general-purpose programming language”, the model might generate a query like “What is Python”. You can find various query generators on our doc2query-hub. Does that mean that texts which couldn't be converted into queries (e.g. "Investment consulting for legal entities and individuals.") cannot be used for training?

    opened by edgar2597 0
  • GPL for sentence embedding tasks?

    GPL for sentence embedding tasks?

    In the provided examples GPL us used for semantic search tasks: given a query, relevant results should be retrieved. Is it also the recommended approach to get meaningful embeddings / bi-encoders, or is it better to use TSDAE?

    opened by hanshupe 2
  • Guidance on gpl_stapes, new_size and batch_size_gpl

    Guidance on gpl_stapes, new_size and batch_size_gpl

    Hello,

    I am looking for some guidance on below parameters of gpl.train().

    • gpl_stapes - Do we need such a huge value of 140000 for corpus of size 1300?
    • new_size
    • batch_size_gpl - would it help to speed up the training if we keep this as 64 or 128? How to derive the values of these parameters based on dataset or corpus.jsonl?
    opened by MyBruso 0
  • TSDAE to GPL... Error on start

    TSDAE to GPL... Error on start

    I'm trying to go from my trained TSDAE and then apply GPL... However, keep getting errors.

    ! export dataset="hs_resume_tsdae_gpl_mini" 
    ! python -m gpl.train \
        --path_to_generated_data "generated/$dataset" \
        --base_ckpt "/Users/cfeld/Desktop/dev/trajectory/finetuning/gpl/outputs/tsdae/MiniLM-L6-H384-uncased-model" \
        --gpl_score_function "dot" \
        --batch_size_gpl 34 \
        --gpl_steps 100 \
        --queries_per_passage 1 \
        --output_dir "output/$dataset" \
        --evaluation_data "./$dataset" \
        --evaluation_output "evaluation/$dataset" \
        --generator "BeIR/query-gen-msmarco-t5-base-v1" \
        --retrievers "msmarco-distilbert-base-v3" "msmarco-MiniLM-L-6-v3" \
        --retriever_score_functions "cos_sim" "cos_sim" \
        --cross_encoder "cross-encoder/ms-marco-MiniLM-L-6-v2" \
        --use_train_qrels
    

    However, I'm getting this error:

    2022-09-12 17:37:44 - Loading faiss.
    2022-09-12 17:37:44 - Successfully loaded faiss.
    /opt/homebrew/Caskroom/miniconda/base/envs/finetune_hs/lib/python3.9/runpy.py:127: RuntimeWarning: 'gpl.train' found in sys.modules after import of package 'gpl', but prior to execution of 'gpl.train'; this may result in unpredictable behaviour
      warn(RuntimeWarning(msg))
    [2022-09-12 17:37:44] INFO [gpl.train.train:79] Corpus does not exist in generated/. Now clone the one from the evaluation path ./
    [2022-09-12 17:37:44] WARNING [gpl.train.train:106] Found `qgen_prefix` is not None. By setting `use_train_qrels == True`, the `qgen_prefix` will not be used
    [2022-09-12 17:37:44] INFO [gpl.train.train:113] Loading qrels and queries from labeled data under the path of `evaluation_data`
    Traceback (most recent call last):
      File "/opt/homebrew/Caskroom/miniconda/base/envs/finetune_hs/lib/python3.9/runpy.py", line 197, in _run_module_as_main
        return _run_code(code, main_globals, None,
      File "/opt/homebrew/Caskroom/miniconda/base/envs/finetune_hs/lib/python3.9/runpy.py", line 87, in _run_code
        exec(code, run_globals)
      File "/opt/homebrew/Caskroom/miniconda/base/envs/finetune_hs/lib/python3.9/site-packages/gpl/train.py", line 250, in <module>
        train(**vars(args))
      File "/opt/homebrew/Caskroom/miniconda/base/envs/finetune_hs/lib/python3.9/site-packages/gpl/train.py", line 114, in train
        assert 'qrels' in os.listdir(evaluation_data) and 'queries.jsonl' in os.listdir(evaluation_data)
    AssertionError
    

    Perhaps my folder structure isn't quite right? I've tried all kinds of combos... Folder: corpus.jsonl evaluation - corpus.jsonl - hs_resume_tsdae_gpl_mini -- corpus.jsonl generated - corpus.jsonl - hs_resume_tsdae_gpl_mini -- corpus.jsonl hs_resume_tsdae_gpl_mini - corpus.jsonl output - hs_resume_tsdae_gpl_mini

    opened by christophermfeld 1
  • Evaluation data format

    Evaluation data format

    Hi,

    1/ How should the evaluation data format be as passed in the evaluation_data argument? Could you provide me some example of evaluation data and how it should be formatted?

    2/ How does the evaluation work on these data? What are the tests passed and labels used?

    Thanks!

    opened by Matthieu-Tinycoaching 0
  • RuntimeError: CUDA out of memory

    RuntimeError: CUDA out of memory

    Hi,

    When trying to generate intermediate results with the following command:

    dataset = 'tiny'
    gpl.train(
        path_to_generated_data=f"generated/{dataset}",
        base_ckpt='sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2',  
        # base_ckpt='GPL/msmarco-distilbert-margin-mse',  # The starting checkpoint of the experiments in the paper
        gpl_score_function="dot",
        # Note that GPL uses MarginMSE loss, which works with dot-product
        batch_size_gpl=32,
        gpl_steps=140000,
        new_size=-1,
        # Resize the corpus to `new_size` (|corpus|) if needed. When set to None (by default), the |corpus| will be the full size. When set to -1, the |corpus| will be set automatically: If QPP * |corpus| <= 250K, |corpus| will be the full size; else QPP will be set 3 and |corpus| will be set to 250K / 3
        queries_per_passage=-1,
        # Number of Queries Per Passage (QPP) in the query generation step. When set to -1 (by default), the QPP will be chosen automatically: If QPP * |corpus| <= 250K, then QPP will be set to 250K / |corpus|; else QPP will be set 3 and |corpus| will be set to 250K / 3
        output_dir=f"output/{dataset}",
        evaluation_data=f"./{dataset}",
        evaluation_output=f"evaluation/{dataset}",
        generator="BeIR/query-gen-msmarco-t5-large-v1",
        retrievers=["msmarco-distilbert-base-tas-b", "msmarco-MiniLM-L6-cos-v5"],
        retriever_score_functions=["dot", "cos_sim"],
        # Note that these two retriever model work with cosine-similarity
        cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
        qgen_prefix="qgen",
        # This prefix will appear as part of the (folder/file) names for query-generation results: For example, we will have "qgen-qrels/" and "qgen-queries.jsonl" by default.
        do_evaluation=True,
        use_amp=True   # One can use this flag for enabling the efficient float16 precision
    )
    

    I got the following error:

    2022-08-26 11:55:08 - Loading faiss with AVX2 support.
    2022-08-26 11:55:08 - Could not load library with AVX2 support due to:
    ModuleNotFoundError("No module named 'faiss.swigfaiss_avx2'")
    2022-08-26 11:55:08 - Loading faiss.
    2022-08-26 11:55:08 - Successfully loaded faiss.
    [2022-08-26 11:55:10] INFO [gpl.train.train:79] Corpus does not exist in generated/tiny. Now clone the one from the evaluation path ./tiny
    [2022-08-26 11:55:10] INFO [gpl.train.train:84] Automatically set `new_size` to 83334
    [2022-08-26 11:55:10] INFO [beir.datasets.data_loader.load_corpus:89] Loading Corpus...
    100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4252/4252 [00:00<00:00, 277639.61it/s]
    [2022-08-26 11:55:10] INFO [beir.datasets.data_loader.load_corpus:91] Loaded 4252 Documents.
    [2022-08-26 11:55:10] INFO [beir.datasets.data_loader.load_corpus:92] Doc Example: {'text': 'Without a specific goal for your speech, your audience will be lost in understanding the message you are seeking to deliver, because you will not know yourself what you are seeking to deliver in that speech.', 'title': ''}
    [2022-08-26 11:55:10] WARNING [gpl.toolkit.resize.resize:19] `new_size` should be smaller than the corpus size
    [2022-08-26 11:55:10] INFO [gpl.toolkit.resize.resize:41] Resized the corpus in ./tiny to generated/tiny with new size 83334
    [2022-08-26 11:55:10] INFO [beir.datasets.data_loader.load_corpus:89] Loading Corpus...
    100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4252/4252 [00:00<00:00, 321974.74it/s]
    [2022-08-26 11:55:10] INFO [beir.datasets.data_loader.load_corpus:91] Loaded 4252 Documents.
    [2022-08-26 11:55:10] INFO [beir.datasets.data_loader.load_corpus:92] Doc Example: {'text': 'Without a specific goal for your speech, your audience will be lost in understanding the message you are seeking to deliver, because you will not know yourself what you are seeking to deliver in that speech.', 'title': ''}
    [2022-08-26 11:55:10] INFO [gpl.train.train:99] Automatically set `queries_per_passage` to 59
    [2022-08-26 11:55:10] INFO [gpl.train.train:125] No generated queries found. Now generating it
    [2022-08-26 11:55:10] INFO [beir.datasets.data_loader.load_corpus:89] Loading Corpus...
    100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4252/4252 [00:00<00:00, 308459.11it/s]
    [2022-08-26 11:55:10] INFO [beir.datasets.data_loader.load_corpus:91] Loaded 4252 Documents.
    [2022-08-26 11:55:10] INFO [beir.datasets.data_loader.load_corpus:92] Doc Example: {'text': 'Without a specific goal for your speech, your audience will be lost in understanding the message you are seeking to deliver, because you will not know yourself what you are seeking to deliver in that speech.', 'title': ''}
    [2022-08-26 11:55:20] INFO [beir.generation.models.auto_model.__init__:16] Use pytorch device: cuda
    [2022-08-26 11:55:21] INFO [beir.generation.generate.generate:40] Starting to Generate 59 Questions Per Passage using top-p (nucleus) sampling...
    [2022-08-26 11:55:21] INFO [beir.generation.generate.generate:41] Params: top_p = 0.95
    [2022-08-26 11:55:21] INFO [beir.generation.generate.generate:42] Params: top_k = 25
    [2022-08-26 11:55:21] INFO [beir.generation.generate.generate:43] Params: max_length = 64
    [2022-08-26 11:55:21] INFO [beir.generation.generate.generate:44] Params: ques_per_passage = 59
    [2022-08-26 11:55:21] INFO [beir.generation.generate.generate:45] Params: batch size = 32
    pas:   0%|                                                                                                                                                                                          | 0/133 [00:00<?, ?it/s]
    Traceback (most recent call last):
      File "/home/matthieu/Tinycoaching/GPL/v.0.1.0/gpl_query_generation.py", line 316, in <module>
        gpl.train(
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/gpl/train.py", line 127, in train
        qgen(path_to_generated_data, path_to_generated_data, generator_name_or_path=generator, ques_per_passage=queries_per_passage, bsz=batch_size_generation, qgen_prefix=qgen_prefix)
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/gpl/toolkit/qgen.py", line 23, in qgen
        generator.generate(corpus, output_dir=output_dir, ques_per_passage=ques_per_passage, prefix=prefix, batch_size=bsz)
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/beir/generation/generate.py", line 54, in generate
        queries = self.model.generate(
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/beir/generation/models/auto_model.py", line 28, in generate
        outs = self.model.generate(
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
        return func(*args, **kwargs)
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/transformers/generation_utils.py", line 1326, in generate
        return self.sample(
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/transformers/generation_utils.py", line 1944, in sample
        outputs = self(
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/transformers/models/t5/modeling_t5.py", line 1639, in forward
        decoder_outputs = self.decoder(
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/transformers/models/t5/modeling_t5.py", line 1035, in forward
        layer_outputs = layer_module(
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/transformers/models/t5/modeling_t5.py", line 692, in forward
        cross_attention_outputs = self.layer[1](
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/transformers/models/t5/modeling_t5.py", line 606, in forward
        attention_output = self.EncDecAttention(
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/transformers/models/t5/modeling_t5.py", line 509, in forward
        scores = torch.matmul(
    RuntimeError: CUDA out of memory. Tried to allocate 584.00 MiB (GPU 0; 23.70 GiB total capacity; 20.69 GiB already allocated; 587.94 MiB free; 20.83 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
    

    My corpus consists of small paragraphs of 3-4 lines and I used use_amp option. How could I deal with it?

    opened by Matthieu-Tinycoaching 1
Releases(v0.1.4)
  • v0.1.4(Sep 29, 2022)

  • v0.1.3(Sep 26, 2022)

    Previously, there was a conflict between easy_elasticsearch and beir on the dependency of elasticsearch:

    • easy_elasticsearch requires elasticsearch==7.12.1 while
    • beir requires elasticserch==7.9.1

    In the lastest version of easy_elasticsearch, the requirements have been changed to solve this issue. Here we update gpl to install this version (easy_elasticsearch==0.0.9). Another update of easy_elasticsearch==0.0.9 is that it has solved the issue that ES could return empty results (due to refresh is not called for indexing)

    Source code(tar.gz)
    Source code(zip)
    gpl-0.1.3-py3-none-any.whl(28.79 KB)
    gpl-0.1.3.tar.gz(22.88 KB)
  • v0.1.0(Apr 19, 2022)

    Updated paper, accepted by NAACL 2022

    The GPL paper has been accepted by NAACL 2022! Major updates:

    • Improved the setting: Down-sampled the corpus if it is too large; calculate the number of generated queries according to the corpus size;
    • Added more analysis about the influence of the number of generated queries: Small corpus needs more queries;
    • Added results on the full 18 BeIR datasets: The conclusions remain the same, while we also tried training GPL on top of the power TAS-B model and achieved new improvements.

    Automatic hyper-parameter

    Previously, we use the whole corpus and number of generated queries = 3, no matter the corpus size. This actually results in a very bad training efficiency for large corpus. In the new version, we automatically set these two hyper-parameters by meeting the standard: the total number of generated queries = 250K.

    In detail, we first set the queries_per_passage >= 3 and uniformly down-sample the corpus if 3 × |C| > 250K, where |C| is the corpus size; then we calculate queries_per_passage = 250K/|C|. For example, the queries_per_passage values for FiQA (original size = 57.6K) and Robust04 (original size = 528.2K) are 5 and 3, resp. and the Robust04 corpus is down-sampled to 83.3K.

    Released checkpoints (TAS-B ones)

    We now release the pre-trained GPL models via the https://huggingface.co/GPL. They also include the power GPL models trained on top of TAS-B.

    Source code(tar.gz)
    Source code(zip)
    gpl-0.1.0-py3-none-any.whl(27.99 KB)
    gpl-0.1.0.tar.gz(22.13 KB)
  • v0.0.9(Jan 11, 2022)

    Fixed bug of max.-sequence-length mismatch between student and teacher

    Previously, the teacher (i.e. the cross-encoder) got the input of the concatenation of query and document texts and had no limits of max. sequence length (cf. here and here). However, the students actually had the limits of max. sequence length on both query texts and document texts separately. This causes the mismatch between the information which can be seen by the student and the teacher models.

    In the new release, we fixed this by doing "retokenization": Right before pseudo labeling, we let the tokenizer of the teacher model tokenize the query texts and the document texts also separately and then decode the results (token IDs) back into the texts again. The resulting texts will meet the same max.-sequence-length requirements as the student model does and thus fix this bug.

    Keep full precision of the pseudo labels

    Previously, we saved the pseudo labels from PyTorch's tensors directly, which would not give the full precision. Now we have fixed this by doing labels.tolist() right before the data dumping. This actually would not influence a lot, since previously it kept 6-digit precision and was high enough.

    Source code(tar.gz)
    Source code(zip)
    gpl-0.0.9-py3-none-any.whl(23.56 KB)
    gpl-0.0.9.tar(18.38 KB)
  • v0.0.8(Dec 20, 2021)

    Independent evaluation and k_values supported

    One can now run the gpl.toolkit.evaluation directly. Previously, it was only possible as part of the whole gpl.train workflow. Please check this example for more details.

    And we have also added argument k_values in gpl.toolkit.evaluation.evaluate. This is for specifying the K values in "[email protected]", "[email protected]", etc.

    Fixed bugs & use load_sbert in mnrl and evaluation

    Now almost all methods that require a separation token has this argument called sep (previously it was fixed as a blank token " "). Two exceptions are mnrl (a loss function in SBERT repo, also the default training loss for the QGen method) and qgen, since they are from the BeIR repo (we will update the BeIR repo in the future if possible).

    Source code(tar.gz)
    Source code(zip)
    gpl-0.0.8-py3-none-any.whl(23.12 KB)
    gpl-0.0.8.tar(17.96 KB)
  • v0.0.7(Dec 17, 2021)

    Rewrite SBERT loading

    Previously, GPL loads starting checkpoints (--base_ckpt) by constructing SBERT model from scratch. This way would lose some information of the checkpoint (e.g. pooling and max_seq_length), and one needed to specify them carefully.

    Now we have created another method called load_sbert. It will use SentenceTransformer(base_ckpt) to load the checkpoint directly and do some checking & assertions. Loading from a Huggingface-format checkpoint (e.g. "distilbert-base-uncased") now is still possible for many cases as previous, but we do recommend users to load from a SBERT-format if possible, since it will be less likely to misuse the starting checkpoint.

    Reformatting examples

    In some cases, Huggingface-format checkpoint cannot be loaded directly by SBERT, e.g. "facebook/dpr-question_encoder-single-nq-base". This is because:

    1. Of course, they are not in SBERT-format but in Hugginface-format;
    2. And for Huggingface-format, SBERT can only work with the checkpoint with a Transformer layer as the last layer, i.e. the outputs should contain hidden states with shape (batch_size, sequence_length, hidden_dimenstion).

    To use these checkpoints, one needs to reformat them into SBERT-format. We have provided two examples/templates in the new toolkit source file, gpl/toolkit/reformat.py. Please refer to its readme here.

    Solved logging bug

    Previously, the logging in GPL is overridden by some other loggers and the formatting cannot display as we want. Now we have solved this by dealing with the root logger. And the new formatting will show many usefull details:

    fmt='[%(asctime)s] %(levelname)s [%(name)s.%(funcName)s:%(lineno)d] %(message)s'
    
    Source code(tar.gz)
    Source code(zip)
    gpl-0.0.7-py3-none-any.whl(22.72 KB)
    gpl-0.0.7.tar(17.81 KB)
Owner
Ubiquitous Knowledge Processing Lab
Ubiquitous Knowledge Processing Lab
Neural Cellular Automata + CLIP

🧠 Text-2-Cellular Automata Using Neural Cellular Automata + OpenAI CLIP (Work in progress) Examples Text Prompt: Cthulu is watching cthulu_is_watchin

Mainak Deb 21 Dec 19, 2022
HTSeq is a Python library to facilitate processing and analysis of data from high-throughput sequencing (HTS) experiments.

HTSeq DEVS: https://github.com/htseq/htseq DOCS: https://htseq.readthedocs.io A Python library to facilitate programmatic analysis of data from high-t

HTSeq 57 Dec 20, 2022
PyTorch implementation of PSPNet

PSPNet with PyTorch Unofficial implementation of "Pyramid Scene Parsing Network" (https://arxiv.org/abs/1612.01105). This repository is just for caffe

Kazuto Nakashima 52 Nov 16, 2022
Curvlearn, a Tensorflow based non-Euclidean deep learning framework.

English | 简体中文 Why Non-Euclidean Geometry Considering these simple graph structures shown below. Nodes with same color has 2-hop distance whereas 1-ho

Alibaba 123 Dec 12, 2022
TF2 implementation of knowledge distillation using the "function matching" hypothesis from the paper Knowledge distillation: A good teacher is patient and consistent by Beyer et al.

FunMatch-Distillation TF2 implementation of knowledge distillation using the "function matching" hypothesis from the paper Knowledge distillation: A g

Sayak Paul 67 Dec 20, 2022
Fusion-DHL: WiFi, IMU, and Floorplan Fusion for Dense History of Locations in Indoor Environments

Fusion-DHL: WiFi, IMU, and Floorplan Fusion for Dense History of Locations in Indoor Environments Paper: arXiv (ICRA 2021) Video : https://youtu.be/CC

Sachini Herath 68 Jan 03, 2023
Demo code for paper "Learning optical flow from still images", CVPR 2021.

Depthstillation Demo code for "Learning optical flow from still images", CVPR 2021. [Project page] - [Paper] - [Supplementary] This code is provided t

130 Dec 25, 2022
Instantaneous Motion Generation for Robots and Machines.

Ruckig Instantaneous Motion Generation for Robots and Machines. Ruckig generates trajectories on-the-fly, allowing robots and machines to react instan

Berscheid 374 Dec 23, 2022
Official implementation of the article "Unsupervised JPEG Domain Adaptation For Practical Digital Forensics"

Unsupervised JPEG Domain Adaptation for Practical Digital Image Forensics @WIFS2021 (Montpellier, France) Rony Abecidan, Vincent Itier, Jeremie Boulan

Rony Abecidan 6 Jan 06, 2023
This is a Pytorch implementation of paper: DropEdge: Towards Deep Graph Convolutional Networks on Node Classification

DropEdge: Towards Deep Graph Convolutional Networks on Node Classification This is a Pytorch implementation of paper: DropEdge: Towards Deep Graph Con

401 Dec 16, 2022
Rotation Robust Descriptors

RoRD Rotation-Robust Descriptors and Orthographic Views for Local Feature Matching Project Page | Paper link Evaluation and Datasets MMA : Training on

Udit Singh Parihar 25 Nov 15, 2022
Human Activity Recognition example using TensorFlow on smartphone sensors dataset and an LSTM RNN. Classifying the type of movement amongst six activity categories - Guillaume Chevalier

LSTMs for Human Activity Recognition Human Activity Recognition (HAR) using smartphones dataset and an LSTM RNN. Classifying the type of movement amon

Guillaume Chevalier 3.1k Dec 30, 2022
BiSeNet based on pytorch

BiSeNet BiSeNet based on pytorch 0.4.1 and python 3.6 Dataset Download CamVid dataset from Google Drive or Baidu Yun(6xw4). Pretrained model Download

367 Dec 26, 2022
Fast, general, and tested differentiable structured prediction in PyTorch

Fast, general, and tested differentiable structured prediction in PyTorch

HNLP 1.1k Dec 16, 2022
This repository contains code from the paper "TTS-GAN: A Transformer-based Time-Series Generative Adversarial Network"

TTS-GAN: A Transformer-based Time-Series Generative Adversarial Network This repository contains code from the paper "TTS-GAN: A Transformer-based Tim

Intelligent Multimodal Computing and Sensing Laboratory (IMICS Lab) - Texas State University 108 Dec 29, 2022
Rethinking Nearest Neighbors for Visual Classification

Rethinking Nearest Neighbors for Visual Classification arXiv Environment settings Check out scripts/env_setup.sh Setup data Download the following fin

Menglin Jia 29 Oct 11, 2022
Over9000 optimizer

Optimizers and tests Every result is avg of 20 runs. Dataset LR Schedule Imagenette size 128, 5 epoch Imagewoof size 128, 5 epoch Adam - baseline OneC

Mikhail Grankin 405 Nov 27, 2022
Lolviz - A simple Python data-structure visualization tool for lists of lists, lists, dictionaries; primarily for use in Jupyter notebooks / presentations

lolviz By Terence Parr. See Explained.ai for more stuff. A very nice looking javascript lolviz port with improvements by Adnan M.Sagar. A simple Pytho

Terence Parr 785 Dec 30, 2022
Relative Uncertainty Learning for Facial Expression Recognition

Relative Uncertainty Learning for Facial Expression Recognition The official implementation of the following paper at NeurIPS2021: Title: Relative Unc

35 Dec 28, 2022
Reference implementation of code generation projects from Facebook AI Research. General toolkit to apply machine learning to code, from dataset creation to model training and evaluation. Comes with pretrained models.

This repository is a toolkit to do machine learning for programming languages. It implements tokenization, dataset preprocessing, model training and m

Facebook Research 408 Jan 01, 2023