본문 바로가기
AI

[PyTorch]Dataset과 DataLoader

by Reodreamer 2022. 9. 16.
반응형

이번 포스트는 PyTorch의 Dataset과 DataLoader에 대해 알아보자.

이에 앞서 이전 포스트에서 텐서의 개념에 대해 먼저 공부하고 오는 것이 도움이 될 것이라 생각한다.

 

 

[PyTorch]Tensor-part2

이번 포스트는 Tensor-part1에 이어지는 내용을 다루고자 한다. 이전 포스트를 아직 확인하지 못했다면 아래의 링크로 들어가 먼저 보고 오면 더 좋을 것 같다. [PyTorch]Tensor-part1 텐서의 특징 텐서는

dream-be.tistory.com


 

Dataset

Dataset은 torch.utils.data.Dataset을 이용하여 PyTorch에서 제공하는 pre-loaded dataset을 불러오는데 사용한다. 이와 더불어 개별 데이터를 처리하는 하위 클래스 함수들로 구성되어 있다.

TensorDataset은 클래스로 Dataset을 상속받아 학습을 위한 feature와 레이블을 함께 담는 컨테이너이다. Dataset은 feature를 가져오고 label을 지정하는 작업을 한다.

 

1) Dataset 불러오기

Dataset은 아래와 같은 파라미터를 설정하여 데이터를 불러올 수 있다.

  • root : 데이터를 저장되는 경로 
  • train : 데이터 셋이 train 혹은 test 인지 설정 
  • download : True로 설정하면 사용하고자 하는 데이터가 경로에 없으면 인터넷에서 지정한 root로 다운
  • transform : 데이터 변환을 설정한다.

 

이를 이용해 PyTorch에서 제공하는 FashionMNIST 데이터셋을 불러오는 예제를 통해 이해해보자.

이렇게 불러온 데이터들은 이미지 자체를 확인할 수 없어 원한다면 시각화를 통해 확인할 수 있다. 

 

2) 데이터 시각화하기

labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}

figure = plt.figure(figsize=(8,8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap='gray')
plt.show()

Out:

3) Custom Dataset 생성하기

Custom Dataset을 만들기 위해서는 Custom Dataset 클래스를 구현해야 한다. 클래스 안에는

반드시 __init__, __len__, __getitem__ 이 3가지 함수가 포함되어야 한다.

 

__init__

__init__ 함수는 Dataset 객체를 생성할 때  실행된다. 이를 통해 파라미터로 들어가는

annotation file, 경로, transform을 초기화된다.

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file, names=['file_name', 'label'])
        self.img_dir = img_dir 
        self.transform = transform
        self.target_transform = target_transform

__len__

__len__ 함수는 데이터 셋의 샘플 수를 반환한다. 

def __len__(self):
    return len(self.img_labels)

 

__getitem__ 

__getitem__ 함수는  주어진 인덱스와 일치하는 샘플을 불러와 반환한다. 아래의 코드를 예를 들어 설명하면, 입력되는 인덱스를 받아와 이미지를 확인하고, read_image로 이미지를 텐서로 변환한다.

이후, image_label의 라벨을 가져와 transform을 이용하여 이미지와 라벨을 딕셔너리로 저장한다. 

def __getitem__(self, idx):
    img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx,0])
    image = read_image(img_path)
    label = self.img_labels.iloc[idx, 1]
    if self.transform:
        image = self.transform(image)
    if self.target_transform:
        label = self.target_transform(label)
    return image, label

 

DataLoader

DataLoader는 PyTorch의 핵심 요소 중 하나라고 할 수 있다. DataLoader를 활용하여 Dataset을 iterable 객체로 감싸서 미니 배치 학습, 병렬처리, 데이터 shuffle과 같은 처리를 간단히 할 수 있다. 모델을 학습할 때 데이터셋을 통째로 입력하는 것이 아니라 샘플들을 미니배치로 나누어 전달하고 에폭마다 데이터를 섞는다. 이렇게 하면, 과적합을 막을수 있다. 이를 위해 데이터를 DataLoader로 불러오면 데이터 셋을 순회한다. 그래서 전체 데이터 셋에서 정해진 배치 크기 만큼 데이터를 가져와 처리하고 다음 배치를 불러오면서 학습을 한다. Shuffle 파라미터를 True로 설정하면 한 에폭이 끝나면 데이터를 섞어서 다음 에폭을 학습한다. 그렇게 함으로서 바로 전의 에폭에서 학습한 iteration의 데이터 구성이 다음 에폭의 iteration 구성과 달라 과적합을 예방하는데 도움을 준다. 

from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=32, shuffle=True)

 


이번 포스트에선 PyTorch의 Dataset과 DataLoader에 대해 알아보았다. 데이터는 학습을 하는데 있어 매우 중요하기 때문에 이를 다루는 것 또한 매우 중요하다. 그래서 이를 염두에 두고 공부를 하면 도움이 될 것 같다.

반응형

'AI' 카테고리의 다른 글

[파이토치] Transform  (1) 2022.09.21
[논문리뷰]Attention Is All You Need  (1) 2022.09.20
[PyTorch]Tensor-part2  (0) 2022.09.15
[PyTorch]Tensor-part1  (0) 2022.09.14
[Pytorch]패키지 기본구성  (0) 2022.09.13

댓글