WRENCH: Weak supeRvision bENCHmark

Overview

made-with-python Maintenance Open Source Love svg1

๐Ÿ”ง What is it?

Wrench is a benchmark platform containing diverse weak supervision tasks. It also provides a common and easy framework for development and evaluation of your own weak supervision models within the benchmark.

For more information, checkout our publications: (coming soon!)

๐Ÿ”ง What is weak supervision?

Weak Supervision is a paradigm for automated training data creation without manual annotations.

For a brief overview, please check out this blog.

To track recent advances in weak supervision, please follow this repo.

๐Ÿ”ง Installation

[1] Install anaconda: Instructions here: https://www.anaconda.com/download/

[2] Clone the repository:

git clone https://github.com/JieyuZ2/wrench.git
cd wrench

[3] Create virtual environment:

conda env create -f environment.yml
source activate wrench

๐Ÿ”ง Available Datasets

The datasets can be downloaded via this.

classification:

Name Task # class # LF # train # validation # test data source LF source
Census income clasification 2 83 10083 5561 16281 link link
Youtube spam clasification 2 10 1586 120 250 link link
SMS spam clasification 2 73 4571 500 500 link link
IMDB sentiment clasification 2 8 20000 2500 2500 link link
Yelp sentiment clasification 2 8 30400 3800 3800 link link
AGNews topic clasification 4 9 96000 12000 12000 link link
TREC question classification 6 68 4965 500 500 link link
Spouse relation classification 2 9 22254 2801 2701 link link
SemEval relation classification 9 164 1749 200 692 link link
CDR bio relation classification 2 33 8430 920 4673 link link
Chemprot chemical relation classification 10 26 12861 1607 1607 link link
Commercial video frame classification 2 4 64130 9479 7496 link link
Tennis Rally video frame classification 2 6 6959 746 1098 link link
Basketball video frame classification 2 4 17970 1064 1222 link link

sequence tagging:

Name # class # LF # train # validation # test data source LF source
CoNLL-03 4 16 14041 3250 3453 link link
WikiGold 4 16 1355 169 170 link link
OntoNotes 5.0 18 17 115812 5000 22897 link link
BC5CDR 2 9 500 500 500 link link
NCBI-Disease 1 5 592 99 99 link link
Laptop-Review 1 3 2436 609 800 link link
MIT-Restaurant 8 16 7159 500 1521 link link
MIT-Movies 12 7 9241 500 2441 link link

The detailed documentation is coming soon.

๐Ÿ”ง Available Models

classification:

Model Model Type Reference Link to Wrench
Majority Voting Label Model -- link
Weighted Majority Voting Label Model -- link
Dawid-Skene Label Model link link
Data Progamming Label Model link link
MeTaL Label Model link link
FlyingSquid Label Model link link
Logistic Regression End Model -- link
MLP End Model -- link
Pre-trained Language Model End Model link link
COSINE End Model link link
Denoise Joint Model link link

sequence tagging:

Model Model Type Reference Link to Wrench
Hidden Markov Model Label Model link link
Conditional Hidden Markov Model Label Model link link
LSTM-CNNs-CRF End Model link link
Pre-trained Language Model End Model link link
ConNet Joint Model link link

๐Ÿ”ง Quick examples

๐Ÿ”ง Label model with parallel grid search for hyper-parameters

import logging
import numpy as np
import pprint

from wrench.dataset import load_dataset
from wrench.logging import LoggingHandler
from wrench.search import grid_search
from wrench import labelmodel 
from wrench.evaluation import AverageMeter

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)

#### Load dataset 
dataset_home = '../datasets'
data = 'youtube'
train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=False)


#### Specify the hyper-parameter search space for grid search
search_space = {
    'Snorkel': {
        'lr': np.logspace(-5, -1, num=5, base=10),
        'l2': np.logspace(-5, -1, num=5, base=10),
        'n_epochs': [5, 10, 50, 100, 200],
    }
}

#### Initialize label model
label_model_name = 'Snorkel'
label_model = getattr(labelmodel, label_model_name)

#### Search best hyper-parameters using validation set in parallel
n_trials = 100
n_repeats = 5
target = 'acc'
searched_paras = grid_search(label_model(), dataset_train=train_data, dataset_valid=valid_data,
                             metric=target, direction='auto', search_space=search_space[label_model_name],
                             n_repeats=n_repeats, n_trials=n_trials, parallel=True)

#### Evaluate the label model with searched hyper-parameters and average meter
meter = AverageMeter(names=[target])
for i in range(n_repeats):
    model = label_model(**searched_paras)
    history = model.fit(dataset_train=train_data, dataset_valid=valid_data)
    metric_value = model.test(test_data, target)
    meter.update(target=metric_value)

metrics = meter.get_results()
pprint.pprint(metrics)

๐Ÿ”ง Run a standard supervised learning pipeline

import logging
import torch

from wrench.dataset import load_dataset
from wrench.logging import LoggingHandler
from wrench.endmodel import MLPModel

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)

#### Load dataset 
dataset_home = '../datasets'
data = 'youtube'

#### Extract data features using pre-trained BERT model and cache it
extract_fn = 'bert'
model_name = 'bert-base-cased'
train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=True, extract_fn=extract_fn,
                                                 cache_name=extract_fn, model_name=model_name)


#### Train a MLP classifier
device = torch.device('cuda:0')
n_steps = 100000
batch_size = 128
test_batch_size = 1000 
patience = 200
evaluation_step = 50
target='acc'

model = MLPModel(n_steps=n_steps, batch_size=batch_size, test_batch_size=test_batch_size)
history = model.fit(dataset_train=train_data, dataset_valid=valid_data, device=device, metric=target, 
                    patience=patience, evaluation_step=evaluation_step)

#### Evaluate the trained model
metric_value = model.test(test_data, target)

๐Ÿ”ง Build a two-stage weak supervision pipeline

import logging
import torch

from wrench.dataset import load_dataset
from wrench.logging import LoggingHandler
from wrench.endmodel import MLPModel
from wrench.labelmodel import MajorityVoting

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)

#### Load dataset 
dataset_home = '../datasets'
data = 'youtube'

#### Extract data features using pre-trained BERT model and cache it
extract_fn = 'bert'
model_name = 'bert-base-cased'
train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=True, extract_fn=extract_fn,
                                                 cache_name=extract_fn, model_name=model_name)

#### Generate soft training label via a label model
#### The weak labels provided by supervision sources are alreadly encoded in dataset object
label_model = MajorityVoting()
label_model.fit(train_data, valid_data)
soft_label = label_model.predict_proba(train_data)


#### Train a MLP classifier with soft label
device = torch.device('cuda:0')
n_steps = 100000
batch_size = 128
test_batch_size = 1000 
patience = 200
evaluation_step = 50
target='acc'

model = MLPModel(n_steps=n_steps, batch_size=batch_size, test_batch_size=test_batch_size)
history = model.fit(dataset_train=train_data, dataset_valid=valid_data, y_train=soft_label, 
                    device=device, metric=target, patience=patience, evaluation_step=evaluation_step)

#### Evaluate the trained model
metric_value = model.test(test_data, target)

#### We can also train a MLP classifier with hard label
from snorkel.utils import probs_to_preds
hard_label = probs_to_preds(soft_label)
model = MLPModel(n_steps=n_steps, batch_size=batch_size, test_batch_size=test_batch_size)
model.fit(dataset_train=train_data, dataset_valid=valid_data, y_train=hard_label, 
          device=device, metric=target, patience=patience, evaluation_step=evaluation_step)

๐Ÿ”ง Procedural labeling function generator

import logging
import torch

from wrench.dataset import load_dataset
from wrench.logging import LoggingHandler
from wrench.synthetic import ConditionalIndependentGenerator, NGramLFGenerator
from wrench.labelmodel import FlyingSquid

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)


#### Generate synthetic dataset
generator = ConditionalIndependentGenerator(
    n_class=2,
    n_lfs=10,
    alpha=0.75, # mean accuracy
    beta=0.1, # mean propensity
    alpha_radius=0.2, # radius of accuracy
    beta_radius=0.1 # radius of propensity
)
train_data = generator.generate_split('train', 10000)
valid_data = generator.generate_split('valid', 1000)
test_data = generator.generate_split('test', 1000)

#### Evaluate label model on synthetic dataset
label_model = FlyingSquid()
label_model.fit(dataset_train=train_data, dataset_valid=valid_data)
target_value = label_model.test(test_data, metric_fn='auc')

#### Load dataset 
dataset_home = '../datasets'
data = 'youtube'

#### Load real-world dataset
train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=False)

#### Generate procedural labeling functions
generator = NGramLFGenerator(dataset=train_data, min_acc_gain=0.1, min_support=0.01, ngram_range=(1, 2))
applier = generator.generate(mode='correlated', n_lfs=10)
L_test = applier.apply(test_data)
L_train = applier.apply(train_data)


#### Evaluate label model on real-world dataset with semi-synthetic labeling functions
label_model = FlyingSquid()
label_model.fit(dataset_train=L_train, dataset_valid=valid_data)
target_value = label_model.test(L_test, metric_fn='auc')
Comments
  • ModuleNotFoundError: No module named 'tokenizations'

    ModuleNotFoundError: No module named 'tokenizations'

    Hi, I faced some problems when trying to install the library. I tried to use pip install ws-benchmark==1.1.2rc0 as suggested in the document, the installation was successful but when I run the code I faced the error ModuleNotFoundError: No module named 'tokenizations'. Then I tried to clone the repository and create the environment using conda env create -f environment.yml, but the installation failed due to the following error FileNotFoundError: [Errno 2] No such file or directory: '/home/naiqing/miniconda3/envs/wrench/lib/python3.6/site-packages/huggingface_hub-0.0.16-py3.8.egg'. Do you have ideas on what might cause the problem and how can I fix it?

    opened by Gnaiqing 12
  • Is there a limitation of using dataset for different algs?

    Is there a limitation of using dataset for different algs?

    Firstly, thank you for building this awesome benchmark. While I try the example with different datasets (e.g., I try astra with youtube dataset), I got some errors like this,

        loss = cross_entropy_with_probs(predict_l, batch['labels'].to(device))
    KeyError: 'labels'
    

    Can this be fixed?

    opened by mrbeann 8
  • Python Package Installation Fails

    Python Package Installation Fails

    Installing ws-benchmark python package fails due to dependency conflict (see stack trace below).

    Tested on system:

    • OS: ubuntu
    • Python: 3.8.13
    • Clean VE

    Command to replicate:

    • pip install ws-benchmark

    Stack Trace:

    ERROR: Cannot install ws-benchmark and ws-benchmark==1.1.1 because these package versions have conflicting dependencies.
    
    The conflict is caused by:
        ws-benchmark 1.1.1 depends on networkx==2.7
        snorkel 0.9.7 depends on networkx<2.4 and >=2.2
    
    To fix this you could try to:
    1. loosen the range of package versions you've specified
    2. remove package versions to allow pip attempt to solve the dependency conflict
    
    ERROR: ResolutionImpossible: for help visit https://pip.pypa.io/en/latest/topics/dependency-resolution/#dealing-with-dependency-conflicts
    
    opened by bradleyfowler123 4
  • Using Multiple GPUs

    Using Multiple GPUs

    Hi,

    Is it possible to use multiple GPUs for the experiments, or will it be in future releases? It would be a nice feature if it is not possible right now.

    Best regards.

    opened by tolgayan 4
  • Running scripts

    Running scripts

    Hi, I am trying to run some models on the IMDB dataset.

    MLP:

    import logging
    import torch
    import numpy as np
    from wrench.dataset import load_dataset
    from wrench.labelmodel import Snorkel
    from wrench.logging import LoggingHandler
    from wrench.search import grid_search
    from wrench.endmodel import EndClassifierModel
    
    #### Just some code to print debug information to stdout
    logging.basicConfig(format='%(asctime)s - %(message)s',
                        datefmt='%Y-%m-%d %H:%M:%S',
                        level=logging.INFO,
                        handlers=[LoggingHandler()])
    
    logger = logging.getLogger(__name__)
    
    device = torch.device('cuda')
    
    if __name__ == '__main__':
        #### Load dataset
        dataset_path = '../datasets/'
        data = "imdb"
        bert_model_name = "bert-base-cased"
        train_data, valid_data, test_data = load_dataset(
            dataset_path,
            data,
            extract_feature=True,
            extract_fn='bert',  # extract bert embedding
            model_name=bert_model_name,
            cache_name='bert',
            dataset_type="TextDataset"
        )
    
        #### Run label model: Snorkel
        label_model = Snorkel(
            lr=0.005,
            l2=0,
            n_epochs=200,
            seed=123
        )
        label_model.fit(
            dataset_train=train_data,
            dataset_valid=valid_data
        )
    
        acc = label_model.test(test_data, 'acc')
        logger.info(f'label model test acc: {acc}')
    
        #### Filter out uncovered training data
        aggregated_hard_labels = label_model.predict(train_data)
        aggregated_soft_labels = label_model.predict_proba(train_data)
    
        #### Search Space
        search_space = {
            'optimizer_lr': np.logspace(-5, -1, num=5, base=10),
            'optimizer_weight_decay': np.logspace(-5, -1, num=5, base=10),
        }
    
        #### Initialize the model: MLP
        model = EndClassifierModel(
            batch_size=8,
            real_batch_size=8,
            test_batch_size=8,
            backbone='MLP',
            optimizer='Adam'
        )
    
        #### Search best hyper-parameters using validation set in parallel
        n_trials = 20
        n_repeats = 1
        searched_paras = grid_search(
            model,
            dataset_train=train_data,
            y_train=aggregated_soft_labels,
            dataset_valid=valid_data,
            metric='acc',
            direction='auto',
            search_space=search_space,
            n_repeats=n_repeats,
            n_trials=n_trials,
            parallel=True,
            device=device,
        )
    
    
        #### Run end model: MLP
        model = EndClassifierModel(
            batch_size=8,
            real_batch_size=8,
            test_batch_size=8,
            backbone='MLP',
            optimizer='Adam',
            **searched_paras
        )
        model.fit(
            dataset_train=train_data,
            y_train=aggregated_soft_labels,
            dataset_valid=valid_data,
            metric='acc',
            device=device
        )
    
        logger.info(model.predict(test_data).tolist())
    
        acc = model.test(test_data, 'acc')
        logger.info(f'end model (MLP) test acc: {acc}')
    
    

    for which I am getting the following output:

    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 20000/20000 [00:00<00:00, 902651.16it/s]
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 2500/2500 [00:00<00:00, 852639.45it/s]
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 2500/2500 [00:00<00:00, 829503.99it/s]
    Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']
    - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
    - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 20000/20000 [1:42:45<00:00,  3.24it/s]  
    Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']
    - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
    - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 2500/2500 [13:24<00:00,  3.11it/s]
    Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']
    - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
    - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 2500/2500 [13:50<00:00,  3.01it/s]
    [I 2021-10-23 22:24:36,807] A new study created in memory with name: no-name-9e4ad09c-ea4a-4ee8-80c2-7633429e4038
    huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
    To disable this warning, you can either:
            - Avoid using `tokenizers` before the fork if possible
            - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
    2021-10-23 20:14:19 - loading data from ../datasets/imdb/train.json
    2021-10-23 20:14:19 - loading data from ../datasets/imdb/valid.json
    2021-10-23 20:14:19 - loading data from ../datasets/imdb/test.json
    2021-10-23 21:57:10 - saving features into ../datasets/imdb/train_bert.pkl
    2021-10-23 22:10:40 - saving features into ../datasets/imdb/valid_bert.pkl
    2021-10-23 22:24:36 - saving features into ../datasets/imdb/test_bert.pkl
    2021-10-23 22:24:36 - label model test acc: 0.716
    huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
    To disable this warning, you can either:
            - Avoid using `tokenizers` before the fork if possible
            - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:37<00:00, 37.48s/it]
    [I 2021-10-23 22:25:14,563] Trial 0 finished with value: 0.5012 and parameters: {'optimizer_lr': 0.001, 'optimizer_weight_decay': 0.0001}. Best is trial 0 with value: 0.5012.
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:23<00:00, 23.70s/it]
    [I 2021-10-23 22:25:38,448] Trial 1 finished with value: 0.496 and parameters: {'optimizer_lr': 1e-05, 'optimizer_weight_decay': 0.1}. Best is trial 0 with value: 0.5012.
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:14<00:00, 14.53s/it]
    [I 2021-10-23 22:25:53,171] Trial 2 finished with value: 0.5004 and parameters: {'optimizer_lr': 0.1, 'optimizer_weight_decay': 0.001}. Best is trial 0 with value: 0.5012.
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:43<00:00, 43.73s/it]
    [I 2021-10-23 22:26:37,071] Trial 3 finished with value: 0.5088 and parameters: {'optimizer_lr': 0.001, 'optimizer_weight_decay': 0.001}. Best is trial 3 with value: 0.5088.
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:18<00:00, 18.85s/it]
    [I 2021-10-23 22:26:56,161] Trial 4 finished with value: 0.488 and parameters: {'optimizer_lr': 0.001, 'optimizer_weight_decay': 0.1}. Best is trial 3 with value: 0.5088.
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:38<00:00, 38.81s/it]
    [I 2021-10-23 22:27:35,214] Trial 5 finished with value: 0.4948 and parameters: {'optimizer_lr': 0.0001, 'optimizer_weight_decay': 0.1}. Best is trial 3 with value: 0.5088.
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:38<00:00, 38.15s/it]
    [I 2021-10-23 22:28:13,614] Trial 6 finished with value: 0.5024 and parameters: {'optimizer_lr': 0.01, 'optimizer_weight_decay': 0.01}. Best is trial 3 with value: 0.5088.
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:15<00:00, 15.47s/it]
    [I 2021-10-23 22:28:29,335] Trial 7 finished with value: 0.4996 and parameters: {'optimizer_lr': 0.1, 'optimizer_weight_decay': 1e-05}. Best is trial 3 with value: 0.5088.
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:22<00:00, 22.49s/it]
    [I 2021-10-23 22:28:52,093] Trial 8 finished with value: 0.5008 and parameters: {'optimizer_lr': 0.0001, 'optimizer_weight_decay': 1e-05}. Best is trial 3 with value: 0.5088.
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:40<00:00, 40.25s/it]
    [I 2021-10-23 22:29:32,594] Trial 9 finished with value: 0.5008 and parameters: {'optimizer_lr': 0.0001, 'optimizer_weight_decay': 0.0001}. Best is trial 3 with value: 0.5088.
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:39<00:00, 39.06s/it]
    [I 2021-10-23 22:30:11,902] Trial 10 finished with value: 0.5116 and parameters: {'optimizer_lr': 1e-05, 'optimizer_weight_decay': 1e-05}. Best is trial 10 with value: 0.5116.
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:43<00:00, 43.46s/it]
    [I 2021-10-23 22:30:55,531] Trial 11 finished with value: 0.4912 and parameters: {'optimizer_lr': 0.001, 'optimizer_weight_decay': 1e-05}. Best is trial 10 with value: 0.5116.
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:23<00:00, 23.41s/it]
    [I 2021-10-23 22:31:19,095] Trial 12 finished with value: 0.4956 and parameters: {'optimizer_lr': 0.001, 'optimizer_weight_decay': 0.01}. Best is trial 10 with value: 0.5116.
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:22<00:00, 22.12s/it]
    [I 2021-10-23 22:31:41,374] Trial 13 finished with value: 0.492 and parameters: {'optimizer_lr': 1e-05, 'optimizer_weight_decay': 0.01}. Best is trial 10 with value: 0.5116.
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:15<00:00, 15.78s/it]
    [I 2021-10-23 22:31:57,283] Trial 14 finished with value: 0.5044 and parameters: {'optimizer_lr': 0.1, 'optimizer_weight_decay': 0.0001}. Best is trial 10 with value: 0.5116.
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:37<00:00, 37.28s/it]
    [I 2021-10-23 22:32:34,728] Trial 15 finished with value: 0.488 and parameters: {'optimizer_lr': 1e-05, 'optimizer_weight_decay': 0.001}. Best is trial 10 with value: 0.5116.
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:16<00:00, 16.04s/it]
    [I 2021-10-23 22:32:50,934] Trial 16 finished with value: 0.4924 and parameters: {'optimizer_lr': 0.0001, 'optimizer_weight_decay': 0.001}. Best is trial 10 with value: 0.5116.
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:19<00:00, 19.65s/it]
    [I 2021-10-23 22:33:10,753] Trial 17 finished with value: 0.5156 and parameters: {'optimizer_lr': 0.1, 'optimizer_weight_decay': 0.1}. Best is trial 17 with value: 0.5156.
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:15<00:00, 15.41s/it]
    [I 2021-10-23 22:33:26,345] Trial 18 finished with value: 0.5068 and parameters: {'optimizer_lr': 0.01, 'optimizer_weight_decay': 0.001}. Best is trial 17 with value: 0.5156.
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:16<00:00, 16.75s/it]
    [I 2021-10-23 22:33:43,222] Trial 19 finished with value: 0.498 and parameters: {'optimizer_lr': 0.0001, 'optimizer_weight_decay': 0.01}. Best is trial 17 with value: 0.5156.
    [TRAIN]:  15%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–Œ                               | 1499/10000 [00:21<02:04, 68.19steps/s, loss=4.02, val_acc=0.5, best_val_acc=0.508, best_step=500]
    2021-10-23 22:33:43 - [END: BEST VAL / PARAMS] Best value: 0.5156, Best paras: {'optimizer_lr': 0.1, 'optimizer_weight_decay': 0.1}
    2021-10-23 22:33:43 - 
    ==========[hyper parameters]==========
    {
        "batch_size": 8,
        "real_batch_size": 8,
        "test_batch_size": 8,
        "n_steps": 10000,
        "grad_norm": -1,
        "use_lr_scheduler": false,
        "binary_mode": false
    }
    ==========[optimizer config]==========
    {
        "name": "Adam",
        "paras": {
            "lr": 0.1,
            "weight_decay": 0.1
        }
    }
    ==========[backbone config]==========
    {
        "name": "MLP",
        "paras": {
            "hidden_size": 100,
            "dropout": 0.0
        }
    }
    
    2021-10-23 22:34:09 - [INFO] early stop @ step 1500!
    2021-10-23 22:34:09 - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    2021-10-23 22:34:09 - end model (MLP) test acc: 0.5004
    

    COSINE:

    import logging
    import torch
    from wrench.dataset import load_dataset
    from wrench.logging import LoggingHandler
    from wrench.labelmodel import Snorkel
    from wrench.endmodel import Cosine
    
    #### Just some code to print debug information to stdout
    logging.basicConfig(format='%(asctime)s - %(message)s',
                        datefmt='%Y-%m-%d %H:%M:%S',
                        level=logging.INFO,
                        handlers=[LoggingHandler()])
    
    logger = logging.getLogger(__name__)
    
    device = torch.device('cuda')
    
    if __name__ == '__main__':
        #### Load dataset
        dataset_path = '../datasets/'
        data = "imdb"
        bert_model_name = "bert-base-cased"
        train_data, valid_data, test_data = load_dataset(
            dataset_path,
            data,
            extract_feature=True,
            extract_fn='bert',  # extract bert embedding
            model_name=bert_model_name,
            cache_name='bert',
            dataset_type="TextDataset"
        )
    
        #### Run label model: Snorkel
        label_model = Snorkel(
            lr=0.005,
            l2=0,
            n_epochs=200,
            seed=123
        )
        label_model.fit(
            dataset_train=train_data,
            dataset_valid=valid_data
        )
    
        acc = label_model.test(test_data, 'acc')
        logger.info(f'label model test acc: {acc}')
    
        #### Filter out uncovered training data
        aggregated_hard_labels = label_model.predict(train_data)
        aggregated_soft_labels = label_model.predict_proba(train_data)
    
    
        # COSINE
        model = Cosine(
            teacher_update=100,
            margin=1.0,
            thresh=0.6,
            lr=1e-5,
            mu=1.0,
            lamda=0.05,
            backbone='BERT',
            backbone_model_name=bert_model_name,
            batch_size=8,
            real_batch_size=8,
            test_batch_size=8,
        )
    
        model.fit(dataset_train=train_data,
                  dataset_valid=valid_data,
                  y_train=aggregated_hard_labels,
                  evaluation_step=10,
                  metric='acc',
                  patience=50,
                  device=device)
    
        acc = model.test(test_data, 'acc')
    
        logger.info(model.predict(test_data))
    
        logger.info(f'end model (COSINE) test acc: {acc}')
    

    for which I am getting the following output:

    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 20000/20000 [00:00<00:00, 899119.81it/s]
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 2500/2500 [00:00<00:00, 423667.07it/s]
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 2500/2500 [00:00<00:00, 802645.44it/s]
    Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias']
    - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
    - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 20000/20000 [1:47:44<00:00,  3.09it/s]  
    Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias']
    - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
    - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 2500/2500 [14:22<00:00,  2.90it/s]
    Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias']
    - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
    - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 2500/2500 [13:33<00:00,  3.07it/s] 
    Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias']
    - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
    - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    [TRAIN] COSINE pretrain stage:   5%|โ–Š               | 509/10000 [21:19<6:37:40,  2.51s/steps, loss=0.605, val_acc=0.5, best_val_acc=0.5, best_step=10]
    [TRAIN] COSINE distillation stage:   0%|                                                                                 | 0/10000 [03:05<?, ?steps/s]
    2021-10-23 20:14:13 - loading data from ../datasets/imdb/train.json
    2021-10-23 20:14:13 - loading data from ../datasets/imdb/valid.json
    2021-10-23 20:14:14 - loading data from ../datasets/imdb/test.json
    2021-10-23 22:02:05 - saving features into ../datasets/imdb/train_bert.pkl
    2021-10-23 22:16:34 - saving features into ../datasets/imdb/valid_bert.pkl
    2021-10-23 22:30:14 - saving features into ../datasets/imdb/test_bert.pkl
    2021-10-23 22:30:14 - label model test acc: 0.716
    2021-10-23 22:30:17 - 
    ==========[hyper parameters]==========
    {
        "teacher_update": 100,
        "margin": 1.0,
        "mu": 1.0,
        "thresh": 0.6,
        "lamda": 0.05,
        "batch_size": 8,
        "real_batch_size": 8,
        "test_batch_size": 8,
        "n_steps": 10000,
        "grad_norm": -1,
        "use_lr_scheduler": false,
        "binary_mode": false
    }
    ==========[optimizer config]==========
    {
        "name": "Adam",
        "paras": {
            "lr": 0.001,
            "weight_decay": 0.0
        }
    }
    ==========[backbone config]==========
    {
        "name": "BERT",
        "paras": {
            "model_name": "bert-base-cased",
            "max_tokens": 512,
            "fine_tune_layers": -1
        }
    }
    ==========[label model_config config]==========
    {
        "name": "MajorityVoting",
        "paras": {}
    }
    
    2021-10-23 22:51:52 - [INFO] early stop @ step 510!
    2021-10-23 22:55:20 - early stop because all the data are filtered!
    2021-10-23 22:56:06 - [1 1 1 ... 1 1 1]
    2021-10-23 22:56:06 - end model (COSINE) test acc: 0.5
    

    As can be seen for both models, label model test acc: 0.716 but end model (MLP) test acc: 0.5004 and end model (COSINE) test acc: 0.5.

    Am I doing something completely wrong? Could you please tell me if I am running the code correctly or is there some issue with hyperparameters?

    I would greatly appreciate if you could give me some advice. I would be very glad if you could include an example running script of the COSINE model as well.

    Thanks for the benchmark, I really appreciate it!

    opened by viheheb757 4
  • Reproducing Table 11 for classification

    Reproducing Table 11 for classification

    Thanks for this package @JieyuZ2 -- do you happen to have an orchestration script for reproducing Table 11 (and therefore Table 3) in the Wrench paper?

    opened by pmangg 3
  • No module named 'wrench.classification.self_training'

    No module named 'wrench.classification.self_training'

    Hi, I am trying to run run_denoise.py but I am getting the following error:

    Traceback (most recent call last):
      File "run_denoise.py", line 5, in <module>
        from wrench.classification import Denoise
      File "/gpfs/space/home/wrench/wrench/classification/__init__.py", line 4, in <module>
        from .self_training import LDSelfTrain, DDSelfTrain
    ModuleNotFoundError: No module named 'wrench.classification.self_training'
    

    Could you please add LDSelfTrain and DDSelfTrain classes?

    opened by andreaspung 3
  • Questions on the use of ground-truth labels for validation

    Questions on the use of ground-truth labels for validation

    Thanks for putting up the benchmark! This is really great work! It seems that both the label model and the end model use the ground-truth labels for validation. For example, the base label model uses the ground-truth labels of the validation set to calculate the class balance weights: https://github.com/JieyuZ2/wrench/blob/544119e781d010797cf153307aa1090361c99522/wrench/basemodel.py#L286 I have a few questions regarding this: (1) A valid baseline for the label models would be a classifier trained on the validation set with the weak labels of LFs as features and the ground-truth labels as the target. Given that the validation set for most datasets is actually not small, I feel the trained model might be a pretty strong baseline compared to other unsupervised label models. (2) Similar to how we combine the weak labels on the training set to get aggregated labels, we could also get aggregated labels for the validation set. Then, the aggregated labels instead of the ground-truth labels of the validation set could be used for validation purposes for the end model. Wouldn't this be a more realistic setting? Especially considering that the proposal of weak supervision is to replace human labeling with programmatical labeling.

    I appreciate any explanations. Thanks!

    opened by wurenzhi 2
  • Clarifying dataset download links

    Clarifying dataset download links

    Great work on the benchmark!

    Under the "Available Datasets" section on the main README, you provide 2 links for downloading the WRENCH datasets:

    One point of confusion is that expanded datasets found on the Google drive link are different than the direct download zip file. For example, classification/youtube/train.json on Google drive has 1686 instances while the zip file contains 1586 for the same file, matching the statistics reported on the README. Can you make the correct file download unambiguous in the documentation?

    opened by jason-fries 2
  • Fix retained probabilities

    Fix retained probabilities

    This pull request removes a bug which lead to the wrong probabilities being stored along with the predictions of each labeling function.

    Previously, all probabilities (2d tensor of size batch by classes) were saved alongside the class predictions. However, what was supposed to be saved is the probability associated with each prediction of the model.

    opened by benbo 2
  • New Release

    New Release

    Hi! Love the repo, super useful so far and really easy interface to use. Thanks for putting it together!

    I was wondering if there were plans to cut another release any time soon? We use the v1.0 tag for making sure the version is consistent across multiple builds. Noticed a few bug fixes and QOL improvements since the last release, and those would be nice to have marked at a new tag.

    opened by rsmith49 2
  • Numba 0.43 doesn't work with newer Python versions

    Numba 0.43 doesn't work with newer Python versions

    The numba package 0.43, specified here, doesn't work with Python 3.9. Upgrading the package to the latest version (0.54) resolves the issue. Traceback:

    /home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/llvmlite/llvmpy/__init__.py:3: UserWarning: The module `llvmlite.llvmpy` is deprecated and will be removed in the future.
      warnings.warn(
    /home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/llvmlite/llvmpy/core.py:8: UserWarning: The module `llvmlite.llvmpy.core` is deprecated and will be removed in the future. Equivalent functionality is provided by `llvmlite.ir`.
      warnings.warn(
    Traceback (most recent call last):
      File "<frozen importlib._bootstrap>", line 1007, in _find_and_load
      File "<frozen importlib._bootstrap>", line 986, in _find_and_load_unlocked
      File "<frozen importlib._bootstrap>", line 680, in _load_unlocked
      File "<frozen importlib._bootstrap_external>", line 850, in exec_module
      File "<frozen importlib._bootstrap>", line 228, in _call_with_frames_removed
      File "/home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/wrench/labelmodel/__init__.py", line 1, in <module>
        from .dawid_skene import DawidSkene
      File "/home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/wrench/labelmodel/dawid_skene.py", line 6, in <module>
        from numba import njit, prange
      File "/home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/numba/__init__.py", line 25, in <module>
        from .decorators import autojit, cfunc, generated_jit, jit, njit, stencil
      File "/home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/numba/decorators.py", line 12, in <module>
        from .targets import registry
      File "/home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/numba/targets/registry.py", line 5, in <module>
        from . import cpu
      File "/home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/numba/targets/cpu.py", line 9, in <module>
        from numba import _dynfunc, config
    ImportError: /home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/numba/_dynfunc.cpython-39-x86_64-linux-gnu.so: undefined symbol: _PyObject_GC_UNTRACK
    
    opened by susuzheng 0
  • COSINE for token classification?

    COSINE for token classification?

    Hi,

    I would like to know whether the code for cosine weak-supervision technique is already capable of performing token classification? Or else what changes should I need to do to build a weakly-supervised training pipeline using some weakly-labeled and unlabeled datasets?

    opened by KrishnanJothi 0
  • Balance in Dawid Skene is obsolete.

    Balance in Dawid Skene is obsolete.

    https://github.com/JieyuZ2/wrench/blob/6d8397956533fc6c2fe50e93fcfe0a2303bdd05f/wrench/labelmodel/dawid_skene.py#L55

    I realized this balance variable is used nowhere in this file. If it is intended, I think it should be removed from input parameters.

    opened by ch-shin 1
  • Balance sum to 1

    Balance sum to 1

    https://github.com/JieyuZ2/wrench/blob/ab717ac26a76649c8fdb946a28dffe7e682c80ba/wrench/basemodel.py#L303

    Hi, I find a minor issue that the class prior computed by this function does not sum to 1. Hope you can revise it.

    opened by Gnaiqing 0
  • about COSINE endmodel

    about COSINE endmodel

    Hi @JieyuZ2 and @yinxiangshi , I am trying to run the COSINE endmodel but I have some troubles in reproducing the results in COSINE paper. Although I tried to use the suggested hyperparameters I still get marginal benefit from wrench, and I'm not sure where is wrong. Can you share the scripts you used when evaluating COSINE? Thanks.

    opened by Gnaiqing 0
  • Recommended parameters to use for each algorithms and datasets.

    Recommended parameters to use for each algorithms and datasets.

    I've tried several combinations of different algorithms and datasets, but I found it's hard to get results similar to the paper. I suspect this is due to inappropriate parameter settings, so, I think it will be great if this repo can provide some recommended parameters. (Especially for the newly added algorithms, it's hard to judge if it get the right results)

    opened by mrbeann 0
Releases(v1.1)
  • v1.1(Nov 9, 2021)

    What's new:

    • A branch of new methods: WeaSEL, ImplyLoss, ASTRA, MeanTeacher, Meta-Weight-Net, Learning-to-Reweight
    • A new EndClassifierModel model which unifies all the classification backbones
    • Two new datasets on image classification
    • Support torch native amp for inference in the validation step
    • Support training on multiple GPUS via torch's DistributedDataParallel and the new parallel_fit function
    • fixed some bugs
    Source code(tar.gz)
    Source code(zip)
  • v1.0(Sep 7, 2021)

Owner
Jieyu Zhang
CS PhD
Jieyu Zhang
FedScale: Benchmarking Model and System Performance of Federated Learning

FedScale: Benchmarking Model and System Performance of Federated Learning (Paper) This repository contains scripts and instructions of building FedSca

268 Jan 01, 2023
This is an official PyTorch implementation of Task-Adaptive Neural Network Search with Meta-Contrastive Learning (NeurIPS 2021, Spotlight).

NeurIPS 2021 (Spotlight): Task-Adaptive Neural Network Search with Meta-Contrastive Learning This is an official PyTorch implementation of Task-Adapti

Wonyong Jeong 15 Nov 21, 2022
PyTorch implementation of NeurIPS 2021 paper: "CoFiNet: Reliable Coarse-to-fine Correspondences for Robust Point Cloud Registration"

CoFiNet: Reliable Coarse-to-fine Correspondences for Robust Point Cloud Registration (NeurIPS 2021) PyTorch implementation of the paper: CoFiNet: Reli

76 Jan 03, 2023
Lowest memory consumption and second shortest runtime in NTIRE 2022 challenge on Efficient Super-Resolution

FMEN Lowest memory consumption and second shortest runtime in NTIRE 2022 on Efficient Super-Resolution. Our paper: Fast and Memory-Efficient Network T

33 Dec 01, 2022
Data cleaning, missing value handle, EDA use in this project

Lending Club Case Study Project Brief Solving this assignment will give you an idea about how real business problems are solved using EDA. In this cas

Dhruvil Sheth 1 Jan 05, 2022
MINOS: Multimodal Indoor Simulator

MINOS Simulator MINOS is a simulator designed to support the development of multisensory models for goal-directed navigation in complex indoor environ

194 Dec 27, 2022
PyTorch implementation of NIPS 2017 paper Dynamic Routing Between Capsules

Dynamic Routing Between Capsules - PyTorch implementation PyTorch implementation of NIPS 2017 paper Dynamic Routing Between Capsules from Sara Sabour,

Adam Bielski 475 Dec 24, 2022
PyTorch implementation of "VRT: A Video Restoration Transformer"

VRT: A Video Restoration Transformer Jingyun Liang, Jiezhang Cao, Yuchen Fan, Kai Zhang, Rakesh Ranjan, Yawei Li, Radu Timofte, Luc Van Gool Computer

Jingyun Liang 837 Jan 09, 2023
[NeurIPS 2021 Spotlight] Aligning Pretraining for Detection via Object-Level Contrastive Learning

SoCo [NeurIPS 2021 Spotlight] Aligning Pretraining for Detection via Object-Level Contrastive Learning By Fangyun Wei*, Yue Gao*, Zhirong Wu, Han Hu,

Yue Gao 139 Dec 14, 2022
Repo for CVPR2021 paper "QPIC: Query-Based Pairwise Human-Object Interaction Detection with Image-Wide Contextual Information"

QPIC: Query-Based Pairwise Human-Object Interaction Detection with Image-Wide Contextual Information by Masato Tamura, Hiroki Ohashi, and Tomoaki Yosh

105 Dec 23, 2022
Breaking Shortcut: Exploring Fully Convolutional Cycle-Consistency for Video Correspondence Learning

Breaking Shortcut: Exploring Fully Convolutional Cycle-Consistency for Video Correspondence Learning Yansong Tang *, Zhenyu Jiang *, Zhenda Xie *, Yue

Zhenyu Jiang 12 Nov 16, 2022
An efficient toolkit for Face Stylization based on the paper "AgileGAN: Stylizing Portraits by Inversion-Consistent Transfer Learning"

MMGEN-FaceStylor English | ็ฎ€ไฝ“ไธญๆ–‡ Introduction This repo is an efficient toolkit for Face Stylization based on the paper "AgileGAN: Stylizing Portraits

OpenMMLab 182 Dec 27, 2022
Fully Connected DenseNet for Image Segmentation

Fully Connected DenseNets for Semantic Segmentation Fully Connected DenseNet for Image Segmentation implementation of the paper The One Hundred Layers

Somshubra Majumdar 84 Oct 31, 2022
Generative code template for PixelBeasts 10k NFT project.

generator-template Generative code template for combining transparent png attributes into 10,000 unique images. Used for the PixelBeasts 10k NFT proje

Yohei Nakajima 9 Aug 24, 2022
DeepLabv3+๏ผšEncoder-Decoder with Atrous Separable Convolution่ฏญไน‰ๅˆ†ๅ‰ฒๆจกๅž‹ๅœจtensorflow2ๅฝ“ไธญ็š„ๅฎž็Žฐ

DeepLabv3+๏ผšEncoder-Decoder with Atrous Separable Convolution่ฏญไน‰ๅˆ†ๅ‰ฒๆจกๅž‹ๅœจtensorflow2ๅฝ“ไธญ็š„ๅฎž็Žฐ ็›ฎๅฝ• ๆ€ง่ƒฝๆƒ…ๅ†ต Performance ๆ‰€้œ€็Žฏๅขƒ Environment ๆณจๆ„ไบ‹้กน Attention ๆ–‡ไปถไธ‹่ฝฝ Download

Bubbliiiing 31 Nov 25, 2022
FaceQgen: Semi-Supervised Deep Learning for Face Image Quality Assessment

FaceQgen FaceQgen: Semi-Supervised Deep Learning for Face Image Quality Assessment This repository is based on the paper: "FaceQgen: Semi-Supervised D

Javier Hernandez-Ortega 3 Aug 04, 2022
An implementation of the proximal policy optimization algorithm

PPO Pytorch C++ This is an implementation of the proximal policy optimization algorithm for the C++ API of Pytorch. It uses a simple TestEnvironment t

Martin Huber 59 Dec 09, 2022
Image-generation-baseline - MUGE Text To Image Generation Baseline

MUGE Text To Image Generation Baseline Requirements and Installation More detail

23 Oct 17, 2022
ใ€ŠK-Adapter: Infusing Knowledge into Pre-Trained Models with Adaptersใ€‹(2020)

K-Adapter: Infusing Knowledge into Pre-Trained Models with Adapters This repository is the implementation of the paper "K-Adapter: Infusing Knowledge

Microsoft 118 Dec 13, 2022
Scripts used to make and evaluate OpenAlex's concept tagging model

openalex-concept-tagging This repository contains all of the code for getting the concept tagger up and running. To learn more about where this model

OurResearch 18 Dec 09, 2022