[2020]SMART: Robust and Efficient Fine-Tuning for Pre-trained Natural Language Models through Principled Regularized Optimization
짧은 요약(Abstract) :
* Transfer Learning
** NLP에 큰 변화 가져옴
*** SOTA 달성
** 그러나 자원제한, PreTrain 모델의 복잡성은 문제 야기
*** Overfit(downstream data)
*** generalization 실패(unseen 대해)
** Principled 매너에 따른 새 학습 프레임워크 제안
*** 더 나은 성능(일반화)
* 두가지 함유(gradient에)
** 1. smootheness regularization
*** 복잡성 매니지 잘 함
** 2. Bregman Proximal Point Optimization
*** 신뢰구간 기법(지나친 업데이트 방지)
** SOTA 달성 NLP tasks(GLUE, SNLI, SciTail ,ANLI) 단어정리
- principled manner: 원칙에 입각한 방식
- harness: 이용하다, 활용하다
- inducing: 설득하다, 유도하다, 유발하다
- perturbation: 작은 변화, 섭동=행동을 다스림
- simplex: 단일어, 단체(정삼각형, 정사면체, 정오포체..)
- proximal: 근위의, 인접면
- symmetrized: 대칭된, 균형적
- Bregman divergence: 브레그먼 발산(구성요소의 명목 분포와 실제 분포 간의 편차를 설명할 때 사용)
- exponential moving average: 지수이동평균(가중변수를 이용하여 최근 수치의 영향력은 높이고 과거 수치의 영향력은 낮추는 것)
1 Introduction
- Transfer Learning 등장 배경
** 큰 레이블 데이터 구하기 어렵고 비쌈
** 특정 downstream data 아니어도 관련 task data로 pre-train
** 지식의 확장(전이)가 목표
** P-T -> F-T 순으로
** ELMo, GPT, BERT 유명
** semantic & synthactic info 모두 파악 가능
** unlabeled data로 학습
** 매우 큰 P-T 모델
** T5는 11B param(1000억개)
** F-T서 PLM to target task(domain)
** Top layer replace
** low resoruce 필요
** SOTA
** P-T 매우 복잡
** 급격히 F-T 경우 overfit
** 보완책
** 휴리스틱
** l.r 스케줄링 조절
** 점진 unfreezing
** adapt certain layers & freezing other
** add 추가 layers(모두 튜닝 노력 필요) - 해결 위해 강건, 효율 F-T 프레임웍 제안
** 2가지 overfit 방지책 포함
(1) 매우 큰 PP 컨트롤
** smootheness inducing adversarial 기술 제안
** local shift sensitivity에서 영감(robust 통계의)
** output 잘 안 바뀌게 함
** 스무스 모델 효과적 capacity control
(2) 급격 업데이트 방지위해 Bregman Proximal Point 최적화 사용
** 신뢰구간 제약(각 iter마다), 작은 이웃(이전 iter와)만 업데이트
** 급격 업데이트 방지 & F-T 안정화 - SOTA와 비교, 성능 향상 확인
** 본 모델 356M param 앙상블 없이 GLUE SOTA - 공헌
- smootheness inducing adversarial regularization & proximal point optimization into LLM F-T 기법 소개
- SOTA 달성(GLUE, SNLI, SciTail, ANLI)
- 표기
** f(xi ceta): 매핑
** f가 para cete input x 연관시킴
** output: 다차원 확률, simplex(단체) for 분류/회귀
** Pkl(P||Q) = Sigma k pk log (pk/qk)는 KL 다이버전스 p분포, q분포 param qk, pk
2 Background
- T-F L 배경
** NMT 첫 제안
** BERT 제안 양방향 트랜스포머 기반 인간 annotate 없이 큰 성과
** 큰 data, 큰 모델 영감되서 많이 나옴
** PLM 파인튠 채택
** 탑 layer는 특정 task용
** overfit 방지, 휴리스틱 사용
** 작은 l.r or triangular l.r 스케줄링, 적은 iter
** regularization 기술 제안
** 이미 있는 기술과 유사, but 다른 응용, 다른 동기부여 e.g. semi 지도학습, 비지도학습 도메인 적응, 제악 적대 in image 분류 - 본 최적화 기술
** 큰 class의 Bregman proximal point 기법 커버(록펠러 vanila proximal point, accelerated proximal point 방법 등) - 관련 F-T 방법
** FreeLB: robust adversarial training 방법
** 본 논문은 local smoothness에 초점 -> 성능 올림
3 The Proposed Method
- SMART 프레임웍 제안
** SMoothness-inducing Adversarial Regularization and BRegman pRoximal poinT opTimization
3.1 Smoothness-Inducing Adversarial Regularization
- model f(.; ceta)와 n개 data points of target task {(xi, yi)} i=1 to n
** xi는 input sent의 임베딩
** yi는 레이블
** 다음식으로 최적화 min ceta F(ceta) = L(ceta) + lamdas Rs(ceta), L(ceta) = 1/n(Sigma i=1 to n l(f(xi;ceta), yi)
** L is loss function
** l(.,.)도 loss 함수 target 의존적
** lamda s > 0이 튜닝 param
** Rs(ceta)가 smoothess inducint adversarial regularization
** ||xi_tilda - xil||p <= epcilon
** Rs(ceta) = 1/n(Sigma i=0 to n max lr(f(xi_tilda;ceta), f(xi;ceta)),
*** epsilon은 튜닝 param
** f(.;ceta)는 분류 task이고
** output 은 확률 simplex
** ls는 KL divergence로부터 선택
** 회귀서 f(.;ceta) 결과는 스칼라
** ls는 squared loss로 선정
** Rs(ceta) 계산은 최대화 문제 gradient ascent - 역사
** Miyato가 준지도학습서 사용
** Shu가 비지도 도메인 적용에 사용
** Zhang이 이미지 분류서 적대 예 제한에 사용
** 파인튠에 사용은 본 모델이 처음 - Smoothness-inducing adversarial regulariation
** local Lipschitz 측정 f under metric ls(perturbation(섭동, 유도, 작은 행동?) 작으면 f 별로 안 변함)
( (1) 줄이면 f가 smooth해짐, xi 이웃에 대해)
** overfit 방지
** generalization 향상(적은 리소스로) - 참고
** local Lipschitz continuity는 local shift sensitivity criterion과 비슷(1960’s)
** 이 criterion이 가늠자들의 독립성 측정 때 사용
3.2 Bregman Proximal Point Optimization
- (1) 풀수 있고 Bregman Proximal point 최적화 기법 클래스를 제안
** 급진 update 억제
** P-T 모델을 초기화로 사용 f(.;ceta)로 표기, (t+1)th 반복서
** VBPP(Vanila Bregman Proximal Point)는 ceta t+1 = argmin ceta F(ceta) + mu DBreg(.,.), mu는 튜닝 param이고 >0일때 사용
** Bregman divergence DBreg는 DBreg(ceta, cetat) = 1/n Sigma i=1 to n ls(f(xi;ceta), f(xi;cedtat))로 정의
** ls는 section 3.1에서 정의, mu 커지면 VBPP의 DBreg가 각 iter 마다 강하게 제약됨
** ceta t가 너무 커지는 것 방지
** 즉, Bregman proximal point method는 효과적으로 보유(도메인 벗어난 지식, P-T의 f(.,ceta))
** 각 sub 문제
** (2)(VBPP 중) 는 closed form solution으로 인정 x SGD류로 풀어야(ADAM 같은) - 비고
** 각 subprob converge까지 불필요
** 작은 iter면 충분
Acceleration by Momentum
- 모멘텀 가속화
** 모멘텀 써서 가속화
** t+1 번째 iter, 모멘텀 MBPP
** ceta t+1 = argmin ceta F(ceta) + mu DBreg(ceta, cetat_tilda), cdetat_tilda = (1-beta)ceta t + beta ceta t01_tilda, beta (= (0,1) 인 exponential 이동평균으로 모멘텀 param임
** Mean Teacher 라고도 불림
** 준지도학습서 SOTA 보임, 편의 위해 MBPP는 알고리즘1로 칭함
4 Experiment - Main Results
- SMART F-T GLUE 사용, SOTA와 비교
4.1 Implementation Details
- 구현 상세
** BERT, RoBERTa, MT-DNN, HNN 기반
** ADAM, RADAM 사용 optimizer
** l.r (= {110^-5, 210^-5, 310^-5, 510^-5}
** batch size (= {16, 32, 64}
** max epoch 6
** l.r decay schedule with mwarm up 0.1
** dropout rate all layer as 0.1 except 0.3 for MNLI, 0.05 for CoLA
** gradient exploading 피하려고 gradient norm 1로 고정
** word piece로 토크나이징
** 최대길이 512
** perturbation(작은 변화?) size epcilon=10^-5, sigma=10^-5, mu=1, lamdas(={1,3,5}
*** l.r eta in algorithm 1 is set to 10^-3, beta=0.999
4.2 GLUE Main Results
- GLUE test
** BERT base 기반 & RoBERTa 라지 기반
** 버트베이스->베이스라인
** RoBERTa 사용
** 위에 RoBERTa PGD, FreeAT, FreeLB adversarial training 기법 built
** SMART 본 제안 모델
** 공정비교위해 버트 재구축->기존 버트보다 성능 오름
** 성능 기존 능가 in 8GLUE
** Adversarial training으로 GLOVE서 SOTA 능가
** T511billion param 대비 SMART는 356billion으로 더 효율적이며 SOTA
5 Experiment - Analysis and Extension
- 실험-분석과 확장
** SMART가 MT에 보완되는지 검증
** SNLI, SciTail의 도메인 적응되나 평가
** ANLI에서 강건성 test
5.1 Ablation Study
- 경감 study
** SMART 사용이 성능 최고
5.2 Error Analysis
- 에러분석
** MNLI 모호 샘플도 분석
** 샘플 4개로 나눔
1) 5/0/0 - 5개 레이블 같음
2) 4/1/0 - 4개 레이블 같음
3) 3/2/0 - 3개 같고 다른 2개 같음
4) 3/1/1 - 3개 같음
** SMART RoBERTa가 RoBERTa 압도
5.3 SMART with Multi-task Learning
- MT 사용 SMART
** MTDNN SMART가 성능 압도, 상호 보완적 입증
5. 4 Domain Adaptation
- 도메인 적용
** SNLI & SciTail set 이용
** 랜덤 샘플 0.1, 1, 10, 100%
** MT-DNN & MT-DNN-SMART
** 압도 SOTA
5.5 Results on SNLI and SciTail
** MT-DNN-SMART가 압도 SOTA
5.6 Robustness
** SMART가 압도, combine set test서
6 Conclusion
- 결론
** 강건 효율적 SMART 제안 (파인튜닝 framework)
** overfit 경감, 급격 update 방지
** SMART 구성
1) smooth-inducing adversarial regularization
2) Bregman proximal point optimization
** SOTA in NLP benchmarks(GLUE, SNLI, Scitail, ANLI)
** 도메인 적응 서로 향상
** 다른 곳에서도 잘 적용 예상