‘Full Stack Optimization of Transformer Inference: a Survey’ 리뷰 시리즈 1편
Transformer Inference의 성능 병목을 이해하기 위해 ‘Full Stack Optimization of Transformer Inference: a Survey’ 논문을 읽고 있습니다. 이번 글은 본격적인 분석에 앞서 Transformer의 기본 구조와 이를 구성하는 주요 연산들을 간략히 소개합니다.
논문 링크: https://arxiv.org/abs/2302.14017
Transformer Architecture
Transformer는 입력을 인코딩하는 Encoder와 출력을 생성하는 Decoder로 구성됩니다. 두 모듈 모두 Multi-Head Attention (MHA) 과 Feed-Forward Network (FFN) 으로 이루어진 Transformer Block을 반복적으로 쌓아 구현됩니다.
Transformer Blocks - Encoder
Transformer는 여러 개의 Transformer Block으로 구성되며, 각 Block은 다음 두 가지 모듈로 구성됩니다:
- Multi-Head Attention (MHA)
- Feed-Forward Network (FFN)
각 모듈에는 Layer Normalization (LayerNorm) 과 Residual Connection이 함께 적용되어 안정적인 학습과 추론을 도와줍니다.
입력 시퀀스는 길이 $l$ 의 토큰으로 구성되며, 각 토큰은 차원 $d$ 의 벡터로 표현됩니다.
Multi-Head Attention (MHA)
MHA는 입력을 $h$개의 head로 나누어 병렬로 attention을 수행한 뒤, 결과를 병합합니다. 주요 단계는 다음과 같습니다:
- Projection: Query와 key 의 곱을 통해 attention score를 계산하고 softmax를 적용해 정규화
- Weighted sum: 정규화 된 attention score와 Value를 곱하여 각 head의 출력을 얻음
- Merge and Projection: 모든 head의 출력을 병합한 뒤 $W_{out}$ 과 곱하여 출력 생성
- Normalization and Residual Connection: LayerNorm 적용하고 입력을 더하여 최종 출력 생성
MHA에는 총 6개의 matrix multiplication이 포함됩니다. 이 중:
- 4개는 weight × activation 연산: $W_Q$, $W_K$, $W_V$, $W_{out}$ (초록색)
- 2개는 activation × activation 연산: $Q \cdot K^\top$, attention × $V$ (빨간색)
이 둘은 연산 방식은 같지만 데이터 재사용 및 성능 특성은 다릅니다. 자세한 내용은 다음 글에서 다룰 예정입니다.
Feed-Forward Network (FFN)
FFN은 다음과 같은 구조를 가집니다:
- Linear: 입력을 차원 $d \rightarrow d_{FFN}$으로 확장
- GELU Activation
- Linear: 차원을 다시 $d_{FFN} \rightarrow d$로 축소
- Residual + LayerNorm
FFN에는 2개의 linear 연산이 포함됩니다.
Linear 연산: Matmul Dimension 정리
아래 표는 각 모듈의 linear 연산과 연산 차원을 요약한 것입니다. Q·K^T
와 attention·V
는 **시퀀스 길이에 대해 quadratic하고, 나머지는 linear합니다.
Module | Operation | Matmul dim |
---|---|---|
MHA | ${W_Q}$ projection | $d$ x $d$ x $l$ |
${W_K}$ projection | $d$ x $d$ x $l$ | |
${W_V}$ projection | $d$ x $d$ x $l$ | |
query x key | $l$ x $d/h$ x $l$ | |
atten. score x value | $d/h$ x $l$ x $l$ | |
${W_{out}}$ projection | $d$ x $d$ x $l$ | |
FFN | ${W_1}$ projection | ${d_{FFN}}$ x $d$ x $l$ |
${W_2}$ projection | ${d}$ x ${d_{FFN}}$ x ${l}$ |
- $l$ : sequence length
- $d$ : hidden dimension
- $h$ : attention heads
Non-linear 연산
Transformer에는 다음과 같은 비선형 연산이 포함됩니다: Softmax, LayerNorm, GELU
이들은 연산량은 적지만 메모리 접근 패턴과 연산 방식 때문에 오버헤드가 발생하기 쉽습니다. 논문의 다음 그림은 각 non-linear 연산이 어느 축으로 작동하고 어떻게 동작하는지 잘 보여줍니다. 초록색 표시가 계산 방향을 나타냅니다.
(a) Softmax - 시퀀스 길이를 따라 정규화 Softmax는 주로 어텐션 점수를 정규화할 때 사용되며, 한 행(= 시퀀스 길이) 안에서 exponential 값을 모두 더한 후 각 항목을 나눕니다.
- 계산 방향: Sequence Length 축
- 특징:
- 전체 합을 먼저 계산해야 softmax를 적용할 수 있음
- 한 번 전체를 스캔해야 하므로 1-pass로는 부족
- 성능 이슈:
- 입력 전체를 순회하며 합을 구해야 하므로 메모리 접근 비용이 큼
(b) Layer Normalization – 히든 차원으로 정규화 LayerNorm은 Transformer에서 거의 모든 블록 뒤에 붙는 핵심 연산입니다. 각 토큰의 히든 벡터를 기준으로 평균과 표준편차를 계산하고 정규화합니다.
- 계산 방향: Hidden Dimension 축 (한 토큰 기준)
- 특징:
- 평균(μ), 표준편차(σ) 계산 → 정규화 → scale & shift
- 따라서 최소 3번 이상 입력을 순회해야 함
- 성능 이슈:
- 반복된 접근이 필요해서 메모리 대역폭 병목의 주요 원인 중 하나
(c) Batch Normalization – 채널 단위, 학습된 통계 사용 BatchNorm은 주로 CNN에서 사용되고, 훈련 중에 채널별 통계(μ, σ)를 학습하여 저장합니다. 추론 시에는 학습된 평균과 표준편차를 그대로 사용하므로 속도가 빠릅니다.
- 계산 방향: Channel 축 (Across batch/height/width)
- 특징:
- μ, σ 모두 훈련 중 학습되고 고정됨
- 추론 시 1-pass로 충분 (실시간 계산 없음)
- 성능 장점:
- 추론 시 가장 빠름
- 연산량 대비 메모리 접근도 적음
Transformer Block - Decoder
Decoder는 Encoder와 유사하지만 다음과 같은 차이가 있습니다:
- 입력은 한 토큰 ($d × 1$)만 처리
- 이전 토큰들과 concat하여 MHA 수행 (auto-regressive)
- 미래 토큰을 참조하지 않도록 Masked Attention 적용
- Attention은 matrix-vector 연산으로 수행됨 [ $d/h$ x $1$, $l$ x $d/h$ ]
Encoder vs. Decoder: 구조 비교
항목 | 인코더 (Encoder) | 디코더 (Decoder) |
---|---|---|
입력 처리 방식 | 전체 입력 시퀀스를 병렬적으로 처리 | 이전에 생성된 토큰들을 사용하여 순차적으로 처리 (auto-regressive) |
기본 구성 |
|
|
Self-Attention | 입력 전체에 대해 자유롭게 어텐션 가능 | 미래 정보 차단을 위한 마스킹(masked) 어텐션 사용 |
Cross-Attention 존재 여부 | 없음 (입력 자체만 사용) | 있음 (인코더의 출력과 디코더의 입력 간 어텐션) |
입력/출력 예시 | “I am a student” → 인코딩된 벡터 | “나는 학생입니다” 생성 (한 토큰씩) |
용도 | 자연어 이해 (e.g., 문장 분류, 요약) | 자연어 생성 (e.g., 기계 번역, 텍스트 생성) |
병렬성 | 높음 (모든 토큰을 동시에 처리 가능) | 낮음 (각 토큰이 이전 토큰에 의존함) |
캐시 활용 | 일반적으로 없음 | Key/Value 캐싱으로 이전 연산 결과 재사용 → 성능 향상 |
Transformer의 대략적인 구조와 각 블락을 이루는 연산 단위를 살펴보았습니다. 다음 글에서는 논문의 2.2장 Model Analysis를 기반으로 BERT 모델과 GPT 모델의 병목 지점을 알아보겠습니다.