n_groups_per_batch does not work for Poverty · Issue #79 · p-lambda/wilds · GitHub
Skip to content

n_groups_per_batch does not work for Poverty #79

Description

@niels-leif-bracher

Hi WILDS Team,
I am currently working with the WILDS repository. Currently, I work with the poverty dataset. I've run the script run_exp.py in the examples folder with the argument --n_groups_per_batch=3 or a different number. However, per batch, I get samples from more than 3 different groups. Do is use this argument wrongly? I understood the argument --n_groups_per_batch as the number of different environments from which samples exist in one batch.

The command line reads:
python examples/run_expt.py --dataset poverty --algorithm ERM --root_dir data --n_epochs=200 --seed=0 --log_every=200 --batch_size=64 --n_groups_per_batch=2 --progress_bar True

The output when i use the n_groups_variable defined in IRM.py:
n groups: 13
groups: tensor([ 3, 5, 7, 9, 10, 11, 13, 14, 16, 19, 20, 21, 22], device='cuda:0')

In addition, is the command --uniform_over_groups valid for Poverty? Since the samples are not uniformly distributed over the different environments used in the training split?

Thanks in advance for your help.

Niels

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Fields

    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions