Create a Custom PyTorch Dataset with a CSV File

Ruman
5 min readMar 23, 2023

--

Introduction

The PyTorch default dataset has certain limitations, particularly with regard to its file structure requirements. Specifically, it expects all images to be categorized into separate folders, with each folder representing a distinct class. While this works well for small datasets, it becomes increasingly challenging to manage with larger datasets, such as those exceeding 100GB, as moving data into this folder structure can become prohibitively difficult.

If you’ve ever found yourself wishing for a more streamlined way to pass image paths and labels without having to shuffle data around, you’re in luck: the PyTorch custom dataset is here to fulfill your request!

This article will guide you through the process of using a CSV file to pass image paths and labels to your PyTorch dataset. By following the steps outlined here, you’ll be able to optimize your workflow and streamline your data handling process.

Problem with default Dataset

If you’re reading this article, you’re likely already aware of the limitations inherent in PyTorch’s default dataset. Nonetheless, it’s worth taking a closer look at the specific issues that this can cause for data handling. By doing so, we can better understand why a custom dataset is so crucial for optimizing your workflow and improving your overall data handling process.

Here’s a script to load the images to Pytorch dataset

# define image transformation
transformation = transforms.Compose([
transforms.Resize((IMAGE_SIZE,IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=MEAN, std=STD),
])

# Load all of the images with default dataset
full_dataset = torchvision.datasets.ImageFolder(
root=data_path, # data_path is path to directory
transform=transformation
)

train_size = int(0.7 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])

# Prepare data for training with DataLoaders
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
num_workers=0,
shuffle=False)

# Prepare data for test with DataLoader
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=BATCH_SIZE,
num_workers=0,
shuffle=False)

The script mentioned above uses “data_path” as the directory where all images are saved in separate folders, each folder representing a specific class. This directory has a file structure similar to the screenshot shown below.

Dealing with large data sets, such as 100 GB or more, can become challenging due to the need to move the images into the above shown folder structure.

Solving this with custom Dataset

Load the CSV file we’re going to use

Our CSV file have two columns namely

  • file_path : is path to the image on disk
  • label : it has respective label to the image. There are only two labels with respect to our data i.e, face_mask and no_face_mask

It’s important to keep in mind that your data may come in a variety of formats and with varying numbers of columns. You may have your own CSV file with different data structures that need to be processed in a specific way.

Define a custom Dataset class

class CustomDataSet(Dataset):
def __init__(self, csv_file, class_list, transform=None):
self.df = pd.read_csv(csv_file)
self.transform = transform
self.class_list = class_list

def __len__(self):
return self.df.shape[0]

def __getitem__(self, index):
image = Image.open(self.df.file_path[index])
label = self.class_list.index(self.df.label[index])

if self.transform:
image = self.transform(image)
return image, label

A custom Dataset class must have these three methods.

  • __inti__ Default function to initialize the custom Dataset class. Here I’ll pass csv_file, class_list and transform as an argument while initializing the custom dataset class.
  • __len__ Here we return the size of dataset i.e, total number of samples.
  • __getitem__ This function is used to load and return specific data samples from the dataset based on a given index. This function allows to define all the necessary actions that need to be performed on the sample, such as locating the image on disk, reading and converting it to a Tensor, and applying any necessary transformations. This function returns the tensor image and corresponding label.

Kindly take note that the implementation of the __getitem__ method is entirely up to your discretion in terms of how you choose to read the data for the specified image path. Ultimately, you must return both the image tensor and its corresponding labels. The label could be a binary class or a tuple for multi-class classification.

Use DataLoaders to load data in batches

Before loading data in batches with DataLoaders we’ll have to initialize the custom dataset object.

# Create custom dataset object 
train_data_object = CustomDataSet(csv_file_path, class_list, transform)

train_loader = torch.utils.data.DataLoader(train_data_object,
batch_size=10,
shuffle=True)

Let’s plot a batch of image from our custom dataset

Here’s the script to select a batch of image and plot those

train_features, train_labels = next(iter(train_loader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
indx = 0
f, axarr = plt.subplots(2, 5, figsize=(12, 8))
for r in range(0, 2):
for c in range(0, 5):
img = train_features[indx].squeeze()
label = train_labels[indx]
axarr[r, c].imshow(transforms.ToPILImage()(img))
axarr[r, c].set_title(class_labels_map.get(str(label.item())))
indx+=1

After running the code above we’ll see the output like this

Conclusion

In conclusion, we have identified the issues and constraints of using the default PyTorch Dataset and learned how to address them by creating a custom Dataset class.

If you have any questions or uncertainties, please feel free to leave a comment below.

Reference

https://discuss.pytorch.org/t/how-to-generate-a-dataloader-with-file-paths-in-pytorch/147928/2

https://stackoverflow.com/questions/62271194/pytorch-dataloader-from-csv-of-file-paths-and-labels

Adding the full code here

%matplotlib inline
import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset
from torchvision import models, datasets, transforms
import pandas as pd
import os
from PIL import Image

# load the csv file
csv_file_path = "/content/drive/MyDrive/experiments/pytorch_custom_dataset/data.csv"
df = pd.read_csv(csv_file_path)

# define the image transformations
IMAGE_SIZE = 224
data_transform = transforms.Compose([transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
transforms.ToTensor()])

# Define and map the class label
# It would be better to sort the class label names alphabetically
class_labels = ["face_mask","no_face_mask"]
class_labels_map = {}
for indx, label in enumerate(class_labels):
class_labels_map[str(indx)] = label
# class_labels_map -> {'0': 'face_mask', '1': 'no_face_mask'}

# Let's define a custom Dataset class for our data
class CustomDataSet(Dataset):
def __init__(self, csv_file, class_list, transform=None):
self.df = pd.read_csv(csv_file)
self.transform = transform
self.class_list = class_list

def __len__(self):
return self.df.shape[0]

def __getitem__(self, index):
image = Image.open(self.df.file_path[index])
label = self.class_list.index(self.df.label[index])

if self.transform:
image = self.transform(image)
return image, label

# Lets create an object from our custom dataset class
train_data_object = CustomDataSet(csv_file_path, class_labels, data_transform)

# Now lets use Data loader to load the data in batches
train_loader = torch.utils.data.DataLoader(
train_data_object,
batch_size=10,
shuffle=True
)

# Let's plot a batch of image
train_features, train_labels = next(iter(train_loader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
indx = 0
f, axarr = plt.subplots(2, 5, figsize=(12, 8))
for r in range(0, 2):
for c in range(0, 5):
img = train_features[indx].squeeze()
label = train_labels[indx]
axarr[r, c].imshow(transforms.ToPILImage()(img))
axarr[r, c].set_title(class_labels_map.get(str(label.item())))
indx+=1

--

--

Ruman

Senior ML Engineer | Sharing what I know, work on, learn and come across :) | Connect with me @ https://www.linkedin.com/in/rumank/