Engineering Whitepaper

Google TPU 실전 가이드(Part 2): JAX와 PyTorch XLA 마스터하기

GPU에 익숙한 당신을 위한 TPU 마이그레이션 가이드. XLA의 마법으로 연산 속도를 3배 높이는 비법을 공개합니다.

By MUMULAB Engineering Team 2025년 11월 26일 25 min read
Coding on TPU
TPU는 더 이상 '구글만의 장난감'이 아닙니다. 모든 엔지니어가 다뤄야 할 필수 장비입니다.
"Part 1에서 TPU의 개념을 이해했다면, Part 2에서는 직접 코드를 작성해 봅니다. 걱정하지 마세요. CUDA 커널을 직접 짤 필요는 없습니다. 우리에겐 JAX와 XLA가 있으니까요."

서론: 왜 코드를 바꿔야 하는가?

많은 개발자들이 "TPU를 쓰려면 텐서플로우(TensorFlow)를 다시 배워야 하나요?"라고 묻습니다. 2025년의 대답은 "아니요"입니다. PyTorch XLA의 성숙과 JAX의 대중화로 인해, 기존 GPU 코드를 거의 그대로 TPU에서 실행할 수 있습니다. 하지만 TPU의 성능을 100% 끌어내기 위해서는 그 뒤단에서 작동하는 XLA (Accelerated Linear Algebra) 컴파일러의 원리를 이해해야 합니다.


Part 1. PyTorch 유저를 위한 가이드 (PyTorch/XLA)

1. XLA란 무엇인가?

GPU는 명령을 하나씩 즉시 실행(Eager Execution)하는 데 능숙하지만, TPU는 전체 연산 그래프를 보고 최적화한 뒤 한 번에 실행(Graph Execution)할 때 가장 빠릅니다. XLA는 당신의 PyTorch 코드를 분석하여 TPU가 이해할 수 있는 최적의 기계어(HLO)로 번역해 주는 '통역사'입니다.

2. 코드 마이그레이션: 3줄의 마법

기존 PyTorch 코드에 단 몇 줄만 추가하면 TPU를 사용할 수 있습니다.

import torch
import torch_xla
import torch_xla.core.xla_model as xm

# 1. 디바이스 설정 (cuda 대신 xla_device 사용)
device = xm.xla_device()

# 2. 모델과 데이터를 TPU로 이동
model = MyNeuralNet().to(device)
data = data.to(device)

# 3. 학습 루프 내에서 Optimizer Step 변경
# optimizer.step() 대신 xm.optimizer_step(optimizer) 사용
# 이는 모든 그라디언트 계산이 끝날 때까지 기다렸다가(Lazy) 한 번에 업데이트하기 위함입니다.
def train_loop():
    output = model(data)
    loss = criterion(output, target)
    loss.backward()
    xm.optimizer_step(optimizer)  # 핵심 변경 사항!

3. 성능 최적화 팁

  • 고정된 입력 크기 (Static Shape): TPU는 입력 데이터의 크기가 변할 때마다 컴파일을 다시 해야 합니다(Re-compilation). 이는 엄청난 오버헤드를 유발하므로, 배치 사이즈를 고정하고 남는 공간은 패딩(Padding)으로 채우는 것이 좋습니다.
  • Too Many Tensors 에러: XLA 그래프가 너무 커지면 메모리 부족이 발생할 수 있습니다. xm.mark_step()을 적절히 호출하여 그래프를 잘라주세요.

Part 2. 고수들의 선택: JAX (Just After eXecution)

JAX는 구글이 작정하고 만든 'Numpy의 GPU/TPU 가속 버전'입니다. 딥마인드의 알파포드(AlphaFold)나 제미나이(Gemini) 모델도 JAX로 만들어졌습니다.

1. JAX의 핵심 철학: 함수형 프로그래밍

PyTorch가 객체 지향적(Object-Oriented)이라면, JAX는 순수 함수형(Pure Functional)입니다. 상태(State)를 저장하지 않고, 입력이 같으면 항상 출력이 같아야 합니다. 이는 병렬 처리에 엄청난 이점을 줍니다.

2. jit, grad, vmap, pmap

JAX의 4대 천왕입니다.

  • @jit: 함수를 JIT 컴파일하여 XLA로 최적화합니다. 속도가 10배 이상 빨라집니다.
  • grad: 미분 함수를 자동으로 만들어줍니다. 역전파(Backpropagation)를 짤 필요가 없습니다.
  • vmap: for 루프 없이 벡터화 연산(Vectorization)을 수행합니다.
  • pmap: 여러 개의 TPU 코어에 자동으로 작업을 분산(Parallelization)시킵니다.
import jax
import jax.numpy as jnp

# 단순한 행렬 곱 함수
def raw_func(x, w):
    return jnp.dot(x, w)

# 1. JIT 컴파일 (속도 향상)
fast_func = jax.jit(raw_func)

# 2. 자동 미분 (기울기 계산)
grad_func = jax.grad(raw_func)

# 3. 병렬 처리 (8개 TPU 코어 동시 사용)
# 입력 데이터를 8개로 쪼개서 각 코어에 던져줍니다.
parallel_func = jax.pmap(fast_func)

Part 3. 실전 예제: TPU Pod에서 LLM 학습하기

1. 데이터 파이프라인 구축 (tf.data)

TPU는 연산 속도가 너무 빨라서, CPU가 데이터를 전처리해서 넘겨주는 속도가 병목(Bottleneck)이 될 수 있습니다. tf.data API를 사용하여 데이터를 미리 로드(Prefetch)하고 캐싱(Cache)해야 합니다.

2. 분산 학습 전략 (Data Parallelism vs Model Parallelism)

수십억 개의 파라미터를 가진 LLM은 TPU 한 대에 들어가지 않습니다.

  • FSDP (Fully Sharded Data Parallel): 모델의 파라미터, 그라디언트, 옵티마이저 상태를 모든 TPU 코어에 잘게 쪼개서 보관합니다. 메모리 효율을 극대화할 수 있습니다.
  • GSPMD (General Sharding): JAX만의 강력한 기능으로, 텐서를 논리적으로 쪼개는 규칙만 정해주면 컴파일러가 알아서 분산 처리를 수행합니다.

결론: 도구에 종속되지 않는 엔지니어가 되라

TPU와 JAX는 처음엔 낯설 수 있지만, 한 번 익숙해지면 GPU에서는 상상할 수 없었던 속도와 확장성을 경험하게 됩니다. 2026년의 AI 엔지니어링은 "누가 더 모델을 잘 만드냐"가 아니라 "누가 더 하드웨어 자원을 효율적으로 쓰느냐"의 싸움이 될 것입니다. 지금 당장 Google Colab을 켜고 TPU 런타임을 선택해 보세요. 새로운 세상이 열릴 것입니다.


자주 묻는 질문 (FAQ)

Q1: Colab 무료 버전에서도 TPU를 쓸 수 있나요?

네, Google Colab은 기본적으로 TPU v2-8을 무료로 제공합니다. 학습 목적으로 사용하기에 충분하며, 프로 버전을 구독하면 더 고성능의 TPU를 사용할 수 있습니다.

Q2: 디버깅이 너무 어렵지 않나요?

XLA 특성상(Lazy Execution) 에러가 발생한 정확한 지점을 찾기 어려울 때가 있습니다. 이럴 때는 JAX_DISABLE_JIT=1 환경 변수를 설정하여 강제로 한 줄씩 실행하게 만들면 디버깅이 쉬워집니다.

Q3: 윈도우(Windows)에서도 JAX를 쓸 수 있나요?

2024년까지는 공식 지원이 미흡했지만, 2025년부터 윈도우에 대한 실험적 지원이 강화되었습니다. 하지만 여전히 리눅스(Linux)나 WSL2 환경에서 사용하는 것이 정신 건강에 좋습니다.