카테고리 없음
Torch + collate_fn
jaeha_lee
2024. 7. 7. 22:59
collate_fn
은 PyTorch의 DataLoader
에서 배치(batch)를 구성하는 방식을 사용자 정의할 수 있게 해주는 함수.
기본적으로 DataLoader
는 단순히 각 데이터를 배치 단위로 묶어서 반환하지만, 데이터의 형태가 복잡하거나 다양한 전처리가 필요한 경우 collate_fn
을 사용하여 맞춤형 배치 구성을 할 수 있음
- 기본적인
Dataset
구현 collate_fn
을 정의하여DataLoader
에 전달
예시
import torch
from torch.utils.data import Dataset, DataLoader
class MyDataset(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
data = self.data[idx]
label = self.labels[idx]
return data, label
def my_collate_fn(batch):
"""
주어진 배치를 커스터마이즈하는 함수입니다.
Args:
batch (list of tuples): Dataset에서 반환된 배치 데이터의 리스트.
각 튜플은 (데이터, 레이블) 형식입니다.
Returns:
tuple: 배치 데이터와 배치 레이블을 각각 텐서로 변환하여 반환합니다.
"""
data = [item[0] for item in batch]
labels = [item[1] for item in batch]
# 여기서 텐서로 변환하고 추가 전처리를 수행할 수 있습니다.
data = torch.stack(data, dim=0)
labels = torch.stack(labels, dim=0)
return data, labels
# 데이터와 레이블 생성
data = [torch.tensor([i]) for i in range(10)]
labels = [torch.tensor([i % 2]) for i in range(10)]
# 커스텀 데이터셋 인스턴스 생성
dataset = MyDataset(data, labels)
# DataLoader 생성
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=my_collate_fn)
# DataLoader를 사용하여 데이터 확인
for batch in dataloader:
print(batch)
설명
my_collate_fn
: 배치 데이터를 커스터마이즈하는 함수. 여기서는 단순히 데이터와 레이블을 각각 리스트에 모은 후,torch.stack
을 사용하여 배치 텐서로 변환함.batch
:Dataset
에서 반환된 여러 (데이터, 레이블) 튜플의 리스트.- 이 함수는 배치의 데이터를 묶고, 텐서로 변환하고, 필요한 전처리를 수행할 수 있음.
DataLoader
:collate_fn
인자로my_collate_fn
을 전달하여 사용자 정의 배치 구성을 적용함
이 방법을 사용하면 복잡한 데이터 구조나 특별한 전처리가 필요한 경우에도 유연하게 배치를 구성할 수 있음