본문 바로가기
  • 머킹이의 머신로그
AI/코드 실습하기

DDP 모델 학습에서 발생한 주요 에러

by 머킹 2024. 9. 10.
728x90

 

안녕하세요 머킹입니다.

모델 학습 과정에서 발생한 에러들을 보면서...

이걸 정리해두면 좋겠다 싶어서 글을 적습니다.

 


1. DataParallel vs DistributedDataParallel (DDP) 이슈

  • 문제: 처음에는 DataParallel을 사용해 여러 GPU에서 모델을 학습했지만, 메모리 오류 및 성능 저하 문제로 인해 안정적이지 않았습니다. 특히, cuda:0과 cuda:3에서 GPU 간 장치 혼합 오류가 발생했습니다.

DataParallel 구조
DDP 구조

해결 방법: DataParallel 대신 DistributedDataParallel (DDP)를 사용하여, GPU 자원 관리 문제와 메모리 문제를 해결했습니다. DDP는 각 GPU마다 별도의 프로세스를 할당해, 성능을 향상시킵니다. setup과 cleanup 함수를 통해 프로세스를 안전하게 설정하고 해제했습니다.

 

 
# DDP를 위한 초기화 및 해제 함수
def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()
  • train_ddp 함수는 각 프로세스에서 모델을 로드하고, DistributedDataParallel로 설정합니다.
 
def train_ddp(rank, world_size, model_name, train_dataset, eval_dataset, tokenizer, args):
    setup(rank, world_size)

    # 모델을 DDP로 설정
    device = torch.device(f'cuda:{rank}')
    model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
    model = DDP(model, device_ids=[rank])

 

 

2. CUDA 메모리 부족 오류

  • 문제: 모델 로드 중, AutoModelForCausalLM을 로드할 때 GPU 메모리가 부족하여 CUDA Out of Memory 오류가 발생했습니다.
  • 해결 방법:
    • torch_dtype=torch.bfloat16을 사용해 부동 소수점 16비트 연산을 적용하고, GradScaler를 사용해 혼합 정밀도 학습(AMP)을 활성화하여 메모리 사용량을 줄였습니다.
scaler = torch.cuda.amp.GradScaler()
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
    outputs = model(**batch)
    loss = outputs.loss / args['gradient_accumulation_steps']

 

3. 모델 저장 시 프로세스 간 충돌

  • 문제: DDP를 사용할 때 모든 프로세스에서 모델을 저장하려고 시도하여 충돌이 발생했습니다.
  • 해결 방법: Rank 0 프로세스에서만 모델 저장 작업을 수행하도록 조건문을 추가했습니다.
if rank == 0:
    model.module.save_pretrained(model_save_path)
    tokenizer.save_pretrained(model_save_path)

 

4. MultiProcessing Spawn 및 프로세스 종료 문제

  • 문제: torch.multiprocessing.spawn()을 사용할 때, 모든 프로세스가 제대로 종료되지 않아 경고 메시지가 발생했습니다.
  • 해결 방법: 모든 학습이 완료된 후 cleanup()을 호출해, 프로세스 그룹을 안전하게 종료했습니다.
def cleanup():
    dist.destroy_process_group()