모델을 학습하다가 보면 OOM (out of memory)이 떠서 학습에 어려움이 생기는 경우가 있다. 이때는 아래 코드를 사용해서 모델의 parameter를 삭제하고, cuda cache를 비우면 된다.
try:
# model 학습 코드
except RuntimeError as e:
print("RuntimeError in evaluate. ")
for p in model.parameters():
if p.grad() is not None:
del p.grad() # free some memory
torch.cuda.empty_cache()
continue
Python
복사