‘Full Stack Optimization of Transformer Inference: a Survey’ 리뷰 시리즈 1편

Transformer Inference의 성능 병목을 이해하기 위해 ‘Full Stack Optimization of Transformer Inference: a Survey’ 논문을 읽고 있습니다. 이번 글은 본격적인 분석에 앞서 Transformer의 기본 구조와 이를 구성하는 주요 연산들을 간략히 소개합니다.

논문 링크: https://arxiv.org/abs/2302.14017

   


Transformer Architecture

Transformer는 입력을 인코딩하는 Encoder와 출력을 생성하는 Decoder로 구성됩니다. 두 모듈 모두 Multi-Head Attention (MHA)Feed-Forward Network (FFN) 으로 이루어진 Transformer Block을 반복적으로 쌓아 구현됩니다.

Transformer Blocks - Encoder

Transformer는 여러 개의 Transformer Block으로 구성되며, 각 Block은 다음 두 가지 모듈로 구성됩니다:

  • Multi-Head Attention (MHA)
  • Feed-Forward Network (FFN)

각 모듈에는 Layer Normalization (LayerNorm)Residual Connection이 함께 적용되어 안정적인 학습과 추론을 도와줍니다.

입력 시퀀스는 길이 $l$ 의 토큰으로 구성되며, 각 토큰은 차원 $d$ 의 벡터로 표현됩니다.

transformer encoder

 

Multi-Head Attention (MHA)

MHA는 입력을 $h$개의 head로 나누어 병렬로 attention을 수행한 뒤, 결과를 병합합니다. 주요 단계는 다음과 같습니다:

  1. Projection: Query와 key 의 곱을 통해 attention score를 계산하고 softmax를 적용해 정규화
  2. Weighted sum: 정규화 된 attention score와 Value를 곱하여 각 head의 출력을 얻음
  3. Merge and Projection: 모든 head의 출력을 병합한 뒤 $W_{out}$ 과 곱하여 출력 생성
  4. Normalization and Residual Connection: LayerNorm 적용하고 입력을 더하여 최종 출력 생성

MHA에는 총 6개의 matrix multiplication이 포함됩니다. 이 중:

  • 4개는 weight × activation 연산: $W_Q$, $W_K$, $W_V$, $W_{out}$ (초록색)
  • 2개는 activation × activation 연산: $Q \cdot K^\top$, attention × $V$ (빨간색)

이 둘은 연산 방식은 같지만 데이터 재사용 및 성능 특성은 다릅니다. 자세한 내용은 다음 글에서 다룰 예정입니다.

Feed-Forward Network (FFN)

FFN은 다음과 같은 구조를 가집니다:

  1. Linear: 입력을 차원 $d \rightarrow d_{FFN}$으로 확장
  2. GELU Activation
  3. Linear: 차원을 다시 $d_{FFN} \rightarrow d$로 축소
  4. Residual + LayerNorm

FFN에는 2개의 linear 연산이 포함됩니다.

   

Linear 연산: Matmul Dimension 정리

아래 표는 각 모듈의 linear 연산과 연산 차원을 요약한 것입니다. Q·K^Tattention·V는 **시퀀스 길이에 대해 quadratic하고, 나머지는 linear합니다.

ModuleOperationMatmul dim
MHA${W_Q}$ projection$d$ x $d$ x $l$
${W_K}$ projection$d$ x $d$ x $l$
${W_V}$ projection$d$ x $d$ x $l$
query x key$l$ x $d/h$ x $l$
atten. score x value$d/h$ x $l$ x $l$
${W_{out}}$ projection$d$ x $d$ x $l$
FFN${W_1}$ projection${d_{FFN}}$ x $d$ x $l$
${W_2}$ projection${d}$ x ${d_{FFN}}$ x ${l}$
  • $l$ : sequence length
  • $d$ : hidden dimension
  • $h$ : attention heads

   

Non-linear 연산

Transformer에는 다음과 같은 비선형 연산이 포함됩니다: Softmax, LayerNorm, GELU

이들은 연산량은 적지만 메모리 접근 패턴연산 방식 때문에 오버헤드가 발생하기 쉽습니다. 논문의 다음 그림은 각 non-linear 연산이 어느 축으로 작동하고 어떻게 동작하는지 잘 보여줍니다. 초록색 표시가 계산 방향을 나타냅니다.  

non-linear

 

(a) Softmax - 시퀀스 길이를 따라 정규화 Softmax는 주로 어텐션 점수를 정규화할 때 사용되며, 한 행(= 시퀀스 길이) 안에서 exponential 값을 모두 더한 후 각 항목을 나눕니다.

  • 계산 방향: Sequence Length 축
  • 특징:
    • 전체 합을 먼저 계산해야 softmax를 적용할 수 있음
    • 한 번 전체를 스캔해야 하므로 1-pass로는 부족
  • 성능 이슈:
    • 입력 전체를 순회하며 합을 구해야 하므로 메모리 접근 비용이 큼

(b) Layer Normalization – 히든 차원으로 정규화 LayerNorm은 Transformer에서 거의 모든 블록 뒤에 붙는 핵심 연산입니다. 각 토큰의 히든 벡터를 기준으로 평균과 표준편차를 계산하고 정규화합니다.

  • 계산 방향: Hidden Dimension 축 (한 토큰 기준)
  • 특징:
    • 평균(μ), 표준편차(σ) 계산 → 정규화 → scale & shift
    • 따라서 최소 3번 이상 입력을 순회해야 함
  • 성능 이슈:
    • 반복된 접근이 필요해서 메모리 대역폭 병목의 주요 원인 중 하나

(c) Batch Normalization – 채널 단위, 학습된 통계 사용 BatchNorm은 주로 CNN에서 사용되고, 훈련 중에 채널별 통계(μ, σ)를 학습하여 저장합니다. 추론 시에는 학습된 평균과 표준편차를 그대로 사용하므로 속도가 빠릅니다.

  • 계산 방향: Channel 축 (Across batch/height/width)
  • 특징:
    • μ, σ 모두 훈련 중 학습되고 고정됨
    • 추론 시 1-pass로 충분 (실시간 계산 없음)
  • 성능 장점:
    • 추론 시 가장 빠름
    • 연산량 대비 메모리 접근도 적음

Transformer Block - Decoder

Decoder는 Encoder와 유사하지만 다음과 같은 차이가 있습니다:

  • 입력은 한 토큰 ($d × 1$)만 처리
  • 이전 토큰들과 concat하여 MHA 수행 (auto-regressive)
  • 미래 토큰을 참조하지 않도록 Masked Attention 적용
  • Attention은 matrix-vector 연산으로 수행됨 [ $d/h$ x $1$, $l$ x $d/h$ ]

decoder

   

Encoder vs. Decoder: 구조 비교

encoder-decoder

 

항목인코더 (Encoder)디코더 (Decoder)
입력 처리 방식전체 입력 시퀀스를 병렬적으로 처리이전에 생성된 토큰들을 사용하여 순차적으로 처리 (auto-regressive)
기본 구성
  • Multi-Head Self Attention (MHA)
  • Feed-Forward Network (FFN)
  • Masked Multi-Head Self Attention
  • Cross Attention (인코더 출력과 함께)
  • Feed-Forward Network (FFN)
Self-Attention입력 전체에 대해 자유롭게 어텐션 가능미래 정보 차단을 위한 마스킹(masked) 어텐션 사용
Cross-Attention 존재 여부없음 (입력 자체만 사용)있음 (인코더의 출력과 디코더의 입력 간 어텐션)
입력/출력 예시“I am a student” → 인코딩된 벡터“나는 학생입니다” 생성 (한 토큰씩)
용도자연어 이해 (e.g., 문장 분류, 요약)자연어 생성 (e.g., 기계 번역, 텍스트 생성)
병렬성높음 (모든 토큰을 동시에 처리 가능)낮음 (각 토큰이 이전 토큰에 의존함)
캐시 활용일반적으로 없음Key/Value 캐싱으로 이전 연산 결과 재사용 → 성능 향상

Transformer의 대략적인 구조와 각 블락을 이루는 연산 단위를 살펴보았습니다. 다음 글에서는 논문의 2.2장 Model Analysis를 기반으로 BERT 모델과 GPT 모델의 병목 지점을 알아보겠습니다.