collate_fn
函数用于处理数据加载器(DataLoader)中的一批数据。在PyTorch中使用 DataLoader 时,通过设置collate_fn
,我们可以决定如何将多个样本数据整合到一起成为一个 batch。在某些情况下,该函数需要由用户自定义以满足特定需求。
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as npclass MyDataset(Dataset):def __init__(self, imgs, labels):self.imgs = imgsself.labels = labelsdef __len__(self):return len(self.imgs)def __getitem__(self, idx):img = self.imgs[idx]out_img = img.astype(np.float32)out_img = out_img.transpose(2, 0, 1) #[3, 300, 150]h,w,c -->> c,h,wout_label = self.labels[idx] #[4, 5] or [2, 5]return out_img, out_label#if batchsize=3
#batch is list, [3]
#batch0 tuple2 (np[3, 300, 150], np[4, 5])
#batch1 tuple2 (np[3, 300, 150], np[2, 5])
#batch2 tuple2 (np[3, 300, 150], np[4, 5])
def my_collate_fn(batch):"""Custom collate fn for dealing with batches of images that have a differentnumber of associated object annotations (bounding boxes).Arguments:batch: (tuple) A tuple of tensor images and lists of annotationsReturn:A tuple containing:1) (tensor) batch of images stacked on their 0 dim2) (list of tensors) annotations for a given image are stacked on0 dim"""targets = []imgs = []for sample in batch:imgs.append(torch.FloatTensor(sample[0]))targets.append(torch.FloatTensor(sample[1]))imgs_out = torch.stack(imgs, 0) #[3, 3, 300, 150]return imgs_out, targetsimg_data = []
label_data = []nums = 34
H=300
W=150
for _ in range(nums):random_img = np.random.randint(low=0, high=255, size=(H, W, 3))nums_target = np.random.randint(low=0, high=10)random_xyxy_label = np.random.random((nums_target, 5))img_data.append(random_img)label_data.append(random_xyxy_label)dataset = MyDataset(img_data, label_data)
dataloader = DataLoader(dataset, batch_size=3, collate_fn=my_collate_fn)for cnt, (img, label) in enumerate(dataloader):print("==>>", cnt, ", img shape=", img.shape)for i in range(len(label)):print("label shape=", label[i].shape)
打印如下:
==>> 0 , img shape= torch.Size([3, 3, 300, 150])
label shape= torch.Size([8, 5])
label shape= torch.Size([2, 5])
label shape= torch.Size([5, 5])
==>> 1 , img shape= torch.Size([3, 3, 300, 150])
label shape= torch.Size([3, 5])
label shape= torch.Size([8, 5])
label shape= torch.Size([5, 5])
==>> 2 , img shape= torch.Size([3, 3, 300, 150])
label shape= torch.Size([7, 5])
label shape= torch.Size([1, 5])
label shape= torch.Size([8, 5])