Pytorch DataLoader의 OOM Kill 이슈
Pytorch를 그렇게나 많이 썼는데 이런 이슈가 있는지는 처음알았다.
대용량 데이터셋으로 대량의 worker를 사용해서 오래학습할 때 문제가 되고 보통은 큰 문제가 되지 않는 수준이라 그런 가보다.
학습을 하는데 계속 memory 사용량이 조금씩 오르면서 결국 oom kill을 당하는 문제가 있어서 골머리를 썩혔는데
그래서 memory leak이 있는지 dataset에서 getitem 호출할 때마다 memory logging 해보고 했는데 memory 사용량이 딱히 일정하게 늘어나지 않는 것이다. 변동이 조금씩 있긴 하지만?
근데 system available memory는 계속 줄어들고 있었고 확실히 뭔가가 memory 사용량이 계속 늘고 있다는 뜻이었다.
그래서 나는 side process나 logger 같은 애들이 데이터를 쌓아둬서 그런가 의심했는데...
막 모든 process의 memory 사용량을 logging해보고 했는데 결과적으로 그게 아니었다.
일단은 data loader의 memory 사용량을 잘못 측정했다.
나는 rss를 쟀는데 이건 uss에 shared memory가 더해진 값이라 multi process의 memory를 측정할 때 shared memory가 뻥튀기 되는 효과가 있다.
구글링을 해보니.... 저런 상당히 쉬운 조건에서 나타나는 memory leak이 있었다.
https://github.com/pytorch/pytorch/issues/13246
https://ppwwyyxx.com/blog/2022/Demystify-RAM-Usage-in-Multiprocess-DataLoader/
요지는 다음과 같다.
python multiprocessing을 fork로 사용할 때 저런 init 안에 두는 데이터같은 공통 object는 process마다 복사해가는 게 아니라 (그랬으면 memory가 터지겠죠?) 하나로 shared memory에 올려두고 (read-only만 가능) 그 memory에 대한 page만 각 process에서 keep해서 접근해 사용한다.
문제는 이제 write operation이 발생하면 그 page를 copy해와야 한다. 근데 이제 파이썬은 reference count로 gc가 돌기 때문에 특정 객체에 접근해서 새로 저장하는 것 만으로도 write operation이 발생하게 된다.
그래서 self.data = [...] 식으로 저장해놓고 data = self.data[i] 처럼 불러오게 되면 이 모든 객체에 대한 memory page를 copy해와야하는 것이다. shuffle을 True로 주면 문제가 심해진다고 하는데 아마도 locality가 있을테니 인접한 index는 같은 page에 들어있을 확률이 높아서 그런 듯 하다.
이거 때문에 추석 연휴 때 계속 다시 돌리고 고생했는데... 이런 문제가 있었다니 충격적이다.
해결책은 간단?할 수도 있고 아닐 수도 있는데 간단한 방법은 객체의 숫자 자체를 줄이는 것이다. 결국은 reference count에 대한 write operation이 문제가 되기 때문에 절대적인 객체의 숫자가 줄어들면 된다. 그래서 가장 간단한 방법은 np.array로 감싸는 것이다. 그러면 결국 getitem에서 액세스하는 객체는 항상 np.array 하나 뿐이고 그 객체의 특정 memory에 접근해 데이터를 읽어오게 되니 말이다.
import multiprocessing
from collections import defaultdict
import numpy as np
import psutil
import torch
from torch.utils.data import DataLoader, Dataset
def current_memory_usage() -> float:
"""Get the current memory usage in MB.
Returns:
float: The current memory usage in MB.
"""
res = defaultdict(int)
for mmap in psutil.Process().memory_maps():
res["rss"] += mmap.rss
res["pss"] += mmap.pss
res["uss"] += mmap.private_clean + mmap.private_dirty
res["shared"] += mmap.shared_clean + mmap.shared_dirty
if mmap.path.startswith("/"): # looks like a file path
res["shared_file"] += mmap.shared_clean + mmap.shared_dirty
return res
class DataIter(Dataset):
def __init__(self):
self.data = [x for x in range(24000000)]
self.data_np = np.array(self.data)
self.data_dict = {x: x for x in range(24000000)}
self.data_mp_dict = multiprocessing.Manager().dict(self.data_dict)
self.data_np_object = np.array(self.data, dtype=object)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
res = current_memory_usage()
print(
f"PID: {psutil.Process().pid}, "
f"RSS: {res['rss'] / 1024 / 1024:.2f} MB, "
f"USS: {res['uss'] / 1024 / 1024:.2f} MB, "
f"PSS: {res['pss'] / 1024 / 1024:.2f} MB, "
f"Shared: {res['shared'] / 1024 / 1024:.2f} MB, END"
)
data = self.data_npobject[idx]
data = np.array([data], dtype=np.int64)
return torch.tensor(data)
train_data = DataIter()
train_loader = DataLoader(train_data, batch_size=300, shuffle=True, drop_last=True, pin_memory=False, num_workers=18)
for i, item in enumerate(train_loader):
if i % 1000 == 0:
# print(i)
pass
이런 코드로 getitem에서 data 종류를 바꿔보면서 memory 증가여부를 살펴보면 눈으로 확인할 수 있다.
self.data는 pss가 증가
self.data_np는 안증가
self.data_dict도 증가
self.data_mp_dict는 안증가
하지만 object type의 nparray data_np_object는 증가한다. 유의.
다만 안증가하는 애들도 초기에는 증가하는 것을 유념해야 한다.
위 스크립트 기준으로 한 6000번은 뽑아내야 그 뒤로 일정해지는 것 같다.
덤으로 각종 oom kill log를 보는 팁이다.
이 둘은 oom kill 관련 log가 저장되는 것을 확인할 수 있다.
이건 감사로그를 활용하는 건데 oom kill이 아닌 sigkill을 전부 track할 수 있다.
대신 따로 설정이 필요하다.
auditd를 설치해준 뒤에 커맨드로 하나씩 저 룰을 추가해줘도 되고
/etc/audit/rules.d/audit.rules
이 파일에 위 내용들을 적은 다음
여기서 확인을 해보면 찍혀 있을 것이다.