카테고리 없음
torch로 dataset 만들기
jaeha_lee
2024. 7. 7. 22:57
PyTorch에서 나만의 Dataset을 만들 때 반드시 작성해야 하는 함수는 다음과 같다.
__init__
: 데이터셋의 초기화를 담당__len__
: 데이터셋의 크기(길이)를 반환__getitem__
: 주어진 인덱스에 해당하는 데이터를 반환
PyTorch의 Dataset
클래스를 상속받아야 한다.
아래는 그 예시
import torch
from torch.utils.data import Dataset, DataLoader
class MyDataset(Dataset):
def __init__(self, data, labels):
"""
데이터셋의 초기화를 수행합니다.
Args:
data (list or numpy array): 데이터 리스트 또는 배열
labels (list or numpy array): 레이블 리스트 또는 배열
"""
self.data = data
self.labels = labels
def __len__(self):
"""
데이터셋의 크기를 반환합니다.
Returns:
int: 데이터셋의 크기
"""
return len(self.data)
def __getitem__(self, idx):
"""
주어진 인덱스에 해당하는 데이터를 반환합니다.
Args:
idx (int): 데이터의 인덱스
Returns:
tuple: (데이터, 레이블)
"""
data = self.data[idx]
label = self.labels[idx]
return data, label
# 데이터와 레이블 생성
data = [torch.tensor([i]) for i in range(10)]
labels = [torch.tensor([i % 2]) for i in range(10)]
# 커스텀 데이터셋 인스턴스 생성
dataset = MyDataset(data, labels)
# DataLoader 생성
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
# DataLoader를 사용하여 데이터 확인
for batch in dataloader:
print(batch)
설명
__init__
함수: 데이터와 레이블을 받아서 초기화함. 여기서는 간단히 리스트로 받아 저장__len__
함수: 데이터셋의 크기를 반환함. 데이터 리스트의 길이를 반환하도록 구현함.__getitem__
함수: 주어진 인덱스에 해당하는 데이터와 레이블을 반환함. 데이터와 레이블을 튜플 형태로 반환함.DataLoader
를 사용하여 배치 단위로 데이터를 로드할 수 있으며,shuffle=True
로 설정하여 매 에포크마다 데이터셋이 섞이도록 할 수 있음