카테고리 없음

Torch + collate_fn

jaeha_lee 2024. 7. 7. 22:59

collate_fn은 PyTorch의 DataLoader에서 배치(batch)를 구성하는 방식을 사용자 정의할 수 있게 해주는 함수.
기본적으로 DataLoader는 단순히 각 데이터를 배치 단위로 묶어서 반환하지만, 데이터의 형태가 복잡하거나 다양한 전처리가 필요한 경우 collate_fn을 사용하여 맞춤형 배치 구성을 할 수 있음

  1. 기본적인 Dataset 구현
  2. collate_fn을 정의하여 DataLoader에 전달

예시

import torch
from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data = self.data[idx]
        label = self.labels[idx]
        return data, label

def my_collate_fn(batch):
    """
    주어진 배치를 커스터마이즈하는 함수입니다.

    Args:
        batch (list of tuples): Dataset에서 반환된 배치 데이터의 리스트. 
                                각 튜플은 (데이터, 레이블) 형식입니다.

    Returns:
        tuple: 배치 데이터와 배치 레이블을 각각 텐서로 변환하여 반환합니다.
    """
    data = [item[0] for item in batch]
    labels = [item[1] for item in batch]

    # 여기서 텐서로 변환하고 추가 전처리를 수행할 수 있습니다.
    data = torch.stack(data, dim=0)
    labels = torch.stack(labels, dim=0)

    return data, labels

# 데이터와 레이블 생성
data = [torch.tensor([i]) for i in range(10)]
labels = [torch.tensor([i % 2]) for i in range(10)]

# 커스텀 데이터셋 인스턴스 생성
dataset = MyDataset(data, labels)

# DataLoader 생성
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=my_collate_fn)

# DataLoader를 사용하여 데이터 확인
for batch in dataloader:
    print(batch)

설명

  • my_collate_fn: 배치 데이터를 커스터마이즈하는 함수. 여기서는 단순히 데이터와 레이블을 각각 리스트에 모은 후, torch.stack을 사용하여 배치 텐서로 변환함.
  • batch: Dataset에서 반환된 여러 (데이터, 레이블) 튜플의 리스트.
  • 이 함수는 배치의 데이터를 묶고, 텐서로 변환하고, 필요한 전처리를 수행할 수 있음.
  • DataLoader: collate_fn 인자로 my_collate_fn을 전달하여 사용자 정의 배치 구성을 적용함

이 방법을 사용하면 복잡한 데이터 구조나 특별한 전처리가 필요한 경우에도 유연하게 배치를 구성할 수 있음