카테고리 없음

torch로 dataset 만들기

jaeha_lee 2024. 7. 7. 22:57

PyTorch에서 나만의 Dataset을 만들 때 반드시 작성해야 하는 함수는 다음과 같다.

  1. __init__: 데이터셋의 초기화를 담당
  2. __len__: 데이터셋의 크기(길이)를 반환
  3. __getitem__: 주어진 인덱스에 해당하는 데이터를 반환

PyTorch의 Dataset 클래스를 상속받아야 한다.
아래는 그 예시

import torch
from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
    def __init__(self, data, labels):
        """
        데이터셋의 초기화를 수행합니다.

        Args:
            data (list or numpy array): 데이터 리스트 또는 배열
            labels (list or numpy array): 레이블 리스트 또는 배열
        """
        self.data = data
        self.labels = labels

    def __len__(self):
        """
        데이터셋의 크기를 반환합니다.

        Returns:
            int: 데이터셋의 크기
        """
        return len(self.data)

    def __getitem__(self, idx):
        """
        주어진 인덱스에 해당하는 데이터를 반환합니다.

        Args:
            idx (int): 데이터의 인덱스

        Returns:
            tuple: (데이터, 레이블)
        """
        data = self.data[idx]
        label = self.labels[idx]
        return data, label

# 데이터와 레이블 생성
data = [torch.tensor([i]) for i in range(10)]
labels = [torch.tensor([i % 2]) for i in range(10)]

# 커스텀 데이터셋 인스턴스 생성
dataset = MyDataset(data, labels)

# DataLoader 생성
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# DataLoader를 사용하여 데이터 확인
for batch in dataloader:
    print(batch)

설명

  • __init__ 함수: 데이터와 레이블을 받아서 초기화함. 여기서는 간단히 리스트로 받아 저장
  • __len__ 함수: 데이터셋의 크기를 반환함. 데이터 리스트의 길이를 반환하도록 구현함.
  • __getitem__ 함수: 주어진 인덱스에 해당하는 데이터와 레이블을 반환함. 데이터와 레이블을 튜플 형태로 반환함.
  • DataLoader를 사용하여 배치 단위로 데이터를 로드할 수 있으며, shuffle=True로 설정하여 매 에포크마다 데이터셋이 섞이도록 할 수 있음