Self Supervised Learning를 여행하는 히치하이커를 위한 안내서
저는 요즘 모두의 연구소에서 Self Supervised Learning(SSL)을 공부하고 SSL에 대해 논문을 쓰는 것을 목표로 하는 SSL LAB(쓸랩)의 연구원으로도 활동하고 있습니다. 제가 정말 논문을.. 쓸 수 있을까요? 🥲 공부하게될 논문들이 SSL을 공부하는 좋은 로드맵이 될 것 같아 살짝 정리를 해보는 것이 좋을 것이라 생각했습니다. 논문 설명과 논문에서 공헌한 부분의 코드를 설명하는 시리즈가 될 것 같습니다.
미리 알아두면 좋은 지식
- MNIST
: 손으로 쓴 숫자들로 이루어진 28 x 28 픽셀 사이즈로 구성된 이미지 데이터.
- Ensemble
: 같은 데이터로 여러개의 분류기 모델을 학습시켜 모델의 예측을 결합하여 더 좋은 성능을 도출하는 기법. Bagging, Boosting, Stacking 등의 기법이 있다.
- Softmax
: 뉴런의 출력값에 대하여 분류를 위하여 마지막 단계에서 출력값을 0~1사이의 값으로 정규화 해주는 함수.
$q_i = \frac{exp(z_i)}{\sum_jexp(z_j)}$로 표현한다.
- Kullback–Leibler divergence
: 두 확률분포의 차이를 계산하는 데에 사용하는 함수.
Distilling the Knowledge in a Neural network
ML 알고리즘의 성능을 올리는 간단한 방법은 여러 개의 모델을 같은 데이터로 학습시키고, 그 모델들이 예측을 평균내는 것입니다. 하지만 이 방법은 성가시고(cumbersome) 비용이 많이 듭니다(too computationally expensive). 모델이 간단해지면 inference 속도도 빨라지고 용량도 작아지는 장점이 있습니다. 그래서 Hinton et al.은 작은 한 개의 모델로 더 큰 모델 만큼의 성능을 낼 수 있는 방법인 Knowledge Distillation 개념을 소개하고 뉴럴넷에 적용하였습니다. 그 방법은, 잘 학습된 보다 복잡하고 큰 Teacher 모델이 간단하고 작은 Student 모델에게 generalization한 정보를 학습하게 하여 Teacher와 비슷한 성능을 내도록 하는 것입니다.
Architecture
우선 Hinton et al.이 제안한 Knowledge Distillation 방법에 대한 아키텍처를 보는 것이 이해가 빠를 것 같습니다.
순서는 다음과 같습니다.
1. Teacher model 학습
2. Student model 학습 (Teacher가 Student에게 knowledge Distillation)
- Soft Targets, Soft Labels
- Distillation loss
- Student Loss
맨 처음으로 Teacher model을 학습합니다. Teacher model은 Student model 보다 더 큰 파라미터를 가지고 있습니다. 그리고 Student model을 학습하기 위해 학습한 Teacher model을 freeze 시킵니다. 위의 [그림 2] 아키텍처에서 Softmax(T = t)와 Softmax(T = 1)이 보입니다. T = t 일 때의 의미는 $q_i = \frac{exp(z_i/T)}{\sum_jexp(z_j/T)}$와 같습니다. 왜 이렇게 했을까요?
Soft Targets, Soft Labels
T = 1일 때는 기존에 우리가 알던 수식 $q_i = \frac{exp(z_i)}{\sum_jexp(z_j)}$과 같을 것입니다. 하지만 T = t 일 때는 다릅니다. 왜 t로 나누도록 했을까요?
[그림 3]은 자율주행 자동차 카메라의 object classification 결과(softmax 값) 입니다. softmax 함수는 각 클래스에 대한 확률을 보여주고, 자동차가 0.78, 사람이 0.003, 트럭이 0.215, 신호등이 0.002이라는 확률로 카메라가 자동차를 인식하고 있다는 것을 알 수 있습니다.
이 물체가 확실하게 자동차라는 정보도 중요하지만 이 물체가 사람이나 신호등 보다 트럭과 닮았다는 사실도 중요합니다. 이러한 정보를 student model에게 가르치기 위해 softmax를 계산할 때 T(Temperature)라는 hyperparameter를 이용해 softning 한 것 입니다. 아래의 [그림 4]는 T = 5 일 때인 값이고 soft targets(soft labels)이라고 부릅니다.
Hyperparameter T는 실험을 통해 알맞는 값을 결정합니다.
Distillation Loss
Loss는 총 두 개가 있습니다. Distillation Loss와 Student Loss 입니다. Distillation loss는 말그대로 Teacher model이 가지고 있는 Knowledge를 Student model에게 Distillation을 하기 위한 Loss 입니다. Tempearture hyperparameter로 softning한 각각의 Teacher와 Student의 soft target을 KL divergence Loss를 이용하여 계산합니다. KL divergence로 두 확률분포의 차이를 계산할 수 있기 때문에 Distillation Loss에서 KL divergence를 사용함으로서 Teacher와 Student의 확률분포가 비슷해지도록 학습을 유도합니다.
Student Loss
Student Loss를 계산하기 위해 Cross Entropy Loss를 사용하고 soft targets이 아닌 T = 1인 Hard targets을 사용합니다. 이로서 Student model는 Distillation Loss를 통해 Teacher model의 knowledge를 전수받고, Student Loss를 통해 더 정확한 값을 학습할 수 있습니다.
Knowledge Distillation Code
Knowledge Distillation을 구현해놓은 코드를 보며 중요한 부분들을 이야기해봅시다. 이곳의 코드를 참고했습니다. 다음과 같은 순서로 코드를 살펴볼 것입니다.
1. Teacher model, Student model 만들기
2. Teacher model 학습하기
3. Teacher model Freeze, Student model 학습하기
1. Teacher model, Student model 만들기
MNIST 분류 모델 예시이기 때문에 input size는 28 x 28 입니다. 위 코드를 보고 Teacher model은 Student model 보다 파라미터가 더 많다는 사실을 반영했다는 것을 알 수 있습니다.
2. Teacher model 학습하기
Epoch 1/5
1875/1875 [==============================] - 248s 132ms/step - loss: 0.2438 - sparse_categorical_accuracy: 0.9220
Epoch 2/5
1875/1875 [==============================] - 263s 140ms/step - loss: 0.0881 - sparse_categorical_accuracy: 0.9738
Epoch 3/5
1875/1875 [==============================] - 245s 131ms/step - loss: 0.0650 - sparse_categorical_accuracy: 0.9811
Epoch 5/5
363/1875 [====>.........................] - ETA: 3:18 - loss: 0.0555 - sparse_categorical_accuracy: 0.9839
데이터를 불러오고 Student model을 학습하기 전 먼저 Teacher model을 학습합니다. 결과에서 5 에폭을 학습시켜 0.9839 정확도가 나온 것을 볼 수 있습니다.
3. Teacher model Freeze, Student model 학습하기
이제 knowledge distillation을 할 차례입니다. 우리가 위에서 배운 [그림 2] 대로 아키텍처를 구성합니다. 이 Distiller 객체를 생성할 때 Teacher model과 Student model이 각각 매개변수로 들어가게 합니다. 학습 단계에서 어떻게 했는지 다시 생각해보면서, train_step 함수를 보면서 따라가보세요.
[코드 line 35 -52]
두 개의 Loss 함수인 Distillation Loss, Student Loss가 있습니다. Distillation Loss는 Teacher model로부터 얻은 teacher_predictions과 Student model로 부터 얻은 student_predictions을 hyperparameter T(Temperature)를 이용해 softmax 값을 구하여 Distillation Loss를 계산합니다. 이 때 Teacher model은 학습을 하지 않기 때문에 training=False로 설정해주어야 합니다. Student Loss는 student_predictions와 ground truth 값인 y를 이용하여 계산합니다. 그리고 두 loss를 합칩니다.
[코드 line 54 -62]
Gradient 계산 부분
Epoch 1/3
1875/1875 [==============================] - 242s 129ms/step - sparse_categorical_accuracy: 0.9761 - student_loss: 0.1526 - distillation_loss: 0.0226
Epoch 2/3
1875/1875 [==============================] - 281s 150ms/step - sparse_categorical_accuracy: 0.9863 - student_loss: 0.1384 - distillation_loss: 0.0185
Epoch 3/3
399/1875 [=====>........................] - ETA: 3:27 - sparse_categorical_accuracy: 0.9896 - student_loss: 0.1300 - distillation_loss: 0.0182
distiller를 compile 할 때 Distillation Loss로 KL Divergence를 Student Loss로 Cross Entropy Loss를 사용하는 것을 볼 수 있습니다. alpha는 Distillation Loss와 Student Loss의 가중치 값으로 각가의 가중치는 합이 1이 됩니다.
Knowledge Distillation 한 결과 Student model이 3에폭만에 0.9896 정확도로 Teacher model 보다 더 높은 정확도를 얻은 것을 확인할 수 있습니다! 그럼 만약 Student model과 똑같이 생긴 모델을 Knowledge Distillation 하지않고 스크래치부터 학습시키면 어떻게 될까요?
Epoch 1/3
1875/1875 [==============================] - 4s 2ms/step - loss: 0.4731 - sparse_categorical_accuracy: 0.8550
Epoch 2/3
1875/1875 [==============================] - 4s 2ms/step - loss: 0.0966 - sparse_categorical_accuracy: 0.9710
Epoch 3/3
1875/1875 [==============================] - 4s 2ms/step - loss: 0.0750 - sparse_categorical_accuracy: 0.9773
313/313 [==============================] - 0s 963us/step - loss: 0.0691 - sparse_categorical_accuracy: 0.9778
[0.06905383616685867, 0.9778000116348267]
student_scratch 모델은 student가 knowledge distillation으로 학습하기 전 아래의 코드로 미리 똑같이 복사를 해두었습니다.
스크래치부터 학습한 모델은 0.9778 정확도를 얻었습니다. Knowledge distillation한 결과가 0.9896로 가장 좋은 결과를 얻었습니다.
모두의 연구소에서 진행하는
"함께 콘텐츠를 제작하는 콘텐츠 크리에이터 모임"
COCRE(코크리) 2기 회원으로 제작한 글입니다.
코크리란? 🐘
Reference
[그림 1] https://upload.wikimedia.org/wikipedia/commons/a/a0/Milky_Way_libya.jpg
[그림 2] https://intellabs.github.io/distiller/knowledge_distillation.html
[SSL LAB 자료] https://github.com/modu-ssl-lab/ssl-papers/issues/1