프로그래밍

[Pytorch] 모델 저장 방법, 그리고 전체 저장과 state_dict 저장의 차이

DongJin Jeong 2021. 1. 3. 20:43

모델 저장의 방법

Pytorch에서는 학습된 모델을 저장할 때 torch.save(object, file) 함수를 사용하게 된다.

# object : 저장할 모델 객체, file : 저장할 위치 및 파일 이름

# Case 1
torch.save(model, 'model.pt')
# Case 2
torch.save(model.state_dict(), 'model.pt')

Pytorch 모델 파일은 확장자가 pt이다.

모델을 저장할 때는 두 가지 방법 중 한 방법을 선택할 수 있는데, 모델 전체를 저장하는 방법모델의 state_dict만 저장하는 것이다.

모델 전체 저장

모델 전체를 저장한다는 것의 의미는 모델 파라미터 뿐만 아니라, 옵티마이저(Optimizer), 에포크, 스코어 등 모든 상태를 저장한다는 것이다. 만약 나중에 이어서 학습을 한다던지, 코드에 접근할 권한이 없는 사용자가 모델을 사용할 수 있도록 허락해주고 싶을 때 등의 경우에 사용하는 것이 바람직하다. 모델 전체를 저장하는 만큼, 상대적으로 더 큰 용량을 가지게 된다.

# Save model
torch.save(model, 'model.pt')
# Load model
model = torch.load('model.pt')

모델의 state_dict만 저장

Pytorch에서 모델의 state_dict은 학습가능한 매개변수가 담겨있는 딕셔너리(Dictionary)이다. 가중치와 편향이 이에 해당한다. 그러나 매개변수 이외에는 정보가 담겨있지 않기 때문에, 코드 상으로 모델이 구현되어 있는 경우에만 로드하는 방법을 통해 사용할 수 있다. state_dict만 저장하면 파일의 용량이 가벼워진다는 장점이 있다.

# Save model
torch.save(model.state_dict(), 'model.pt')
# Load model
model.load_state_dict(torch.load('model.pt'))