카테고리 없음
torch 2.x
jaeha_lee
2025. 1. 7. 22:26
PyTorch 2.0 주요 변경 사항
1. TorchDynamo
- PyTorch 2.0의 핵심 기능으로, 동적 그래프를 정적으로 컴파일하는 기능을 제공
- TorchDynamo는 기존의 Python 함수 코드를 변환하여 빠르고 최적화된 그래프 실행을 지원함
- 사용자는
torch.compile()
API를 통해 모델을 쉽게 컴파일할 수 있음
2. Inductor Backend
- Inductor는 새로운 고성능 컴파일러 백엔드로, 다양한 하드웨어(예: CPU, GPU)에 대해 최적화된 코드를 생성함.
- 기존의 TorchScript 및 기타 백엔드보다 더 나은 성능을 제공함.
3. Dynamic Shapes 지원
- 동적 입력 크기를 더 잘 처리할 수 있도록 개선됨.
- TorchInductor와 함께 사용할 때도 성능 저하 없이 동적 크기 입력을 처리.
4. Better Memory Pooling
- GPU 메모리 관리가 더욱 효율적.
- 큰 배치 크기와 복잡한 모델에서도 더 낮은 메모리 사용량을 보여줌.
5. Enhanced Developer Tooling
- 디버깅, 성능 분석 및 디버거 친화적인 기능이 개선됨.
- 디버깅 도구는 컴파일된 모델과도 잘 작동하도록 설계됨.
PyTorch 1.x 코드에서 2.0으로 전환하는 방법
1. torch.compile()
사용
- PyTorch 2.0에서는 모델의 성능을 자동으로 최적화하기 위해
torch.compile()
API를 사용함 - 기존 코드에서
torch.compile()
한 줄만 추가하면 모델 성능을 개선할 수 있음.
코드 예제:
PyTorch 1.x 코드
import torch
import torch.nn as nn
import torch.optim as optim
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 1)
def forward(self, x):
return self.fc(x)
model = SimpleModel()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# Training loop
for epoch in range(10):
inputs = torch.randn(32, 10)
labels = torch.randn(32, 1)
outputs = model(inputs)
loss = torch.nn.functional.mse_loss(outputs, labels)
loss.backward()
optimizer.step()
optimizer.zero_grad()
PyTorch 2.0 코드
import torch
import torch.nn as nn
import torch.optim as optim
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 1)
def forward(self, x):
return self.fc(x)
# 컴파일 추가
model = torch.compile(SimpleModel())
optimizer = optim.SGD(model.parameters(), lr=0.01)
# Training loop (변경 없음)
for epoch in range(10):
inputs = torch.randn(32, 10)
labels = torch.randn(32, 1)
outputs = model(inputs)
loss = torch.nn.functional.mse_loss(outputs, labels)
loss.backward()
optimizer.step()
optimizer.zero_grad()
2. 동적 입력 크기 처리
- 기존의 TorchScript는 고정된 입력 크기를 요구했지만, PyTorch 2.0은 동적 입력 크기를 네이티브로 지원함.
- 별도의 추가 코드 변경 없이도 다양한 입력 크기를 처리할 수 있음.
예제:
# PyTorch 2.0에서는 다음처럼 동적 입력을 처리 가능
inputs1 = torch.randn(32, 10) # Batch size 32
inputs2 = torch.randn(64, 10) # Batch size 64
outputs1 = model(inputs1)
outputs2 = model(inputs2)
3. TorchDynamo 디버깅
torch.compile()
에서 문제가 발생할 경우, 디버깅 모드를 사용할 수 있음.model = torch.compile(model, mode="debug")
4. 호환되지 않는 기능 확인
- 일부 PyTorch 기능이나 사용자 정의 연산은
torch.compile()
에서 완벽히 지원되지 않을 수 있음 - 이런 경우
torch._dynamo.config.verbose=True
를 설정하여 문제를 추적할 수 있음
성능 비교
- PyTorch 2.0은 대부분의 경우 모델 학습 속도를 증가시키고, 메모리 사용량을 감소시킴
- 성능 향상을 확인하려면 기존 학습 루프에서 단순히
torch.compile()
을 호출하면 됨
한가지 주의해야 할 점은, 컴파일된 모델을 저장할 때 state_dict만 저장할 수 있다. 즉,
torch.save(optimized_model.state_dict(), "foo.pt") # 가능
이건 되고
torch.save(optimized_model, "foo.pt") # 불가능
이건 안된다.
참고 자료
- PyTorch 2.0 공식 문서: https://pytorch.org/docs/stable/
- TorchDynamo와 Inductor에 대한 심층적인 정보는 PyTorch 2.0 Release Note에서 확인가능