Kevin Clark et al., ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators, iclr2020
- stanford와 구글 브레인에서 작성한 논문
pretraining에서 사용하는 masked language model(MLM)의 문제점
- 입력 텍스트의 일부만 masking하고, masking 된 토큰만 학습에 사용함 --> 비율적임.
- 본 논문에서는 입력 텍스트의 전체 토큰을 학습에 활용하는 방법을 제안함
- pretraining시의 입력 텍스트와 downstream task에서 사용되는 입력 텍스트 모양이 다름
- pretraining시에는 masked token이 포함된 텍스트를 입력으로 받음
- downstream task에서는 masked token은 없고, 일반 텍스트를 입력으로 받음
- 본 논문에서는 pretraining시에 masked token을 사용하지 않고, 일반 텍스트를 사용하는 방법을 제안함
replaced token detection
- 입력 텍스트에서 일부 토큰을 다른 토큰으로 변경하고, 각 토큰에 대해서 입력 텍스트의 토큰이었는지, 변경된 토큰이었는지를 판단하는 태스크(task)
- 입력 텍스트의 모든 토큰에 대해서 학습이 이루어지므로, MLM보다 효율적인 학습이 가능함
electra
- 본 논문에서는 generator와 discriminator의 2가지 네트워크를 사용함.
- 아래 그림에서 보듯이 discriminator 부분이 electra에 해당됨
- discriminator을 downstream task에서 사용함.
- generator는 discriminator를 학습시키는데 사용함
- generator는 일부 token이 masking되어 있는 masked text를 입력으로 받아서, masked token에 대한 replaced token를 생성함.
- replaced token이 원래 입력 토큰과 동일하면, 해당 토큰은 replaced token으로 간주하지 않음. 아래 그림에서 첫번째 단어인 'the'가 해당됨
- discriminator는 generator가 생성한 텍스트를 입력으로 받아서, 각 토큰이 원래 토큰인지 generator가 생성한 replaced token인지 판단함

electra 학습 방법
- generator와 discriminator는 transformer encoder로 되어 있다.
- generator와 discriminator는 서로 다른 네트워크 구조를 사용한다.
- generator의 layer 개수가 더 적다. discriminator보다 작은 generator를 사용하는 것이 좋음을 실험에서 확인했다. (그림3에서 왼쪽 그림 참조)
- generator로 transformer를 이용하지 않고, unigram model를 이용하는 경우도 실험했음. (그림3에서 왼쪽 그림 참조: transformer를 이용하는 것보다 결과가 좋지 않음)
- generator를 학습하면서, 동시에 discriminator도 학습을 한다. 이렇게 하는 것이 가장 좋은 결과를 보였다.
- generator를 학습 완료한 후에, discriminator를 학습하는 경우도 실험했음.
- generator를 adversarial하게 학습한 경우도 실험했음.
- 그림 3의 오른쪽 그림 참조
electra small model
- sequence length를 512에서 128로 줄임
- batch size를 256에서 128로 줄임
- hidden dimension size를 768에서 256으로 줄임
- token embedding을 768에서 128로 줄임
efficiency analysis
- electra와 아래 3가지 실험 결과를 비교함 (table 5 참조)
- electra 15%
- discriminator 학습시에 전체 토큰이 아닌 masked token만으로 학습
- discriminator 학습에서 전체 토큰을 사용하는 것이 품질 향상에 기여하는지 확인하기 위함
- 실험 결과:
- electra 보다 결과가 많이 낮음 (85 -> 82.4) . 즉, 전체 토큰을 학습에 사용하는 것이 품질 향상에 기여함
- replace MLM
- MLM 시에 입력에 mask 토큰 부분을 [MASK]라는 토큰 대신 generator가 생성한 토큰을 이용. 아마도 mask 토큰 부분의 원단어를 찾도록 학습할 듯.
- pretraining과 fine-tuning의 입력이 다른 문제를 해결. 이 방법으로 어느 정도의 품질 향상이 있는지 확인
- 실험 결과: BERT보다 살짝 좋음 (82.2 -> 82.4).
- pretraining과 fine-tuning의 입력 텍스트가 다른 문제([MASK] 토큰의 존재 유무)는 그리 큰 문제는 아님
- all-tokens MLM
- replace MLM과 동일한 형태로 입력을 만듬
- MLM에서 mask token에 대해서만 단어 복원을 하는데, 여기에서는 전체 단어를 복원하도록 학습함. 즉 mask 토큰은 원래 token을 찾도록 학습하고, 다른 토큰은 그 토큰 자체를 그대로 찾도록 학습
- 실험 결과: BERT보다 결과가 좋음 (82.2 -> 84.3). 즉, 모든 토큰을 학습에 함께 사용하는 것이 품질 향상에 기여함
- 실험 결과
- 입력 텍스트의 전체 토큰을 학습에 사용하는 것이 품질 향상에 기여
- pretraining과 fine-tuning의 입력 텍스트가 다른 문제는 그리 큰 문제는 아님
댓글