[논문리뷰] (InstructGPT) Training language models to follow instructions with human feedback 논문 리뷰 요약

2025. 2. 6. 16:59paper

 

Improving Language Understanding by Generative Pre-Training
InstructGPT - [2203.02155] Training language models to follow instructions with human feedback]
Long Ouyang, Jeff Wu, Xu Jiang, Diogo Almeida, Carroll L. Wainwright, Pamela Mishkin, OpenAI
NeurIPS, 2022

 

- https://arxiv.org/abs/2203.02155

 

Contents

1. ChatGPT

    1-1. GPT-1 

    1-2. GPT-2 

    1-3. GPT-3

    1-4. GPT-4

    1-5. Instruct GPT - (Training language models to follow instructions with human feedback)

    1-6. RLHF

2. LLama 

3. Gemini

4. DeepSeek

5. LLaVA

    5-1. LLaVA - https://jaeha-lee.tistory.com/93

    5-2. LLaVA-Next
    5-3. LLaVA-NeXT-Interleave

   

 

 

 

Summary

 

- 기존 LLM의 경우 untruthful, toxic, not helpful한 output을 자주 생성
     - 단순히 다음 텍스트를 유추하는 방식으로 진행하여 위와 같은 문제들이 계속 발생함

 

- 위 문제를 해결하기 위해서 아래 시도를 하고 결과에 대한 요약은 다음과 같음

     - human feedback을 이용한 데이터 + 강화 학습 기법을 활용하여 (PPO & PPO-ptx) GPT-3를 fine-tuning

     - 훨씬 적은 매개변수를 가지고도 기존 GPT-3보다 human intent 잘 맞춤 : InstructGPT 1.3B >> GPT-3 175B 

     - GPT-3에 비해 toxicity 에 개선 / bias는 개선하지 못함

- 아키텍처는 GPT-3 그대로 사용

 

*toxic : 유해한 정보

 

 

 

Method

- 인간 피드백을 통한 강화 학습

- instruction을 따르도록 language model을 학습함

- 학습과정 - 예시

     - step1 : (demonstration data(사람이 만든 데이터) -> 학습 생성 방법) = SFT 모델 생성
         1) 6살에게 달 착륙에 대해 설명해줘봐

         2) 사람이 설명을함

         3) 2) 데이터를 가지고 GPT-3를 fine-tuning 함  

 

     - step2 - (비교 데이터를 모으고 RM 모델을 학습함) - RM 모델 학습

         1) 하나의 prompt와 똑같은model의 여러 output을 가지고 옴
             - 하나의 prompt : "6살에게 달 착륙에 대해 설명해줘봐"
             - 여러 개의 model output : 모델에서 위 prompt에 대한 여러 답변들

         2) 사람이 best -> worst 까지 순위를 매김

         3) 해당 데이터로 RM 모델을 학습함

 

     - step3

         1) 데이터 셋으로부터 새로운 prompt를 가져옴

         2) 모델이 output을 냄

         3) RM이 output에 대해 reward를 계산함

         4) reward를 사용하여 PPO를 통해 policy를 update 함

 

학습 과정

     step 1 : GPT-3 모델을 finetuning 시켜서 Superviesed fine-tuning(SFT) 모델 생성

     setp 2 : SFT 모델을 이용해 RM을 만들어 한번 더 학습 (사람의 선호도가 높을 수록 높은 reward)

     즉 GPT-3 에다가 SFT와 강화학습을 결합하여 fine-tuning한 방법이 instructGPT


*policy - 강화학습에 사용하는 용어로, 주어진 상태에서 어떤 행동을 취할지를 알려주는 일종의 지침서

*regression performance - fine-tuning 후 특정 작업에 대한 성능이 저하되는 현상

*PPO(Proximal Policy Optimization) - 강화 학습 알고리즘의 일종, policy gradient 방법을 기반으로 하며 agent가 환경과 상호 작용하여 보상을 극대화 하는 최적의 policy를 학습하는 것

 

 

Dataset 

 

학습 데이터는크게 2종류 : "OpenAI API를 활용한 prompt | 답변"  + "라벨러들이 직접 생성한 prompt | 답변"

 

라벨러들에게 아래와 같은 기준으로 데이터를 생성하게 함


     - Plain - 자유롭게

     - Few-shot - instruction 형식으로 작성 

     - User-based - OpenAI API 에서 작성한 prompt를 읽고

 

 

- Table 1 - Generation, Open QA, Brainstorming ... Table 2 - Table 1 예시

 

- 이용자당 prompt 200개 제한

- 각각의 prompt에 ID(user 구분) 하여 valid, testset에서 trainset이 안들어가게 함

- 처음에는 prmopt를 라벨러들이 직접 쓰게 해서 초기 source를 구성함. 이 데이터는 GPT-3에 사용되진 않음

 

 

Models

1) Superviesed fine-tuning(SFT)
- cosine learing rate decay와 residual dropout 0.2 사용 / 16 에폭  / 1 에폭만에 validation에 대해 overfiiting 되었지만 계속 학습할수록 RM score와 human score 모두 도움이 됨

2) Reward modeling(RM)
- SFT에서 unembedding layer를 제거함 | 원래 unembedding layer는 임베딩 된 단어를 원래 단어로 바꿈 | 얘를 제거하면 스칼라 값이 나옴
- prompt와 response를 입력으로 받아서 스칼라 값을 출력 함

- 연산량 줄이기 위해 6B RM만을 사용. 오히려 175B를 사용하면 훈련이 잘 안됨

 

 

- InstructGPT에서는 SFT를 통해 같은 prompt에 대해 4~9개에 해당하는 답변을 만들고 / 생성한 데이터를 라벨러가 순서를 매김

- 그럼 이 각각을 2개 씩 뽑아 비교를 진행함 (combination)


- loss function : | E : 기대값 | x : prompt | y : 출력 | yw > yl (선호도) | r : reward | σ : sigmoid |

- 결국 사람 선호도가 높은 출력에 대해 reward를 주는 방향으로 학습을 함

 

3) RL

- PPO(Proximal Policy Optimization) 방법의 강화학습을 사용

- ϕ라는 파라미터를 objective function

- E : 기대값 | x: prmopt, y: 출력 | r(x,y) step2에서 사용 세타로 여기서 업데이트 되지 않음 | 
- β : 하이퍼파라미터 | π라는 강화학습 모델 파라미터는 ϕ, x:prompt를 넣었을 때 y가 나올 확률 |  SFT 모델에 x:prompt 넣었을 때 y가 나올 확률)

- InstructGPT는 PPO-ptx 라는 것을 만들었음

- 뒤에 항을 추가 : 학습전 pretraining 데이터 분포를 알 수 있게 하기 위한 장치.. (이해는 안감)

- PPO를 하게 되면 performance regression이 일어나게 되는데, 이를 막아주는 역할을 함. 실제로 이걸 쓴 모델이 성능이 더 좋음 (근데 파라미터가 더 많을 때 좋아지는데 효과적인게 맞나...)

 

 

 

Experiments

3부분으로 나눠서 실험 검증을 함

 

1. API prompt distribution

- InstructGPT를 GPT-3보다 확실히 더 선호함

 

 

 

2. public NLP dataset

3.  qualitative result

 

 

뭐 성능 좋다는 얘기고, Hallucination의 경우 SFT가 제일 낮게 나옴

Conclusion

 

 

Ablation Study