본문 바로가기
DL|ML

Google I/O 2023과 Keras Core

by 이든Eden 2023. 7. 15.

Google I/O 2023 이후 그리고 며칠 전 Keras가 Keras Core가 발표한 이후 Keras는 날이 갈 수록 더 핫해져가고 있습니다 🔥🔥🔥🔥🔥

 

그래서 두 가지 소식을 follow 하려고 합니다! 하나는 Google I/O에서의 내용, 하나는 Keras Core에 대한 내용입니다. 두 글에서 케라스와 관련된 내용을 짧고 굵게 읽을 수 있도록 압축했습니다.

 

  • Google I/O 2023: What’s new in TensorFlow and Keras?
  • Introducing Keras Core: Keras for TensorFlow, JAX, and PyTorch

 

 

Google I/O 2023: What’s new in TensorFlow and Keras?

 

KerasCV와 KerasNLP

 

강력하고 모듈화된 KerasCV, KerasNLP로 다양한 테스크의 최신 모델을 코드 몇 줄로 사용할 수 있습니다.

 

The KerasCV + KerasNLP suite, at a glance.

 

 

Diffusion Model을 가져오는 코드를 보여드리겠습니다.

 

# keras_cv import
from keras_cv.models import (
    StableDiffusion,
)
  
# 모델 가져오기
model = StableDiffusion(
    img_width=512,
    img_height=512,
)

# 인퍼런스 하기
images = model.text_to_image(
    "photograph of an astronaut "
    "riding a horse",
    batch_size=3,
)
 

단 몇 줄만의 코드로 이미지를 생성할 수 있습니다.

KerasCV-generated images of an astronaut riding a horse!

 

 

DTensor

파라미터가 거대한 모델을 사용할 때 GPU 메모리 부족이 나는 일은 자주 있었을 것입니다. 이러한 문제를 해결하기 위해 DTensor가 필요합니다!

 

DTensor는 모델 병렬화를 통해 모델을 여러 하드웨어 장치에 걸쳐 확장할 수 있게 합니다. 또한, DTensor는 데이터와 모델 병렬화를 조합하여 효율적으로 모델을 확장할 수 있으며, TPU, GPU 등 어떤 가속기를 사용하더라도 호환됩니다.

 

Mixed (data + model) parallelism, with DTensor.

 

DTensor는 모델 변경 없이 코드 몇 줄만으로 사용가능하며 tf.distribute와 Keras와 같은 주요 인터페이스와 완전히 통합될 것 입니다.

 

아래는 DTensor를 사용하여 1750억개의 파라미터를 갖는 Open Pre-trained Transformer(OPT)를 학습하는 예시입니다.

mesh_dims = [("batch", 2), ("model", 4)]
mesh = dtensor.create_distributed_mesh(mesh_dims, device_type="GPU")
dtensor.initialize_accelerator_system("GPU")

layout_map = keras_nlp.models.OPTCausalLM.create_layout_map(mesh)

with layout_map.scope():
    opt_lm = keras_nlp.models.OPTCasualLM.from_preset("opt_6.7b_en")
opt_lm.compile(...)
opt_lm.fit(wiki_text_dataset)

 

 

 

 

Introducing Keras Core: Keras for TensorFlow, JAX, and PyTorch

 

2023년 가을 Keras Core는 Keras 3.0으로 발전할 것입니다. Keras Core는 Keras 코드 기반을 완전히 재작성하여 모듈화된 백엔드 아키텍처 위에 다시 구축한 것입니다. 이를 통해 임의의 프레임워크 위에서 Keras 워크플로우를 실행할 수 있게 됩니다. 초지 지원 프레임워크로는 TensorFlow, JAX, PyTorch가 될 것입니다.

 

TensorFlow를 사용할 때 tf.keras와 거의 완벽하게 호환성을 제공할 것이며 대부분의 경우, 기존 코드를 from tensorflow import keras 대신 import keras_core as keras로 가져와서 문제없이 실행할 수 있습니다. 더불어 XLA 컴파일로 인해 약간의 성능 향상도 기대할 수 있습니다.

 

초기 지원에 포함된 세개의 프레임워크는 각각 장단점이 있습니다. TensorFlow는 55-60%의 시장 점유율을 가지고 있고 프로덕션 레벨에 아주 유용하고, PyTorch는 45%-50%의 점유율을 가지며 ML 연구에서 많이 쓰이고 있습니다. JAX는 점유율은 낮지만 Generative AI 분야에서 많이 사용되고 있습니다. 이 각 프레임워크는 다양한 사용 사례에 중요한 역할을 하고 있으며, Keras Core를 통해 이 모든 프레임워크를 사용할 수 있게 해주는 것이 Keras multi-backend를 다시 만든 이유입니다.

 

 

케라스 코어의 주요 기능

 

미리 보기 릴리스에 포함된 몇 가지 사항을 살펴보겠습니다.

 

Keras Core는 TensorFlow, JAX, PyTorch와 함께 전체 Keras API를 구현하고 사용할 수 있도록 합니다. 이는 수백 개의 레이어, 수십 개의 메트릭, 손실 함수, 옵티마이저, 콜백, Keras의 학습 및 평가 루프, 그리고 Keras의 저장 및 직렬화 인프라를 포함합니다.

 

기존 tf.keras 모델은 keras import를 keras_core로 변경하기만 하면 JAX와 PyTorch에서 즉시 실행할 수 있습니다! 

 

 

Keras Core는 어떤 프레임워크에서든 동일하게 작동하는 구성 요소(사용자 정의 레이어 또는 pretrained model)를 생성할 수 있도록 해주는 교차 프레임워크 저수준 언어입니다. 

 

keras_core.ops를 사용해서 NumPy API처럼 사용할 수 있게 됩니다. Numpy와 동일한 함수와 동일한 인자를 갖추고 있고, ops.matmul, ops.sum, ops.stack, ops.einsum 등과 같은 함수를 사용할 수 있습니다. 그리고 ops.softmax, ops.binary_crossentropy, ops.conv와 같이 NumPy에는 없지만 신경망에 특화된 함수도 제공합니다.

 

keras_core.ops를 사용한다면, 사용자 정의 레이어, 사용자 정의 손실, 사용자 정의 메트릭, 사용자 정의 옵티마이저는 JAX, PyTorch, TensorFlow와 함께 동작할 수 있습니다. 동일한 코드로 모든 프레임워크에서 사용할 수 있으며, 이는 하나의 구성 요소(예: model.py와 단일 체크포인트 파일)를 유지하면 모든 프레임워크에서 일관된 수치 연산을 사용할 수 있습니다.

 

 

즉, keras_core.ops를 사용해 numpy 함수와 동일한 인자로 같은 결과를 받을 수 있기 때문에, 모든 프레임워크에서 일관된 수치 연산을 사용할 수 있습니다! 케라스 세상!

 

 

 

JAX, PyTorch 및 TensorFlow의 기본 워크플로와 원활하게 통합 

 

이전 multi-backend Keras 1.0과 달리 Keras Core는 Keras 중심적인 워크플로우뿐만 아니라 낮은 수준의 백엔드 기반 워크플로우와 원활하게 연동될 수 있도록 설계되었습니다.

 

Keras Core는 JAX와 PyTorch에서 tf.keras가 이전에 TensorFlow에서 제공한 것과 동일한 수준의 저수준 구현 유연성을 제공합니다.

 

  • optax 옵티마이저, jax.grad, jax.jit, jax.pmap를 사용하여 JAX 훈련 루프를 작성하여 Keras 모델을 학습할 수 있습니다.
  • tf.GradientTape 및 tf.distribute를 사용하여 TensorFlow의 저수준 훈련 루프를 작성하여 Keras 모델을 학습할 수 있습니다.
  • torch.optim 옵티마이저, torch 손실 함수, torch.nn.parallel.DistributedDataParallel 래퍼를 사용하여 PyTorch의 저수준 학습 루프를 작성하여 Keras 모델을 학습할 수 있습니다.
  • Keras 레이어 또는 모델을 torch.nn.Module의 일부로 사용할 수 있습니다. 이는 PyTorch 사용자가 Keras API를 사용하든 말든 Keras 모델을 활용할 수 있음을 의미합니다. Keras 모델을 다른 PyTorch Module과 동일하게 취급할 수 있습니다.

 

즉, 초기 지원 프레임워크 어디서든 케라스 모델을 학습할 수 있습니다! 케라스 세상!

 

 

 

Multi framework ML은 multi framework 데이터 로딩 및 전처리를 의미합니다. Keras Core 모델은 JAX, PyTorch 또는 TensorFlow 백엔드를 사용하더라도 다양한 데이터 파이프라인을 사용하여 학습할 수 있습니다.

  • tf.data.Dataset 파이프라인
  • torch.utils.data.DataLoader 객체
  • NumPy 배열 및 Pandas 데이터프레임
  • keras_core.utils.PyDataset 객체

 

즉, 어떤 데이터 로더 및 전처리 코드를 사용해도 상관없습니다! 케라스 세상!

 

 

 

KerasCV와 KerasNLP의 다양한 pretrained model 모델(BERT, T5, YOLOv8, Whisper 등)을  모든 백엔드에서 사용할 수 있습니다.

 

즉, Keras Core의 KerasCV, KerasNLP에서 최신의 pretrained model을 가져와 모든 백엔드에서 학습할 수 있습니다! 케라스 세상!

 

 

Keras는 모델을 구축하고 훈련하는 단일 "진정한" 방법을 강요하지 않습니다. 대신, 다양한 워크플로우를 지원하여 다른 사용자 프로필에 해당하는 매우 고수준에서 매우 저수준까지 다양한 작업 흐름을 제공합니다. 이는 간단한 워크플로우(Sequential 및 Functional 모델 사용 및 fit()으로 훈련)부터 시작하여 필요에 따라 다양한 구성 요소를 쉽게 사용자 정의하고 이전 코드의 대부분을 재사용할 수 있는 유연성을 제공합니다.

 

특정한 요구사항이 더욱 구체화 됐을 때 다른 툴로 전환할 필요가 없습니다. 이 원칙을 모든 백엔드에 적용했습니다. 예를 들어, train_step 메서드를 오버라이드함으로써 fit()의 강력함을 활용하면서도 학습 루프에서 발생하는 동작을 사용자 정의할 수 있습니다. PyTorch와 TensorFlow에서는 다음과 같이 작동합니다:

 

 

JAX 버전은 여기

 

즉, Keras는 다양한 작업 흐름과 유연성을 제공하도록 모델 구축과 학습을 사용자가 정의할 수 있고 호환이 가능하게 만듭니다. 케라스 세상!

 

 

 

순수 함수형 프로그래밍을 좋아하시나요? 기쁜 소식이 있습니다. Keras의 모든 stateful 객체(즉, 훈련 또는 평가 중에 업데이트되는 숫자 변수를 소유하는 객체)는 이제 stateless API를 갖추어 JAX 함수에서 사용할 수 있게 되었습니다

 

  • 모든 레이어와 모델은 call()과 동일한 기능을 하는 stateless_call() 메서드를 갖추었습니다.
  • 모든 옵티마이저는 apply()과 동일한 기능을 하는 stateless_apply() 메서드를 갖추었습니다.
  • 모든 메트릭은 update_state()을 반영하는 stateless_update_state() 메서드와 result()를 반영하는 stateless_result() 메서드를 갖추었습니다.

 

이러한 메서드들은 어떠한 side-effects도 없습니다. 즉, 입력으로 현재 상태 변수의 값을 받고, 업데이트된 값들을 반환합니다. 예를 들면: outputs, updated_non_trainable_variables = layer.stateless_call( trainable_variables, non_trainable_variables, inputs, ) 이러한 메서드들은 직접 구현할 필요가 없습니다. Stateful하게 (call() 또는 update_state() 등)을 구현한 경우 자동으로 사용할 수 있습니다.

 

 

즉, Keras가 함수형 프로그래밍을 따르는 stateless한 형태를 갖춘 함수들을 이용해 JAX에서 사용할 수 있게 되었습니다! 케라스 세상!

 

 

 

다만, Keras Core는 아직 베타버전이며, 문제점이 발생할 수 있다는 점을 염두에 두셔야 합니다. 주의해야할 점과 Keras Core에 대해 자주 물어보는 질문들에 대한 질의응답을 읽고 싶다면 원문을 읽어보세요! 이 부분에 대해서도 요약을 원하시면 댓글을 달아주세요!