2025. 1. 7. 22:54ㆍ카테고리 없음
torch.compile()
을 사용할 때 특정 연산(예: 사용자 정의 CUDA 커널)이 지원되지 않는 경우,
PyTorch의 디버깅 도구를 활용하여 문제를 추적하고 해결할 수 있음
1. torch._dynamo.config.verbose=True
로 디버깅 활성화
torch.compile()
에서 문제가 발생하면, torch._dynamo.config.verbose=True
를 설정하여 PyTorch의 디버깅 메시지를 활성화
이렇게 하면 어떤 연산이나 코드가 문제가 되는지 확인할 수 있음
import torch
import torch._dynamo
torch._dynamo.config.verbose = True
이 설정을 활성화한 상태로 torch.compile()
을 사용하면, 컴파일 과정에서 발생한 오류와 문제의 위치를 상세히 출력 가능
2. 문제 위치 추적
출력된 로그에서 다음을 확인
- 오류 메시지: 예를 들어, "Unsupported Operation" 또는 "Fallback to Eager Execution"과 같은 메시지가 나타날 수 있음
- 문제가 된 연산: 문제가 되는 특정 연산(예: 사용자 정의 연산 또는 특정 모듈)이 로그에 표시됨.
# 문제가 되는 부분을 일반(Eager) 모드로 실행
with torch._dynamo.disable():
problematic_output = problematic_function(input)
3. 문제를 해결하는 방법
(1) 문제가 된 연산을 제외하고 실행
torch.compile()
는 부분적으로만 모델을 컴파일할 수도 있음
문제가 되는 코드 블록을 torch.compile()
의 컴파일에서 제외하고, 나머지 코드만 최적화할 수 있음
import torch
import torch.nn as nn
import torch.optim as optim
class CustomModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 20)
self.fc2 = nn.Linear(20, 1)
def forward(self, x):
x = self.fc1(x)
# 문제가 되는 연산은 여기서 실행
x = x.relu() # 예시: 문제 발생 지점
return self.fc2(x)
# 모델 전체를 컴파일하지 않고, 문제가 되는 부분만 제외
model = CustomModel()
compiled_fc1 = torch.compile(model.fc1) # 일부 모듈만 컴파일
compiled_fc2 = torch.compile(model.fc2)
# Forward pass에서 문제를 우회
def forward_with_custom_logic(x):
x = compiled_fc1(x)
x = x.relu() # 문제 연산을 일반 모드로 실행
return compiled_fc2(x)
# Training loop
inputs = torch.randn(32, 10)
outputs = forward_with_custom_logic(inputs)
(2) Fallback to Eager Mode
PyTorch는 특정 연산에서 문제가 발생하면 자동으로 Eager 모드로 돌아가게 설정할 수 있음. 이 방식은 문제를 우회하면서도 성능을 일부 유지
model = torch.compile(model, fallback=True)
이 설정을 통해 문제가 되는 연산이 발견되면, 해당 부분은 Eager 모드로 실행되고 나머지 연산은 최적화된 그래프에서 처리됨
(3) 커스텀 연산을 TorchDynamo에서 지원하도록 등록
사용자 정의 연산(CUDA 커널 등)이 문제의 원인인 경우, 해당 연산을 TorchDynamo에 지원하도록 등록할 수 있음. 아래는 사용자 정의 연산을 TorchDynamo에서 처리하도록 설정하는 예제
import torch
import torch._dynamo
# 사용자 정의 연산 정의
def custom_op(x):
return x ** 2 + 1
# TorchDynamo에 연산 추가
@torch._dynamo.register
def handle_custom_op(x):
return custom_op(x)
# 연산이 포함된 모델
class CustomOpModel(nn.Module):
def forward(self, x):
return handle_custom_op(x)
model = CustomOpModel()
model = torch.compile(model)
# 실행
inputs = torch.randn(32, 10)
outputs = model(inputs)
4. 디버깅 모드
PyTorch의 디버깅 모드인 mode="debug"
를 사용하여, 실행 결과를 확인하고 문제를 좁혀 나갈 수 있음.
model = torch.compile(model, mode="debug")
mode="debug"
를 설정하면, 실행이 실패해도 더 많은 디버깅 정보를 제공하며, 그래프의 최적화 결과를 상세히 출력
5. 유의 사항
torch.compile()은 model.train()과 model.eval() 모두를 지원함 -> 하지만 torch.no_grad()와 결합할 때 성능이 완벽히 최적화 되지 않을 수 있음
model.eval()
with torch.no_grad():
compiled_model = torch.compile(model)
outputs = compiled_model(inputs)