You are on page 1of 11

4/23/2020 Unbalanced data loading for multi-task learning in PyTorch

Unbalanced data loading for multi-task learning in


PyTorch
A practical PyTorch guide for training multi-task models on multiple unbalanced datasets

Omri Bar Follow


Jan 7 · 5 min read

Designed by Kjpargeter / Freepik

Working on multi-task learning (MTL) problems require a unique training setup, mainly
in terms of data handling, model architecture, and performance evaluation metrics.

In this post, I am reviewing the data handling part. Specifically, how to train a multi-task
Read more on Medium. Create a free account.
learning model on multiple datasets and how to handle tasks with a highly unbalanced
https://towardsdatascience.com/unbalanced-data-loading-for-multi-task-learning-in-pytorch-e030ad5033b 1/11
4/23/2020 Unbalanced data loading for multi-task learning in PyTorch

dataset.

I will describe my suggestion in three steps:

1. Combining two (or more) datasets into a single PyTorch Dataset. This dataset will be
the input for a PyTorch DataLoader.

2. Modifying the batch preparation process to produce either one task in each batch or
alternatively mix samples from both tasks in each batch.

3. Handling the highly unbalanced datasets at the batch level by using a batch sampler
as part of the DataLoader.

I am only reviewing Dataset and DataLoader related code, ignoring other important
modules like the model, optimizer and metrics definition.

. . .

For simplicity, I am using a generic two dataset example. However, the number of
datasets and the type of data should not affect the main setup. We can even use several
instances of the same dataset, in case we have more than one set of labels for the same
set of samples. For example, a dataset of images with an object class and a spatial
location, or a face emotions dataset with facial emotion and age labeling per image.

A PyTorch Dataset class needs to implement the __getitem__() function. This function
handles samples fetching and preparation for a given index. When using two datasets, it
is then possible to have two different methods of creating samples. Hence, we can even
use a single dataset, get samples with different labels, and change the samples
processing scheme (the output samples should have the same shape since we stack them
as a batch tensor).

First, let’s define two datasets to work with:

1 import torch
2 from torch.utils.data.dataset import ConcatDataset

Read3 more on Medium. Create a free account.


4
5 l M Fi tD t t(t h til d t D t t)
https://towardsdatascience.com/unbalanced-data-loading-for-multi-task-learning-in-pytorch-e030ad5033b 2/11
4/23/2020 Unbalanced data loading for multi-task learning in PyTorch
5 class MyFirstDataset(torch.utils.data.Dataset):
6 def __init__(self):
7 # dummy dataset
8 self.samples = torch.cat((-torch.ones(5), torch.ones(5)))
9
10 def __getitem__(self, index):
11 # change this to your samples fetching logic
12 return self.samples[index]
13
14 def __len__(self):
15 # change this to return number of samples in your dataset
16 return self.samples.shape[0]
17
18
19 class MySecondDataset(torch.utils.data.Dataset):
20 def __init__(self):
21 # dummy dataset
22 self.samples = torch.cat((torch.ones(50) * 5, torch.ones(5) * -5))
23
24 def __getitem__(self, index):
25 # change this to your samples fetching logic
26 return self.samples[index]
27
28 def __len__(self):
29 # change this to return number of samples in your dataset
30 return self.samples.shape[0]
31
32
33 first_dataset = MyFirstDataset()
34 second_dataset = MySecondDataset()
35 concat_dataset = ConcatDataset([first_dataset, second_dataset])

basic_dataset_example.py hosted with ❤ by GitHub view raw

We define two (binary) datasets, one with ten samples of ±1 (equally distributed), and
the second with 55 samples, 50 samples of the digit 5, and 5 samples of the digit -5.
These datasets are only for illustration. In real datasets, you should have both the
samples and the labels, you will probably read the data from a database or parse it from
data folders, but these simple datasets are enough to understand the main concepts.

Next, we need to define a DataLoader. We provide it with our concat_dataset and set the
loader
Read parameters,
more such as
on Medium. Create theaccount.
a free batch size, and whether or not to shuffle the samples.

https://towardsdatascience.com/unbalanced-data-loading-for-multi-task-learning-in-pytorch-e030ad5033b 3/11
4/23/2020 Unbalanced data loading for multi-task learning in PyTorch

1 batch_size = 8
2
3 # basic dataloader
4 dataloader = torch.utils.data.DataLoader(dataset=concat_dataset,
5 batch_size=batch_size,
6 shuffle=True)
7
8 for inputs in dataloader:
9 print(inputs)

basic_dataloader_example.py hosted with ❤ by GitHub view raw

The output of this part looks like:

tensor([ 5., 5., 5., 5., -5., 5., -5., 5.])


tensor([5., 5., 5., 5., 5., 5., 5., 5.])
tensor([-1., -5., 5., 1., 5., -1., 5., -1.])
tensor([5., 5., 5., 5., 5., 5., 5., 5.])
tensor([ 5., 5., 5., 5., -5., 1., 5., 5.])
tensor([ 5., 5., 5., 1., 5., 5., 5., -1.])
tensor([ 5., 5., 5., 5., -1., 5., 1., 5.])
tensor([ 5., -5., 1., 5., 5., 5., 5., 5.])
tensor([5.])

Each batch is a tensor of 8 samples from our concat_dataset. The order is set randomly,
and samples are selected from the pool of samples.

Until now, everything was relatively straight forward. The datasets are combined into a
single one, and samples are randomly picked from both of the original datasets to
construct the mini-batch. Now let’s try to control and manipulate the samples in each
batch. We want to get samples from only one dataset in each mini-batch, switching
between them every other batch.

1 import torch
2 from torch.utils.data.sampler import RandomSampler
3
4
5 class BatchSchedulerSampler(torch.utils.data.sampler.Sampler):
6 """

Read7 more oniterate


Medium.over tasks
Create andaccount.
a free provide a random batch per task in each mini-batch
8 """

https://towardsdatascience.com/unbalanced-data-loading-for-multi-task-learning-in-pytorch-e030ad5033b 4/11
4/23/2020 Unbalanced data loading for multi-task learning in PyTorch
9 def __init__(self, dataset, batch_size):
10 self.dataset = dataset
11 self.batch_size = batch_size
12 self.number_of_datasets = len(dataset.datasets)
13
14 def __len__(self):
15 return len(self.dataset) * self.number_of_datasets
16
17 def __iter__(self):
18 samplers_list = []
19 sampler_iterators = []
20 datasets_length = []
21 for dataset_idx in range(self.number_of_datasets):
22 cur_dataset = self.dataset.datasets[dataset_idx]
23 sampler = RandomSampler(cur_dataset)
24 samplers_list.append(sampler)
25 cur_sampler_iterator = sampler.__iter__()
26 sampler_iterators.append(cur_sampler_iterator)
27 datasets_length.append(len(cur_dataset))
28
29 push_index_val = [0] + self.dataset.cumulative_sizes[:-1]
30 step = self.batch_size * self.number_of_datasets
31 samples_to_grab = self.batch_size
32 largest_dataset_index = torch.argmax(torch.as_tensor(datasets_length)).item()
33 # for this case we want to get all samples in dataset, this force us to resample
34 epoch_samples = datasets_length[largest_dataset_index] * self.number_of_datasets
35
36 final_samples_list = [] # this is a list of indexes from the combined dataset
37 for _ in range(0, epoch_samples, step):
38 for i in range(self.number_of_datasets):
39 cur_batch_sampler = sampler_iterators[i]
40 cur_samples = []
41 for _ in range(samples_to_grab):
42 try:
43 cur_sample_org = cur_batch_sampler.__next__()
44 cur_sample = cur_sample_org + push_index_val[i]
45 cur_samples.append(cur_sample)
46 except StopIteration:
47 if i == largest_dataset_index:
48 # largest dataset iterator is done we can break
49 samples_to_grab = len(cur_samples) # adjusting the samples_
50 # got to the end of iterator - extend final list and continu
51 break
Read
52 more on Medium. Create a free account.
else:
53 # restart the iterator - we want more samples until finishin
https://towardsdatascience.com/unbalanced-data-loading-for-multi-task-learning-in-pytorch-e030ad5033b 5/11
4/23/2020 Unbalanced data loading for multi-task learning in PyTorch

54 sampler_iterators[i] = samplers_list[i].__iter__()
55 cur_batch_sampler = sampler_iterators[i]
56 cur_sample_org = cur_batch_sampler.__next__()
57 cur_sample = cur_sample_org + push_index_val[i]
58 cur_samples.append(cur_sample)
59 final_samples_list.extend(cur_samples)
60
61 return iter(final_samples_list)

multi_task_batch_scheduler.py hosted with ❤ by GitHub view raw

This is the definition of a BatchSchedulerSampler class, which creates a new samples


iterator. First, by creating a RandomSampler for each internal dataset. And second by
pulling samples (actually samples indexes) from each internal dataset iterator. Thus,
building a new list of samples indexes. Using a batch size of 8 means that from each
dataset we need to fetch 8 samples.

Now let’s run and print the samples using a new DataLoader, which gets our
BatchSchedulerSampler as an input sampler (shuffle can’t be set to True when working
with a sampler).

1 import torch
2 from multi_task_batch_scheduler import BatchSchedulerSampler
3
4 batch_size = 8
5
6 # dataloader with BatchSchedulerSampler
7 dataloader = torch.utils.data.DataLoader(dataset=concat_dataset,
8 sampler=BatchSchedulerSampler(dataset=concat_da
9 batch_size=batch_
10 batch_size=batch_size,
11 shuffle=False)
12
13 for inputs in dataloader:
14 print(inputs)

batch_scheduler_dataloader_example.py hosted with ❤ by GitHub view raw

Themore
Read output now looks
on Medium. like
Create this:account.
a free

https://towardsdatascience.com/unbalanced-data-loading-for-multi-task-learning-in-pytorch-e030ad5033b 6/11
4/23/2020 Unbalanced data loading for multi-task learning in PyTorch

tensor([-1., -1., 1., 1., -1., 1., 1., -1.])


tensor([5., 5., 5., 5., 5., 5., 5., 5.])
tensor([ 1., -1., -1., -1., 1., 1., -1., 1.])
tensor([5., 5., 5., 5., 5., 5., 5., 5.])
tensor([-1., -1., 1., 1., 1., -1., 1., -1.])
tensor([ 5., 5., -5., 5., 5., -5., 5., 5.])
tensor([ 1., 1., -1., -1., 1., -1., 1., 1.])
tensor([5., 5., 5., 5., 5., 5., 5., 5.])
tensor([-1., -1., -1., -1., 1., 1., 1., -1.])
tensor([ 5., -5., 5., 5., 5., 5., -5., 5.])
tensor([-1., 1., -1., 1., -1., 1., 1., -1.])
tensor([ 5., 5., 5., 5., 5., -5., 5., 5.])
tensor([ 1., -1., -1., 1., 1., 1., 1., -1.])
tensor([5., 5., 5., 5., 5., 5., 5.])

Hurray!!!
For each mini-batch we now get only one dataset samples.
We can play with this type of scheduling in order to downsample or upsample more
important tasks.

The remaining problem in our batches now comes from the second highly unbalanced
dataset. This is often the case in MTL, having a main task and a few other satellite sub-
tasks. Training the main task and sub-tasks together might lead to improve performance
and contribute to the generalization of the overall model. The problem is that samples of
the sub-tasks are often very sparse, having only a few positive (or negative) samples.
Let’s use our previous logic but also forcing a balanced batch with respect to the
distribution of samples in each task.

To handle the unbalanced issue, we need to replace the random sampler in the
BatchSchedulerSampler class with an ImbalancedDatasetSampler (I am using a great
implementation from this repository). This class handles the balancing of the dataset.
We can also mix and use RandomSampler for some tasks and ImbalancedDatasetSampler
for others.

1 import torch
2 from torch.utils.data import RandomSampler
3 from sampler import ImbalancedDatasetSampler
4
5
Read more on Medium. Create a free account.
6 class ExampleImbalancedDatasetSampler(ImbalancedDatasetSampler):
7 """
https://towardsdatascience.com/unbalanced-data-loading-for-multi-task-learning-in-pytorch-e030ad5033b 7/11
4/23/2020 Unbalanced data loading for multi-task learning in PyTorch
7
8 ImbalancedDatasetSampler is taken from https://github.com/ufoym/imbalanced-dataset-s
9 In order to be able to show the usage of ImbalancedDatasetSampler in this example I
10 to fit my datasets
11 """
12 def _get_label(self, dataset, idx):
13 return dataset.samples[idx].item()
14
15
16 class BalancedBatchSchedulerSampler(torch.utils.data.sampler.Sampler):
17 """
18 iterate over tasks and provide a balanced batch per task in each mini-batch
19 """
20 def __init__(self, dataset, batch_size):
21 self.dataset = dataset
22 self.batch_size = batch_size
23 self.number_of_datasets = len(dataset.datasets)
24
25 def __len__(self):
26 return len(self.dataset) * self.number_of_datasets
27
28 def __iter__(self):
29 samplers_list = []
30 sampler_iterators = []
31 datasets_length = []
32 for dataset_idx in range(self.number_of_datasets):
33 cur_dataset = self.dataset.datasets[dataset_idx]
34 if dataset_idx == 0:
35 # the first dataset is kept at RandomSampler
36 sampler = RandomSampler(cur_dataset)
37 else:
38 # the second unbalanced dataset is changed
39 sampler = ExampleImbalancedDatasetSampler(cur_dataset)
40 samplers_list.append(sampler)
41 cur_sampler_iterator = sampler.__iter__()
42 sampler_iterators.append(cur_sampler_iterator)
43 datasets_length.append(len(cur_dataset))
44
45 push_index_val = [0] + self.dataset.cumulative_sizes[:-1]
46 step = self.batch_size * self.number_of_datasets
47 samples_to_grab = self.batch_size
48 largest_dataset_index = torch.argmax(torch.as_tensor(datasets_length)).item()
49 # for this case we want to get all samples in dataset, this force us to resample
Read
50 more on Medium. Create a free
epoch_samples account.
= datasets_length[largest_dataset_index] * self.number_of_datasets
51
https://towardsdatascience.com/unbalanced-data-loading-for-multi-task-learning-in-pytorch-e030ad5033b 8/11
4/23/2020 Unbalanced data loading for multi-task learning in PyTorch

52 final_samples_list = [] # this is a list of indexes from the combined dataset


53 for _ in range(0, epoch_samples, step):
54 for i in range(self.number_of_datasets):
55 cur_batch_sampler = sampler_iterators[i]
56 cur_samples = []
57 for _ in range(samples_to_grab):
58 try:
59 cur_sample_org = cur_batch_sampler.__next__()
60 cur_sample = cur_sample_org + push_index_val[i]
61 cur_samples.append(cur_sample)
62 except StopIteration:
63 if i == largest_dataset_index:
64 # largest dataset iterator is done we can break
65 samples_to_grab = len(cur_samples) # adjusting the samples_
66 break # got to the end of iterator - extend final list and
67 else:
68 # restart the iterator - we want more samples until finishin
69 sampler_iterators[i] = samplers_list[i].__iter__()
70 cur_batch_sampler = sampler_iterators[i]
71 cur_sample_org = cur_batch_sampler.__next__()
72 cur_sample = cur_sample_org + push_index_val[i]
73 cur_samples.append(cur_sample)
74 final_samples_list.extend(cur_samples)
75
76 return iter(final_samples_list)

balanced_sampler.py hosted with ❤ by GitHub view raw

We first create ExampleImbalancedDatasetSampler, which inherit from


ImbalancedDatasetSampler, only modifying the _get_label function to fit our use case.

Next, we use BalancedBatchSchedulerSampler, which is similar to the previous


BatchSchedulerSampler class, only replacing the usage of RandomSampler for the
unbalanced task with the ExampleImbalancedDatasetSampler.

Let’s run the new DataLoader:

1 import torch
2 from balanced_sampler import BalancedBatchSchedulerSampler
3
Read4 more on Medium.
batch_size = 8Create a free account.
5
https://towardsdatascience.com/unbalanced-data-loading-for-multi-task-learning-in-pytorch-e030ad5033b 9/11
4/23/2020 Unbalanced data loading for multi-task learning in PyTorch

6 # dataloader with BalancedBatchSchedulerSampler


7 dataloader = torch.utils.data.DataLoader(dataset=concat_dataset,
8 sampler=BalancedBatchSchedulerSampler(dataset=c
9 batch_siz
10 batch_size=batch_size,
11 shuffle=False)
12
13 for inputs in dataloader:
14 print(inputs)

balanced_batch_scheduler_dataloader_example.py hosted with ❤ by GitHub view raw

The output looks like:

tensor([-1., 1., 1., -1., -1., -1., 1., -1.])


tensor([ 5., 5., 5., 5., -5., -5., -5., -5.])
tensor([ 1., 1., 1., -1., 1., -1., 1., 1.])
tensor([ 5., -5., 5., -5., -5., -5., 5., 5.])
tensor([-1., -1., 1., -1., -1., -1., -1., 1.])
tensor([-5., 5., 5., 5., 5., -5., 5., -5.])
tensor([-1., -1., 1., 1., 1., 1., -1., -1.])
tensor([-5., 5., 5., 5., 5., -5., 5., 5.])
tensor([ 1., -1., 1., 1., 1., -1., 1., -1.])
tensor([ 5., 5., 5., -5., 5., -5., 5., 5.])
tensor([-1., -1., -1., -1., 1., 1., 1., 1.])
tensor([-5., 5., 5., 5., 5., 5., -5., 5.])
tensor([-1., 1., -1., 1., 1., 1., 1., 1.])
tensor([-5., -5., 5., 5., -5., -5., 5.])

The mini-batches of the unbalanced task are now much more balanced.

There is a lot of room to play with this setup even further. We can combine the tasks in a
balanced way, and by setting the samples_to_grab to 4, which is half of the batch size, we
can get a mixed mini-batch with 4 samples taken from each task. To produce a ratio of
1:2 toward a more important task, we can set samples_to_grab=2 for the first task and
samples_to_grab=6 for the second task.

That’s it. The full code can be downloaded from my repository.

Thanks to Amber Teng.


Read more on Medium. Create a free account.

https://towardsdatascience.com/unbalanced-data-loading-for-multi-task-learning-in-pytorch-e030ad5033b 10/11
4/23/2020 Unbalanced data loading for multi-task learning in PyTorch

Machine Learning Pytorch Multi Task Learning Data Handling Unbalanced Data

About Help Legal

Get the Medium app

Read more on Medium. Create a free account.

https://towardsdatascience.com/unbalanced-data-loading-for-multi-task-learning-in-pytorch-e030ad5033b 11/11

You might also like