본문 바로가기
AI

[파이토치] Transform

by Reodreamer 2022. 9. 21.
반응형

이번 포스트는 파이토치에서 학습을 위해 데이터를 처리의 한 부분인 Transform을 하는지 공부해 보자. 


Transform은 왜 필요하고 어떻게 할까?

머신러닝과 딥러닝에서 데이터를 학습에 용이하게 활용할 수 있도록 처리를 해야 한다. Transform은 데이터를 학습에 적합하도록 만드는 작업이다. 

 

모든 torchvision dataset은 2개의 파라미터가 있다. transform은 feature들을 처리하는 데 사용하고 target_transform은 label을 처리하는 데 사용한다. 이들은 transformation 로직을 을 가지는 callable 객체를 받는다. 

 

예제를 통해 이해를 해보자. torchvision에서 제공하는 FashionMNIST 데이터 셋의 feature는 PIL 이미지 형식이고, 라벨은 interger 타입이다. 이미지 자체로는 학습을 할 수 없기 때문에 feature들은 정규화된 텐서의 형태로 만들어야 하고 라벨은 원핫-인코딩된 텐서로 변환해야 한다. 이를 위해 ToTensor 변환과 Lambda 변환을 쓸 수 있다. 

import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

ds = datasets.FashionMNIST(
    root="data", 
    train=True, 
    download=True, 
    transform=ToTensor(),
    target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y),value=1))
)

 

ToTensor Transforms 

ToTensor 변환은 PIL 이미지나 Numpy array를 FloatTensor로 변환할 때 사용하는데 이는 이미지의 픽셀 값을 [0,1] 범위로 스케일링한다. 

 

Lambda Transforms

람다 변환은 사용자가 정의한 람다 함수를 적용한다. 위에서는 라벨 값은 원핫-인코딩한는 람다 함수를 사용했다. 데이터 셋의 라벨의 수대로 사이즈가 10인 영텐서를 생성하고 scatter_를 이용하여 기존 라벨 값을 기반으로 1을 값으로 넣었다. 


이번 포스트에선 파이토치를 이용해 학습을 위해 데이터를 변환하는것에 대해 공부해봤다.

 

반응형

'AI' 카테고리의 다른 글

[파이토치]신경망 구성  (0) 2022.09.29
[논문리뷰]Conditional Generative Adversarial Nets(CGAN)  (1) 2022.09.23
[논문리뷰]Attention Is All You Need  (1) 2022.09.20
[PyTorch]Dataset과 DataLoader  (1) 2022.09.16
[PyTorch]Tensor-part2  (0) 2022.09.15

댓글