본문 바로가기
DL|ML

Self Supervised Learning를 여행하는 히치하이커를 위한 안내서 (2) - DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter

by 이든Eden 2022. 6. 15.

논문, 코드

[그림 1] 은하수를 여행하는이 아니라 Self Supervised Learning를 여행하는 히치하이커를 위한 안내서

 

Self Supervised Learning를 여행하는 히치하이커를 위한 안내서

저는 요즘 모두의 연구소에서 Self Supervised Learning(SSL)을 공부하고 SSL에 대해 논문을 쓰는 것을 목표로 하는 SSL LAB(쓸랩)의 연구원으로도 활동하고 있습니다. 벌써 2개의 논문을 끝냈습니다. 공부하고 있는 논문들이 SSL을 공부하는 좋은 로드맵이 될 것 같아 살짝 정리를 해보는 것이 좋을 것이라 생각했습니다. 논문 설명과 논문에서 공헌한 부분의 코드를 설명하는 시리즈가 될 것 같습니다.

 

 

미리 알아두면 좋은 지식

- Distilling the Knowledge in a Neural Network

: 지난 번 SSL 히치하이거 시리즈에서 소개한 논문입니다. Knowledge Distillation을 소개하는 논문이고 DistilBERT 역시 큰 뼈대는 같은 방식을 사용하고 있기 때문에 한 번 읽어보시는 것이 좋습니다.

- BERT

: 2018년 구글에서 소개한 NLP 분야의 pre-trained 모델입니다. Transformer의 인코더를 쌓아올린 아키텍처를 가지고 있습니다. BERT-base는 110M개의 파라미터, BERT-Large는 무려 340M개의 파라미터를 가지고 있습니다. 해당 논문은 BERT의 모델을 변형한 DistilBERT를 Student로 사용하여 학습하는 방법을 이야기하기 때문에 BERT의 아키텍처(Embedding, Transformer, pooler 등)에 대해 아는 것이 좋습니다.

 

DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter

[그림 2]에서 볼 수 있듯 거대한 사이즈의 pre-trained 모델들이 공개되고 있고, 그것을 이용해 transfer learning을 하는 것은 NLP 분야에서 일반적인 과정이 되었습니다.  

[그림 2] 거대한 파라미터 사이즈를 가진 language 모델들

 

이러한 모델에는 몇 가지 문제가 있습니다.

  • 계산(computational) requirements가 많이 필요하다는 것
  • On-device로 모델을 이용하게 된다면 새롭고 흥미로운 언어 처리 응용 프로그램을 가능하게 할 가능성이 있지만 계산 및 메모리 요구 사항이 모델 채택을 방해한다는 것

그래서 DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter 논문에서는 Knowledge Distillation 방법을 적용해 기존의 BERT의 97% language understanding capacity를 유지하면서 모델 사이즈의 40%를 줄이고 60%는 더 빠른 모델을 만드는 법을 제안했습니다.

 

 

Architecture

우선 [그림 3]을 보며 지난번에 이야기했던 Distilling the Knowledge in a Neural network의 아키텍처와 학습과정을 먼저 떠올려봅시다. 우선 pre-trained 모델인 Teacher와 Knowledge Distillation으로 학습할 Student가 있을 것입니다. Teacher를 freeze 시킨 상태로 Distillation Loss, Student Loss로 학습을 시킵니다. 그리고 Softmax를 계산할 때 사용하는 하이퍼파라미터 temperature T도 잊지말아야하죠.

 

[그림 3] knowledge Distillation Architecture

DistilBERT도 이러한 과정을 이용해 학습되었습니다. Teacher, Student, Loss function이 어떻게 되어있는지 비교해봅시다. Teacher는 기존의 BERT를 이용했습니다. 그리고 Student는 어떨까요? 분명 Teacher 모델인 BERT보다 더 작은 사이즈인 모델일 것입니다.

 

Student architecture 논문에서는 이 Student 모델인 DistilBERT를 기존의 BERT 모델에서 token type embedding과 pooler 레이어를 제거했습니다. BERT는 Transformer의 인코더를 쌓아올린 구조를 가지고 있다고 했습니다. Transformer를 주로 이루고 있는 linear layer, layer normalisation 등의 레이어는 modern linear algebra frameworks 들로 이미 최적화되어 있기 때문에 layer 갯수를 줄이는 것이 dimension 을 computation efficiency에 큰 영향을 준다는 이유로 두 레이어가 제거되었습니다.

 

Student Initialization 논문의 저자들은 student architecture 선택외에도 학습이 잘되기 위해서 weight initialization이 잘 되어있어야 한다고 합니다. 따라서 Teacher 모델과 Student 모델의 dimension이 같다는 장점을 이용해 Teacher의 weight를 Student에 사용했습니다.

 

Training loss DistilBERT는 총 3개의 loss를 linear한 결합으로 사용하고 있습니다. $L_{CE}$, $L_{mlm}$, $L_{cos}$ 가 존재합니다. $L_{CE}$는 teacher의 softmax-temperature 값인 $t_i$, student의 softmax-temperature 값인 $s_i$으로 최종 $L_{CE} = \sum_{i}t_i ∗ log(s_{i})$를 계산합니다. Softmax-temperature는 Distilling the Knowledge in a Neural network 방식을 그대로 사용했습니다. $L_{mlm}$는 Student는 BERT 기반 모델이기 때문에 BERT의 masked language modeling loss를 의미합니다. 기존의 Knowledge Distillation 모델에서는 Distillation을 해주는 Loss, Student를 학습하는 Loss로 총 2가지 Loss만을 사용했는데 DistilBERT에서는 $L_{cos}$를 추가하여 총 3개의 loss를 가지고 있습니다. 나머지 한 개의 Loss의 역할은 Teacher와 Student 모델의 마지막 hidden state의 vector 방향을 align하게 해주는 효과가 있어서 추가했다고 합니다.

 

  • Teacher : 기존 BERT
  • Student : DistilBERT (기존 BERT에서 Token-type embedding 과 Pooler 제거)
  • Loss : $L_{CE}$, $L_{mlm}$, $L_{cos}$

 

DistilBERT 결과 [그림 4] DistilBERT의 결과입니다. Table 1을 보면 기존 BERT랑 비교해서 모델의 97%의 퍼포먼스를 유지한다는 것을 알 수 있고 Table 2를 통해 Downstream Task의 퍼포먼스도 크게 차이나지 않는다는 것을 알 수 있습니다. 마지막으로 Table 3을 통해서 ELMo와 BERT-base보다 훨씬 더 적은 파라미터를 가지고 inference time이 더 빠르다는 것을 알 수 있습니다.

 

[그림 4] DistilBERT 결과

 

DistilBERT Code

 

DistilBERT Loss, BERT와의 다른 점을 코드를 보며 이해해봅시다. 이곳의 코드를 참고했습니다. 지난번에 Knowledge Distillation을 하기위해 Teacher, Student, Distiller(이곳에 두 모델을 보내서 학습)을 만들었습니다. 아래의 코드는 train.py의 일부입니다.

 

## STUDENT ##
logger.info(f'Loading student config from {args.student_config}')
stu_architecture_config = student_config_class.from_pretrained(args.student_config)
stu_architecture_config.output_hidden_states = True
if args.student_pretrained_weights is not None:
logger.info(f'Loading pretrained weights from {args.student_pretrained_weights}')
student = student_model_class.from_pretrained(args.student_pretrained_weights,
config=stu_architecture_config)
else:
student = student_model_class(stu_architecture_config)
if args.n_gpu > 0:
student.to(f'cuda:{args.local_rank}')
logger.info(f'Student loaded.')
## TEACHER ##
teacher = teacher_model_class.from_pretrained(args.teacher_name, output_hidden_states=True)
if args.n_gpu > 0:
teacher.to(f'cuda:{args.local_rank}')
logger.info(f'Teacher loaded from {args.teacher_name}.')
## FREEZING ##
if args.freeze_pos_embs:
freeze_pos_embeddings(student, args)
if args.freeze_token_type_embds:
freeze_token_type_embeddings(student, args)
## SANITY CHECKS ##
assert student.config.vocab_size == teacher.config.vocab_size
assert student.config.hidden_size == teacher.config.hidden_size
assert student.config.max_position_embeddings == teacher.config.max_position_embeddings
if args.mlm:
assert token_probs.size(0) == stu_architecture_config.vocab_size
## DISTILLER ##
torch.cuda.empty_cache()
distiller = Distiller(params=args,
dataset=train_lm_seq_dataset,
token_probs=token_probs,
student=student,
teacher=teacher)
distiller.train()
logger.info("Let's go get some drinks.")
view raw train.py hosted with ❤ by GitHub

[코드 line 1 -15]

Student 모델을 만들고 Student Initialization을 하는 것을 볼 수 있습니다. 그리고 마지막 줄 15에서 device에 모델을 로드합니다.

 

[코드 line 17 -21]

Pre-trained된 Teacher 모델을 만들고 device에 모델을 로드합니다.

 

[코드 line 23 - 27]

freeze_pos_embeddings는 Student 모델이 ['roberta', 'gpt2'] 일 때, freeze_token_type_embeddings는 Student 모델이 ['roberta'] 일 때만 해당되므로 넘어가겠습니다.

 

[코드 line 29 - 34]

Teacher 모델과 Student 모델의 configuration이 동일한지 체크합니다.

 

[코드 line 36 - 43]

Distillation 과정을 가지고 있는 Distiller 클래스에 Teacher 모델, Student 모델, 데이터 등을 넣고 학습을 시작합니다.

 

 

이제 Distiller 클래스에서 어떻게 DistilBERT가 학습되는지 확인해봅시다. 코드가 꽤 길지만 겁먹지 말고 필요한 부분을 찾아봅시다.

# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team and Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" The distiller to distil the student.
Adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
"""
import math
import os
import time
import psutil
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import BatchSampler, DataLoader, RandomSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
from grouped_batch_sampler import GroupedBatchSampler, create_lengths_groups
from lm_seqs_dataset import LmSeqsDataset
from transformers import get_linear_schedule_with_warmup
from utils import logger
try:
from torch.utils.tensorboard import SummaryWriter
except ImportError:
from tensorboardX import SummaryWriter
class Distiller:
def __init__(
self, params: dict, dataset: LmSeqsDataset, token_probs: torch.tensor, student: nn.Module, teacher: nn.Module
):
logger.info("Initializing Distiller")
self.params = params
self.dump_path = params.dump_path
self.multi_gpu = params.multi_gpu
self.fp16 = params.fp16
self.student = student
self.teacher = teacher
self.student_config = student.config
self.vocab_size = student.config.vocab_size
if params.n_gpu <= 1:
sampler = RandomSampler(dataset)
else:
sampler = DistributedSampler(dataset)
if params.group_by_size:
groups = create_lengths_groups(lengths=dataset.lengths, k=params.max_model_input_size)
sampler = GroupedBatchSampler(sampler=sampler, group_ids=groups, batch_size=params.batch_size)
else:
sampler = BatchSampler(sampler=sampler, batch_size=params.batch_size, drop_last=False)
self.dataloader = DataLoader(dataset=dataset, batch_sampler=sampler, collate_fn=dataset.batch_sequences)
self.temperature = params.temperature
assert self.temperature > 0.0
self.alpha_ce = params.alpha_ce
self.alpha_mlm = params.alpha_mlm
self.alpha_clm = params.alpha_clm
self.alpha_mse = params.alpha_mse
self.alpha_cos = params.alpha_cos
self.mlm = params.mlm
if self.mlm:
logger.info(f"Using MLM loss for LM step.")
self.mlm_mask_prop = params.mlm_mask_prop
assert 0.0 <= self.mlm_mask_prop <= 1.0
assert params.word_mask + params.word_keep + params.word_rand == 1.0
self.pred_probs = torch.FloatTensor([params.word_mask, params.word_keep, params.word_rand])
self.pred_probs = self.pred_probs.to(f"cuda:{params.local_rank}") if params.n_gpu > 0 else self.pred_probs
self.token_probs = token_probs.to(f"cuda:{params.local_rank}") if params.n_gpu > 0 else token_probs
if self.fp16:
self.pred_probs = self.pred_probs.half()
self.token_probs = self.token_probs.half()
else:
logger.info(f"Using CLM loss for LM step.")
self.epoch = 0
self.n_iter = 0
self.n_total_iter = 0
self.n_sequences_epoch = 0
self.total_loss_epoch = 0
self.last_loss = 0
self.last_loss_ce = 0
self.last_loss_mlm = 0
self.last_loss_clm = 0
if self.alpha_mse > 0.0:
self.last_loss_mse = 0
if self.alpha_cos > 0.0:
self.last_loss_cos = 0
self.last_log = 0
self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
self.lm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
if self.alpha_mse > 0.0:
self.mse_loss_fct = nn.MSELoss(reduction="sum")
if self.alpha_cos > 0.0:
self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction="mean")
logger.info("--- Initializing model optimizer")
assert params.gradient_accumulation_steps >= 1
self.num_steps_epoch = len(self.dataloader)
num_train_optimization_steps = (
int(self.num_steps_epoch / params.gradient_accumulation_steps * params.n_epoch) + 1
)
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [
p for n, p in student.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad
],
"weight_decay": params.weight_decay,
},
{
"params": [
p for n, p in student.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad
],
"weight_decay": 0.0,
},
]
logger.info(
"------ Number of trainable parameters (student): %i"
% sum([p.numel() for p in self.student.parameters() if p.requires_grad])
)
logger.info("------ Number of parameters (student): %i" % sum([p.numel() for p in self.student.parameters()]))
self.optimizer = AdamW(
optimizer_grouped_parameters, lr=params.learning_rate, eps=params.adam_epsilon, betas=(0.9, 0.98)
)
warmup_steps = math.ceil(num_train_optimization_steps * params.warmup_prop)
self.scheduler = get_linear_schedule_with_warmup(
self.optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_train_optimization_steps
)
if self.fp16:
try:
from apex import amp
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
logger.info(f"Using fp16 training: {self.params.fp16_opt_level} level")
self.student, self.optimizer = amp.initialize(
self.student, self.optimizer, opt_level=self.params.fp16_opt_level
)
self.teacher = self.teacher.half()
if self.multi_gpu:
if self.fp16:
from apex.parallel import DistributedDataParallel
logger.info("Using apex.parallel.DistributedDataParallel for distributed training.")
self.student = DistributedDataParallel(self.student)
else:
from torch.nn.parallel import DistributedDataParallel
logger.info("Using nn.parallel.DistributedDataParallel for distributed training.")
self.student = DistributedDataParallel(
self.student,
device_ids=[params.local_rank],
output_device=params.local_rank,
find_unused_parameters=True,
)
self.is_master = params.is_master
if self.is_master:
logger.info("--- Initializing Tensorboard")
self.tensorboard = SummaryWriter(log_dir=os.path.join(self.dump_path, "log", "train"))
self.tensorboard.add_text(tag="config/training", text_string=str(self.params), global_step=0)
self.tensorboard.add_text(tag="config/student", text_string=str(self.student_config), global_step=0)
def prepare_batch_mlm(self, batch):
"""
Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the masked label for MLM.
Input:
------
batch: `Tuple`
token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.
Output:
-------
token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
mlm_labels: `torch.tensor(bs, seq_length)` - The masked languge modeling labels. There is a -100 where there is nothing to predict.
"""
token_ids, lengths = batch
token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
assert token_ids.size(0) == lengths.size(0)
attn_mask = torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None]
bs, max_seq_len = token_ids.size()
mlm_labels = token_ids.new(token_ids.size()).copy_(token_ids)
x_prob = self.token_probs[token_ids.flatten()]
n_tgt = math.ceil(self.mlm_mask_prop * lengths.sum().item())
tgt_ids = torch.multinomial(x_prob / x_prob.sum(), n_tgt, replacement=False)
pred_mask = torch.zeros(
bs * max_seq_len, dtype=torch.bool, device=token_ids.device
) # previously `dtype=torch.uint8`, cf pytorch 1.2.0 compatibility
pred_mask[tgt_ids] = 1
pred_mask = pred_mask.view(bs, max_seq_len)
pred_mask[token_ids == self.params.special_tok_ids["pad_token"]] = 0
# mask a number of words == 0 [8] (faster with fp16)
if self.fp16:
n1 = pred_mask.sum().item()
if n1 > 8:
pred_mask = pred_mask.view(-1)
n2 = max(n1 % 8, 8 * (n1 // 8))
if n2 != n1:
pred_mask[torch.nonzero(pred_mask).view(-1)[: n1 - n2]] = 0
pred_mask = pred_mask.view(bs, max_seq_len)
assert pred_mask.sum().item() % 8 == 0, pred_mask.sum().item()
_token_ids_real = token_ids[pred_mask]
_token_ids_rand = _token_ids_real.clone().random_(self.vocab_size)
_token_ids_mask = _token_ids_real.clone().fill_(self.params.special_tok_ids["mask_token"])
probs = torch.multinomial(self.pred_probs, len(_token_ids_real), replacement=True)
_token_ids = (
_token_ids_mask * (probs == 0).long()
+ _token_ids_real * (probs == 1).long()
+ _token_ids_rand * (probs == 2).long()
)
token_ids = token_ids.masked_scatter(pred_mask, _token_ids)
mlm_labels[~pred_mask] = -100 # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility
# sanity checks
assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size
return token_ids, attn_mask, mlm_labels
def prepare_batch_clm(self, batch):
"""
Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the labels for CLM.
Input:
------
batch: `Tuple`
token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.
Output:
-------
token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
clm_labels: `torch.tensor(bs, seq_length)` - The causal languge modeling labels. There is a -100 where there is nothing to predict.
"""
token_ids, lengths = batch
token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
assert token_ids.size(0) == lengths.size(0)
attn_mask = torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None]
clm_labels = token_ids.new(token_ids.size()).copy_(token_ids)
clm_labels[~attn_mask] = -100 # previously `clm_labels[1-attn_mask] = -1`, cf pytorch 1.2.0 compatibility
# sanity checks
assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size
return token_ids, attn_mask, clm_labels
def round_batch(self, x: torch.tensor, lengths: torch.tensor):
"""
For float16 only.
Sub-sample sentences in a batch, and add padding, so that each dimension is a multiple of 8.
Input:
------
x: `torch.tensor(bs, seq_length)` - The token ids.
lengths: `torch.tensor(bs, seq_length)` - The lengths of each of the sequence in the batch.
Output:
-------
x: `torch.tensor(new_bs, new_seq_length)` - The updated token ids.
lengths: `torch.tensor(new_bs, new_seq_length)` - The updated lengths.
"""
if not self.fp16 or len(lengths) < 8:
return x, lengths
# number of sentences == 0 [8]
bs1 = len(lengths)
bs2 = 8 * (bs1 // 8)
assert bs2 > 0 and bs2 % 8 == 0
if bs1 != bs2:
idx = torch.randperm(bs1)[:bs2]
lengths = lengths[idx]
slen = lengths.max().item()
x = x[idx, :slen]
else:
idx = None
# sequence length == 0 [8]
ml1 = x.size(1)
if ml1 % 8 != 0:
pad = 8 - (ml1 % 8)
ml2 = ml1 + pad
if self.mlm:
pad_id = self.params.special_tok_ids["pad_token"]
else:
pad_id = self.params.special_tok_ids["unk_token"]
padding_tensor = torch.zeros(bs2, pad, dtype=torch.long, device=x.device).fill_(pad_id)
x = torch.cat([x, padding_tensor], 1)
assert x.size() == (bs2, ml2)
assert x.size(0) % 8 == 0
assert x.size(1) % 8 == 0
return x, lengths
def train(self):
"""
The real training loop.
"""
if self.is_master:
logger.info("Starting training")
self.last_log = time.time()
self.student.train()
self.teacher.eval()
for _ in range(self.params.n_epoch):
if self.is_master:
logger.info(f"--- Starting epoch {self.epoch}/{self.params.n_epoch-1}")
if self.multi_gpu:
torch.distributed.barrier()
iter_bar = tqdm(self.dataloader, desc="-Iter", disable=self.params.local_rank not in [-1, 0])
for batch in iter_bar:
if self.params.n_gpu > 0:
batch = tuple(t.to(f"cuda:{self.params.local_rank}") for t in batch)
if self.mlm:
token_ids, attn_mask, lm_labels = self.prepare_batch_mlm(batch=batch)
else:
token_ids, attn_mask, lm_labels = self.prepare_batch_clm(batch=batch)
self.step(input_ids=token_ids, attention_mask=attn_mask, lm_labels=lm_labels)
iter_bar.update()
iter_bar.set_postfix(
{"Last_loss": f"{self.last_loss:.2f}", "Avg_cum_loss": f"{self.total_loss_epoch/self.n_iter:.2f}"}
)
iter_bar.close()
if self.is_master:
logger.info(f"--- Ending epoch {self.epoch}/{self.params.n_epoch-1}")
self.end_epoch()
if self.is_master:
logger.info(f"Save very last checkpoint as `pytorch_model.bin`.")
self.save_checkpoint(checkpoint_name=f"pytorch_model.bin")
logger.info("Training is finished")
def step(self, input_ids: torch.tensor, attention_mask: torch.tensor, lm_labels: torch.tensor):
"""
One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation),
and possibly a parameter update (depending on the gradient accumulation).
Input:
------
input_ids: `torch.tensor(bs, seq_length)` - The token ids.
attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention.
lm_labels: `torch.tensor(bs, seq_length)` - The language modeling labels (mlm labels for MLM and clm labels for CLM).
"""
if self.mlm:
s_logits, s_hidden_states = self.student(
input_ids=input_ids, attention_mask=attention_mask
) # (bs, seq_length, voc_size)
with torch.no_grad():
t_logits, t_hidden_states = self.teacher(
input_ids=input_ids, attention_mask=attention_mask
) # (bs, seq_length, voc_size)
else:
s_logits, _, s_hidden_states = self.student(
input_ids=input_ids, attention_mask=None
) # (bs, seq_length, voc_size)
with torch.no_grad():
t_logits, _, t_hidden_states = self.teacher(
input_ids=input_ids, attention_mask=None
) # (bs, seq_length, voc_size)
assert s_logits.size() == t_logits.size()
# https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
# https://github.com/peterliht/knowledge-distillation-pytorch/issues/2
if self.params.restrict_ce_to_mask:
mask = (lm_labels > -1).unsqueeze(-1).expand_as(s_logits) # (bs, seq_lenth, voc_size)
else:
mask = attention_mask.unsqueeze(-1).expand_as(s_logits) # (bs, seq_lenth, voc_size)
s_logits_slct = torch.masked_select(s_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask
s_logits_slct = s_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask
t_logits_slct = torch.masked_select(t_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask
t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask
assert t_logits_slct.size() == s_logits_slct.size()
loss_ce = (
self.ce_loss_fct(
F.log_softmax(s_logits_slct / self.temperature, dim=-1),
F.softmax(t_logits_slct / self.temperature, dim=-1),
)
* (self.temperature) ** 2
)
loss = self.alpha_ce * loss_ce
if self.alpha_mlm > 0.0:
loss_mlm = self.lm_loss_fct(s_logits.view(-1, s_logits.size(-1)), lm_labels.view(-1))
loss += self.alpha_mlm * loss_mlm
if self.alpha_clm > 0.0:
shift_logits = s_logits[..., :-1, :].contiguous()
shift_labels = lm_labels[..., 1:].contiguous()
loss_clm = self.lm_loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
loss += self.alpha_clm * loss_clm
if self.alpha_mse > 0.0:
loss_mse = self.mse_loss_fct(s_logits_slct, t_logits_slct) / s_logits_slct.size(
0
) # Reproducing batchmean reduction
loss += self.alpha_mse * loss_mse
if self.alpha_cos > 0.0:
s_hidden_states = s_hidden_states[-1] # (bs, seq_length, dim)
t_hidden_states = t_hidden_states[-1] # (bs, seq_length, dim)
mask = attention_mask.unsqueeze(-1).expand_as(s_hidden_states) # (bs, seq_length, dim)
assert s_hidden_states.size() == t_hidden_states.size()
dim = s_hidden_states.size(-1)
s_hidden_states_slct = torch.masked_select(s_hidden_states, mask) # (bs * seq_length * dim)
s_hidden_states_slct = s_hidden_states_slct.view(-1, dim) # (bs * seq_length, dim)
t_hidden_states_slct = torch.masked_select(t_hidden_states, mask) # (bs * seq_length * dim)
t_hidden_states_slct = t_hidden_states_slct.view(-1, dim) # (bs * seq_length, dim)
target = s_hidden_states_slct.new(s_hidden_states_slct.size(0)).fill_(1) # (bs * seq_length,)
loss_cos = self.cosine_loss_fct(s_hidden_states_slct, t_hidden_states_slct, target)
loss += self.alpha_cos * loss_cos
self.total_loss_epoch += loss.item()
self.last_loss = loss.item()
self.last_loss_ce = loss_ce.item()
if self.alpha_mlm > 0.0:
self.last_loss_mlm = loss_mlm.item()
if self.alpha_clm > 0.0:
self.last_loss_clm = loss_clm.item()
if self.alpha_mse > 0.0:
self.last_loss_mse = loss_mse.item()
if self.alpha_cos > 0.0:
self.last_loss_cos = loss_cos.item()
self.optimize(loss)
self.n_sequences_epoch += input_ids.size(0)
def optimize(self, loss):
"""
Normalization on the loss (gradient accumulation or distributed training), followed by
backward pass on the loss, possibly followed by a parameter update (depending on the gradient accumulation).
Also update the metrics for tensorboard.
"""
# Check for NaN
if (loss != loss).data.any():
logger.error("NaN detected")
exit()
if self.multi_gpu:
loss = loss.mean()
if self.params.gradient_accumulation_steps > 1:
loss = loss / self.params.gradient_accumulation_steps
if self.fp16:
from apex import amp
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
self.iter()
if self.n_iter % self.params.gradient_accumulation_steps == 0:
if self.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.params.max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(self.student.parameters(), self.params.max_grad_norm)
self.optimizer.step()
self.optimizer.zero_grad()
self.scheduler.step()
def iter(self):
"""
Update global counts, write to tensorboard and save checkpoint.
"""
self.n_iter += 1
self.n_total_iter += 1
if self.n_total_iter % self.params.log_interval == 0:
self.log_tensorboard()
self.last_log = time.time()
if self.n_total_iter % self.params.checkpoint_interval == 0:
self.save_checkpoint()
def log_tensorboard(self):
"""
Log into tensorboard. Only by the master process.
"""
if not self.is_master:
return
for param_name, param in self.student.named_parameters():
self.tensorboard.add_scalar(
tag="parameter_mean/" + param_name, scalar_value=param.data.mean(), global_step=self.n_total_iter
)
self.tensorboard.add_scalar(
tag="parameter_std/" + param_name, scalar_value=param.data.std(), global_step=self.n_total_iter
)
if param.grad is None:
continue
self.tensorboard.add_scalar(
tag="grad_mean/" + param_name, scalar_value=param.grad.data.mean(), global_step=self.n_total_iter
)
self.tensorboard.add_scalar(
tag="grad_std/" + param_name, scalar_value=param.grad.data.std(), global_step=self.n_total_iter
)
self.tensorboard.add_scalar(
tag="losses/cum_avg_loss_epoch",
scalar_value=self.total_loss_epoch / self.n_iter,
global_step=self.n_total_iter,
)
self.tensorboard.add_scalar(tag="losses/loss", scalar_value=self.last_loss, global_step=self.n_total_iter)
self.tensorboard.add_scalar(
tag="losses/loss_ce", scalar_value=self.last_loss_ce, global_step=self.n_total_iter
)
if self.alpha_mlm > 0.0:
self.tensorboard.add_scalar(
tag="losses/loss_mlm", scalar_value=self.last_loss_mlm, global_step=self.n_total_iter
)
if self.alpha_clm > 0.0:
self.tensorboard.add_scalar(
tag="losses/loss_clm", scalar_value=self.last_loss_clm, global_step=self.n_total_iter
)
if self.alpha_mse > 0.0:
self.tensorboard.add_scalar(
tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter
)
if self.alpha_cos > 0.0:
self.tensorboard.add_scalar(
tag="losses/loss_cos", scalar_value=self.last_loss_cos, global_step=self.n_total_iter
)
self.tensorboard.add_scalar(
tag="learning_rate/lr", scalar_value=self.scheduler.get_lr()[0], global_step=self.n_total_iter
)
self.tensorboard.add_scalar(
tag="global/memory_usage",
scalar_value=psutil.virtual_memory()._asdict()["used"] / 1_000_000,
global_step=self.n_total_iter,
)
self.tensorboard.add_scalar(
tag="global/speed", scalar_value=time.time() - self.last_log, global_step=self.n_total_iter
)
def end_epoch(self):
"""
Finally arrived at the end of epoch (full pass on dataset).
Do some tensorboard logging and checkpoint saving.
"""
logger.info(f"{self.n_sequences_epoch} sequences have been trained during this epoch.")
if self.is_master:
self.save_checkpoint(checkpoint_name=f"model_epoch_{self.epoch}.pth")
self.tensorboard.add_scalar(
tag="epoch/loss", scalar_value=self.total_loss_epoch / self.n_iter, global_step=self.epoch
)
self.epoch += 1
self.n_sequences_epoch = 0
self.n_iter = 0
self.total_loss_epoch = 0
def save_checkpoint(self, checkpoint_name: str = "checkpoint.pth"):
"""
Save the current state. Only by the master process.
"""
if not self.is_master:
return
mdl_to_save = self.student.module if hasattr(self.student, "module") else self.student
mdl_to_save.config.save_pretrained(self.dump_path)
state_dict = mdl_to_save.state_dict()
torch.save(state_dict, os.path.join(self.dump_path, checkpoint_name))

[코드 line 53 - 54]

Teacher, Student 모델 넘겨받기

 

[코드 line 81 - 94]

Masked language modeling loss인 $L_{mlm}$을 위한 과정

 

[코드 line 111 - 116]

Loss를 계산해줄 객체 생성, 아래에서 논문에서 나온 3개의 Loss를 어떻게 코딩했는지 살펴봅시다.

 

[코드 line 324 - 365]

[331 - 332] Student 모델은 train 상태Teacher 모델은 eval 상태로(freeze)

[334 - 365] Epoch 수만큼 학습시킵니다.

 

[코드 line 366 - 459]

loss_ce, loss_mlm, loss_cos를 제외한 나머지 loss와 alpha는 GPT2, roberta 모델일 때 필요하기 때문에 넘어가겠습니다.

[406 - 412] $L_{CE}$ 계산, $t_{i}$와 $s_{i}$를 먼저 구한 뒤 softmax with temperature를 사용하여 계산합니다.

[415 - 417] $L_{mlm}$ 계산

[429 - 443] Teacher와 Student 모델의 마지막 hidden state의 vector 방향을 align하는 목적으로 사용되기 때문에 hidden_states의 마지막 벡터 s_hidden_states, t_hidden_states로 Loss를 계산합니다.

 

 

 

모델을 직접 학습해보지는 않았지만 코드를 이용해서 논문에서의 내용이 어떻게 구현되어있는지 살펴보았습니다. DistilBERT 논문을 읽고 코드를 보며 느낀 점은 학습 방식이 Distilling the Knowledge in a Neural Network과 크게 달라지지 않았지만 논문의 목표는 Knowledge Distilation을 사용하여 BERT의 사이즈를 줄이고 비슷한 퍼포먼스를 유지하는 것이기 때문에 BERT에 대해 잘 모른다면 DistilBERT의 아키텍처와 코드를 이해하기 어렵다는 것입니다. 또한 Knowledge Distilation의 아이디어를 간단하게 어떤 모델이든 적용해볼 수 있겠다는 생각이 들었습니다.

 

 

모두의 연구소에서 진행하는
"함께 콘텐츠를 제작하는 콘텐츠 크리에이터 모임"
COCRE(코크리) 2기 회원으로 제작한 글입니다. 
코크리란? 🐘

 

Reference

[그림 1] https://upload.wikimedia.org/wikipedia/commons/a/a0/Milky_Way_libya.jpg

[그림 2] DistilBERT 논문

[그림 3] https://intellabs.github.io/distiller/knowledge_distillation.html

[그림 4] DistilBERT 논문