jaeha_lee 2025. 1. 7. 22:26

PyTorch 2.0 주요 변경 사항

1. TorchDynamo

  • PyTorch 2.0의 핵심 기능으로, 동적 그래프를 정적으로 컴파일하는 기능을 제공
  • TorchDynamo는 기존의 Python 함수 코드를 변환하여 빠르고 최적화된 그래프 실행을 지원함
  • 사용자는 torch.compile() API를 통해 모델을 쉽게 컴파일할 수 있음

2. Inductor Backend

  • Inductor는 새로운 고성능 컴파일러 백엔드로, 다양한 하드웨어(예: CPU, GPU)에 대해 최적화된 코드를 생성함.
  • 기존의 TorchScript 및 기타 백엔드보다 더 나은 성능을 제공함.

3. Dynamic Shapes 지원

  • 동적 입력 크기를 더 잘 처리할 수 있도록 개선됨.
  • TorchInductor와 함께 사용할 때도 성능 저하 없이 동적 크기 입력을 처리.

4. Better Memory Pooling

  • GPU 메모리 관리가 더욱 효율적.
  • 큰 배치 크기와 복잡한 모델에서도 더 낮은 메모리 사용량을 보여줌.

5. Enhanced Developer Tooling

  • 디버깅, 성능 분석 및 디버거 친화적인 기능이 개선됨.
  • 디버깅 도구는 컴파일된 모델과도 잘 작동하도록 설계됨.

PyTorch 1.x 코드에서 2.0으로 전환하는 방법

1. torch.compile() 사용

  • PyTorch 2.0에서는 모델의 성능을 자동으로 최적화하기 위해 torch.compile() API를 사용함
  • 기존 코드에서 torch.compile() 한 줄만 추가하면 모델 성능을 개선할 수 있음.

코드 예제:

PyTorch 1.x 코드

import torch
import torch.nn as nn
import torch.optim as optim

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(10, 1)

    def forward(self, x):
        return self.fc(x)

model = SimpleModel()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Training loop
for epoch in range(10):
    inputs = torch.randn(32, 10)
    labels = torch.randn(32, 1)
    outputs = model(inputs)
    loss = torch.nn.functional.mse_loss(outputs, labels)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

PyTorch 2.0 코드

import torch
import torch.nn as nn
import torch.optim as optim

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(10, 1)

    def forward(self, x):
        return self.fc(x)

# 컴파일 추가
model = torch.compile(SimpleModel())
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Training loop (변경 없음)
for epoch in range(10):
    inputs = torch.randn(32, 10)
    labels = torch.randn(32, 1)
    outputs = model(inputs)
    loss = torch.nn.functional.mse_loss(outputs, labels)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

2. 동적 입력 크기 처리

  • 기존의 TorchScript는 고정된 입력 크기를 요구했지만, PyTorch 2.0은 동적 입력 크기를 네이티브로 지원함.
  • 별도의 추가 코드 변경 없이도 다양한 입력 크기를 처리할 수 있음.

예제:

# PyTorch 2.0에서는 다음처럼 동적 입력을 처리 가능
inputs1 = torch.randn(32, 10)  # Batch size 32
inputs2 = torch.randn(64, 10)  # Batch size 64
outputs1 = model(inputs1)
outputs2 = model(inputs2)

3. TorchDynamo 디버깅

  • torch.compile()에서 문제가 발생할 경우, 디버깅 모드를 사용할 수 있음.
    model = torch.compile(model, mode="debug")

4. 호환되지 않는 기능 확인

  • 일부 PyTorch 기능이나 사용자 정의 연산은 torch.compile()에서 완벽히 지원되지 않을 수 있음
  • 이런 경우 torch._dynamo.config.verbose=True를 설정하여 문제를 추적할 수 있음

성능 비교

  • PyTorch 2.0은 대부분의 경우 모델 학습 속도를 증가시키고, 메모리 사용량을 감소시킴
  • 성능 향상을 확인하려면 기존 학습 루프에서 단순히 torch.compile()을 호출하면 됨

한가지 주의해야 할 점은, 컴파일된 모델을 저장할 때 state_dict만 저장할 수 있다. 즉,

torch.save(optimized_model.state_dict(), "foo.pt") # 가능

 

이건 되고

torch.save(optimized_model, "foo.pt") # 불가능

 

이건 안된다.

 

 

참고 자료