본문 바로가기
AI

[파이토치] 모델 저장 및 불러오기

by Reodreamer 2022. 10. 7.
반응형

이번 포스트에서는 학습한 모델을 저장하는 방법과 저장된 모델을 불러오는 방법에 대해 다뤄보려고 한다. 일반적으로 모델을 학습하는데 많은 시간이 걸린다. 그렇기 때문에 코드 파일을 매번 실행하는 것은 굉장히 비효율적이다. 그렇기 때문에 학습이 끝난 모델을 저장하고 나중에 사용할 때 그 모델을 불러와서 사용하는 것이 시간과 자원 면에서 효율성을 증대할 수 있다.

 

모델 저장과 불러오기

먼저 실습을 위해 필요한 기본적인 패키지를 import 한다. 

import torch
import torchvision.models as models

모델을 저장하고 불러오는 방법은 크게 두 가지로 나뉜다. 

 

모델의 가중치를 저장하고 불러오는 방식

저장하기

모델이 학습한 파라미터들은 state_dict에 저장된다. 그리고 이 state_dict를 torch.save()를 이용하여 저장 파일을 생성하여 모델을 저장한다. 코드를 통해 확인해 보자. 

model = models.vgg16(pretrained=True)
torch.save(model.state_dict(), 'model_weights.pth')

불러오기 

앞에서 저장한 model의 stat_dict를 가져오면 저장한 모델을 불러오는 것과 동일하다. 자세히 말하자면, 우선 저장할 때 사용한 모델 인스턴스와 동일한 모델 클래스를 생성한다. 그리고 해당 인스턴스에 저장한 state_dict를 load_state_dict()를 이용하여 저장된 파라미터들을 불러온다.

model = models.vgg16()
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()

위에서 model.eval()을 불러오는 것을 볼 수 있다. 이를 불러오는 이유는 추론 단계에서 dropout과 batch normalization을 evaluation mode로 설정하기 위함이다. 그렇게 해야 추론 단계에서 일관성 있는 결과를 생성할 수 있다. 

 

모델 전체를 저장하고 불러오는 방식

앞에서 알아본 방법은 저장된 가중치를 불러오기 전에, 모델 클래스를 먼저 불러와야 하는 불편함이 있다. 그렇기 때문에, 모델 구조 자체와 가중치를 함께 저장하고 불러오는 방식을 이용하는 것도 좋은 방법이다. 

 

저장하기

이번에는 model.state_dict가 아닌 model 자체를 저장하면 모델의 구조와 가중치가 함께 저장된다. 

torch.save(model, 'model.pth')

불러오기 

저장한 모델을 불러오는 방법은 아래와 같다. 

model = torch.load('model.pth')

모델을 저장할 때는 torch.save()를 썼는데 여기서 torch.load()로 호출하면 모델 구조와 가중치를 함께 불러올 수 있다.

 


이 포스트를 끝으로 파이 토치의 기본적인 사용법에 대해 알아보았다. 앞으로는 이를 기반으로 파이토치를 더 다양하게 활용해 보고자 한다. 

반응형

댓글