Google TPU 실전 가이드(Part 2): JAX와 PyTorch XLA 마스터하기
GPU에 익숙한 당신을 위한 TPU 마이그레이션 가이드. XLA의 마법으로 연산 속도를 3배 높이는 비법을 공개합니다.
서론: 왜 코드를 바꿔야 하는가?
많은 개발자들이 "TPU를 쓰려면 텐서플로우(TensorFlow)를 다시 배워야 하나요?"라고 묻습니다. 2025년의 대답은 "아니요"입니다. PyTorch XLA의 성숙과 JAX의 대중화로 인해, 기존 GPU 코드를 거의 그대로 TPU에서 실행할 수 있습니다. 하지만 TPU의 성능을 100% 끌어내기 위해서는 그 뒤단에서 작동하는 XLA (Accelerated Linear Algebra) 컴파일러의 원리를 이해해야 합니다.
Part 1. PyTorch 유저를 위한 가이드 (PyTorch/XLA)
1. XLA란 무엇인가?
GPU는 명령을 하나씩 즉시 실행(Eager Execution)하는 데 능숙하지만, TPU는 전체 연산 그래프를 보고 최적화한 뒤 한 번에 실행(Graph Execution)할 때 가장 빠릅니다. XLA는 당신의 PyTorch 코드를 분석하여 TPU가 이해할 수 있는 최적의 기계어(HLO)로 번역해 주는 '통역사'입니다.
2. 코드 마이그레이션: 3줄의 마법
기존 PyTorch 코드에 단 몇 줄만 추가하면 TPU를 사용할 수 있습니다.
import torch
import torch_xla
import torch_xla.core.xla_model as xm
# 1. 디바이스 설정 (cuda 대신 xla_device 사용)
device = xm.xla_device()
# 2. 모델과 데이터를 TPU로 이동
model = MyNeuralNet().to(device)
data = data.to(device)
# 3. 학습 루프 내에서 Optimizer Step 변경
# optimizer.step() 대신 xm.optimizer_step(optimizer) 사용
# 이는 모든 그라디언트 계산이 끝날 때까지 기다렸다가(Lazy) 한 번에 업데이트하기 위함입니다.
def train_loop():
output = model(data)
loss = criterion(output, target)
loss.backward()
xm.optimizer_step(optimizer) # 핵심 변경 사항!
3. 성능 최적화 팁
- 고정된 입력 크기 (Static Shape): TPU는 입력 데이터의 크기가 변할 때마다 컴파일을 다시 해야 합니다(Re-compilation). 이는 엄청난 오버헤드를 유발하므로, 배치 사이즈를 고정하고 남는 공간은 패딩(Padding)으로 채우는 것이 좋습니다.
- Too Many Tensors 에러: XLA 그래프가 너무 커지면 메모리 부족이 발생할 수 있습니다.
xm.mark_step()을 적절히 호출하여 그래프를 잘라주세요.
Part 2. 고수들의 선택: JAX (Just After eXecution)
JAX는 구글이 작정하고 만든 'Numpy의 GPU/TPU 가속 버전'입니다. 딥마인드의 알파포드(AlphaFold)나 제미나이(Gemini) 모델도 JAX로 만들어졌습니다.
1. JAX의 핵심 철학: 함수형 프로그래밍
PyTorch가 객체 지향적(Object-Oriented)이라면, JAX는 순수 함수형(Pure Functional)입니다. 상태(State)를 저장하지 않고, 입력이 같으면 항상 출력이 같아야 합니다. 이는 병렬 처리에 엄청난 이점을 줍니다.
2. jit, grad, vmap, pmap
JAX의 4대 천왕입니다.
@jit: 함수를 JIT 컴파일하여 XLA로 최적화합니다. 속도가 10배 이상 빨라집니다.grad: 미분 함수를 자동으로 만들어줍니다. 역전파(Backpropagation)를 짤 필요가 없습니다.vmap: for 루프 없이 벡터화 연산(Vectorization)을 수행합니다.pmap: 여러 개의 TPU 코어에 자동으로 작업을 분산(Parallelization)시킵니다.
import jax
import jax.numpy as jnp
# 단순한 행렬 곱 함수
def raw_func(x, w):
return jnp.dot(x, w)
# 1. JIT 컴파일 (속도 향상)
fast_func = jax.jit(raw_func)
# 2. 자동 미분 (기울기 계산)
grad_func = jax.grad(raw_func)
# 3. 병렬 처리 (8개 TPU 코어 동시 사용)
# 입력 데이터를 8개로 쪼개서 각 코어에 던져줍니다.
parallel_func = jax.pmap(fast_func)
Part 3. 실전 예제: TPU Pod에서 LLM 학습하기
1. 데이터 파이프라인 구축 (tf.data)
TPU는 연산 속도가 너무 빨라서, CPU가 데이터를 전처리해서 넘겨주는 속도가 병목(Bottleneck)이 될 수 있습니다. tf.data API를 사용하여
데이터를 미리 로드(Prefetch)하고 캐싱(Cache)해야 합니다.
2. 분산 학습 전략 (Data Parallelism vs Model Parallelism)
수십억 개의 파라미터를 가진 LLM은 TPU 한 대에 들어가지 않습니다.
- FSDP (Fully Sharded Data Parallel): 모델의 파라미터, 그라디언트, 옵티마이저 상태를 모든 TPU 코어에 잘게 쪼개서 보관합니다. 메모리 효율을 극대화할 수 있습니다.
- GSPMD (General Sharding): JAX만의 강력한 기능으로, 텐서를 논리적으로 쪼개는 규칙만 정해주면 컴파일러가 알아서 분산 처리를 수행합니다.
결론: 도구에 종속되지 않는 엔지니어가 되라
TPU와 JAX는 처음엔 낯설 수 있지만, 한 번 익숙해지면 GPU에서는 상상할 수 없었던 속도와 확장성을 경험하게 됩니다. 2026년의 AI 엔지니어링은 "누가 더 모델을 잘 만드냐"가 아니라 "누가 더 하드웨어 자원을 효율적으로 쓰느냐"의 싸움이 될 것입니다. 지금 당장 Google Colab을 켜고 TPU 런타임을 선택해 보세요. 새로운 세상이 열릴 것입니다.
자주 묻는 질문 (FAQ)
Q1: Colab 무료 버전에서도 TPU를 쓸 수 있나요?
네, Google Colab은 기본적으로 TPU v2-8을 무료로 제공합니다. 학습 목적으로 사용하기에 충분하며, 프로 버전을 구독하면 더 고성능의 TPU를 사용할 수 있습니다.
Q2: 디버깅이 너무 어렵지 않나요?
XLA 특성상(Lazy Execution) 에러가 발생한 정확한 지점을 찾기 어려울 때가 있습니다. 이럴 때는
JAX_DISABLE_JIT=1 환경 변수를 설정하여 강제로 한 줄씩 실행하게 만들면 디버깅이 쉬워집니다.
Q3: 윈도우(Windows)에서도 JAX를 쓸 수 있나요?
2024년까지는 공식 지원이 미흡했지만, 2025년부터 윈도우에 대한 실험적 지원이 강화되었습니다. 하지만 여전히 리눅스(Linux)나 WSL2 환경에서 사용하는 것이 정신 건강에 좋습니다.