상세 컨텐츠

본문 제목

[Advanced ML & DL Week4] BERT : Pre-training of Deep Bidirectional Transformers for Language Understanding

심화 스터디/Advanced ML & DL paper review

by needmorecaffeine 2022. 11. 3. 09:58

본문

작성자 : 14기 김태영

논문 링크 : BERT, Pre-training of Deep Bidirectional Transformers for Language Understanding

1. Introduction

 BERT가 등장하기 이전의 SOTA 모델로는 GPT-1이 있었고 이는 많은 양의 데이터로 학습된 pre-trained된 모델이었다. 하지만 GPT-1을 포함한 기존 모델들은 문장을 학습할 때 순차적으로 읽어가며 학습해야 한다는 단점이 있었다. 단어 임베딩에서 Transformer를 사용하여 Attention을 통해 토큰 간 관계를 잘 파악하도록 작업할 수 있었지만 결국 예측할 때는 문장을 왼쪽부터 오른쪽으로 읽으며 예측해 이전 토큰만 참조할 수 있다는 단점은 여전히 존재했다.

 

 이러한 단점은 다음 문장 예측, 문장 빈칸 예측의 태스크에서 성능을 저하시키는 요인이 되었다. 그렇다면 양방향으로 학습 문장을 읽으면 성능이 더 개선되지 않을까? 

 

 이것을 가능케 한 것이 이 논문의 BERT MLM 방식이고 이를 이용해 Bidirectional Transformer 이용이 가능해졌다.

BERT에 대한 자세한 설명 이전에 BERT와 같은 사전 학습 언어 모델을 적용시키는 방법에 대해 살펴보면 크게 다음과 같은 두 가지 방식이 있다.

 

  • Feature-based approach : Elmo와 같이 사전학습된 모델을 추가적인 feature로 포함하는 task-specific 구조. 극히 일부를 제외하고 임베딩까지 포함한 모든 것을 업데이트하는 구조
  • Fine Tuning : GPT와 같이 최소한의 task-specific parameter를 알려주고 pretrain된 파라미터들을 downstream task 학습을 통해 조금만 바꿔 적용하는 구조. 임베딩은 그대로 두고 그 윗단의 레이어만 학습하는 구조

 이러한 기존 논의된 사전학습 구조들은 전에 언급한 것과 마찬가지로 그 학습 방향이 unidirectional 하고 왼쪽에서 오른쪽으로 향하며 모든 토큰이 앞선 토큰의 어텐션만 가능하다는 단점이 있다.


2. BERT Architecture

 위와 같이 Pre-training, Fine-tuning 두가지 단계로 구성되어 있다.

 

 pre-training 동안에는 모델은 여러 unlabeled data를 학습한다. fine tuning을 위해 모델은 pre-training parameter를 초기화하고 대량의 파라미터들은 down stream task로 labeled data를 통해 fine tuning된다. 각각의 down streeam 작업은 분리된 fine tuning 모델을 가지고 있다.

 

  BERT가 높은 성능을 얻을 수 있었던 이유는 위와 같이 레이블이 없는 방대한 데이터로 사전 훈련된 모델을 만든 후 레이블이 있는 다른 작업에서 추가훈련을 진행하여 하이퍼파라미터를 재조정하기 때문이다. 이렇게 pre-train된 모델과 fine-tune된 모델을 비교해보면 BERT는 구조적 차이가 거의 없고 미세한 차이만을 가진 여러 개의 모델을 가지게 된다.

 

  BERT는 트랜스포머의 인코더만을 다중으로 쌓은 다중 레이어 양방향 트랜스포머 인코더라고 할 수 있다. 

말 그대로 기존의 트랜스포머와 다르게 앞과 뒤 양방향에서 어텐션을 사용한다. 

 

 논문에서 설명하는 두 모델의 사이즈는 다음과 같다. 

 

L = 레이어 갯수(트랜스포머 블럭 개수) / H = hidden state / A = self-attention head 갯수

  • BERT BASE : L = 12 / H = 768 / A = 12 >> total parameters = 110M

  • BERT LARGE : L = 24 / H = 1024 / A = 16 >> total parameters = 340M

3. Input / Output Representation

 본 논문에서 정의하는 문장과 시퀀스의 정의는 다음과 같다.

  • 문장(sentence) = 텍스트의 임의 범위 (실제 언어학적 문장일 필요 없음)

  • 시퀀스(sequence) = BERT의 입력 토큰 시퀀스로 두 개의 문장을 함께 패킹한 것

 시퀀스의 정의와 같이 문장의 쌍은 하나의 시퀀스로 묶이게 되고 다음과 같은 두가지 방법으로 그 문장 쌍을 구분한다.

  • 토큰 [SEP]를 통해 분리

  • 이것이 어떤 문장인 표시하는 학습된 임베딩 추가

 BERT의 tokenizer는 WordPiece 임베딩 방식이고 문장의 첫번째 토큰은 항상 [CLS]로 이 토큰을 통해 마지막 분류 시 분류에 대한 값이 [CLS] 토큰의 연산 결과로 나타난다.

 

 위 사진을 보면 각각의 임베딩 레이어를 통해 나타나는 임베딩을 E로 표시하고 각 임베딩 층은 다음과 같은 기능을 한다.

 

  • Token Embedding = 실질적인 입력이 되는 워드 임베딩. 임베딩 벡터의 종류는 단어집합의 크기로 30,522개

  • Posiotin Embedding = 위치 정보를 학습하기 위한 임베딩. 임베딩 벡터의 종류는 문장의 최대 길이인 512개

  • Segment Embedding = 두 개의 문장을 구분하기 위한 임베딩. 임베딩 벡터의 종류는 문장의 최대 갯수인 2개

 이러한 세 개의 임베딩 레이어를 지나 BERT모델에 들어가고 연산이 된다.


4. Pre-training

 BERT는 기존의 방식들과 다르게  bidirectional하게 학습을 하였으며 내부적으로 두 가지 비지도 학습 방식을 통해 훈련되었다.

4-1. Masked LM

 아무 제약 없는 양방향 학습을 진행할 경우, 간접적으로 예측하려는 단어를 참조하게 되고 이는 multi layer 구조에서 해당 단어를 자기자신 참조로 예측하게 될 수 있기 때문에 제대로 된 학습이 되지 않는다.

 

 이를 위해 등장한 방식이 Masked LM, MLM이고 MLM은 다음 단어가 무엇이 오는지 예측하는 것이 아니라 문장 내에서 무작위로 입력 토큰의 일정 비율을 마스킹하고 마스킹된 토큰을 예측한다. 마스크 토큰에 해당하는 마지막 hidden  vector는 토큰을 통해 출력 소프트맥스로 주어지고 단어를 예측하게 된다.

 BERT의 경우 wordpiece 토큰의 15%를 무작위로 각 시퀀스에서 마스킹하고 마스킹된 단어의 예측이 모델의 목표가 된다.

이렇게 pre train 될 경우 fine tuning 시에는 mask 토큰이 없어 pre train과 fine tune 결과가 불일치하는 단점도 존재한다.

 

 이런 단점을 보완하기 위해 mask하기로 지정된 데이터를 훈련 데이터를 생성 시, 예측을 위해 무작위로 토큰 포지션의 15%를 선택한다. 만약 i번째 토큰이 선택된다면, i번째 토큰 중 80%는 MASK 토큰으로, 10%는 다른 토큰으로 교체, 10%는 변경되지 않은 i번째 토큰을 사용한다.

 

4-2. Next Sentence Prediction(NSP)

 MLM 학습과 동시에 BERT는 QA로도 활용되기 위해 두 개의 문장을 제공하고 이 두문장이 이어지는 문장인지 분류하는 훈련도 거치게 된다. 이때 문장은 50% 확률로 이어지는 문장이다.

 

 이 때 [SEP] 토큰을 기준으로 각 문장이 다른 문장임을 구분하고 가장 앞에 있는 [CLS]  토큰에 이 문장이 이어지는 문장인지 예측하게 된다.


5. Fine-tuning

 pre-train된 BERT로 해결하고자 하는 task의 데이터를 추가적으로 학습시켜 검증하게 된다. fine tuning 하는 방식은 다음과 같은 네 가지 학습 방식이 있다.

 

  • Paraphrasing
  • Hypothesis - Premise pairs in entailment : 두 문장이 이론과 가설 관계를 가지는지 분류
  • Question - Answering
  • Tagging / Text pair classification : 각 단어에 품사 태깅 / 문장 간의 관계 예측

6. Experiments


7. Reference

관련글 더보기

댓글 영역