๐Ÿง  BitNet ๋…ผ๋ฌธ ์š”์•ฝ โ€“ 1bit๋กœ ํ•™์Šต๋˜๋Š” LLM์˜ ๊ฐ€๋Šฅ์„ฑ

๐Ÿ“Œ ๋…ผ๋ฌธ ๊ฐœ์š”

  • ์ œ๋ชฉ: BitNet: Training Language Models in 1 Bit
  • ์ €์ž: Microsoft Azure AI ์—ฐ๊ตฌํŒ€
  • ๋ฐœํ‘œ ์‹œ์ : 2023๋…„ 10์›”
  • ํ•ต์‹ฌ ๊ธฐ์—ฌ:
    • ํ›ˆ๋ จ ๊ฐ€๋Šฅํ•œ 1bit weight matrix๋ฅผ ๋„์ž…
    • GEMM ๊ณฑ์…ˆ์„ ์ œ๊ฑฐํ•˜์—ฌ ํ•™์Šต/์ถ”๋ก  ๋น„์šฉ ์ตœ์†Œํ™”
    • ๊ฒฝ์Ÿ๋ ฅ ์žˆ๋Š” ์„ฑ๋Šฅ ์œ ์ง€ํ•˜๋ฉฐ๋„ 8~16๋ฐฐ ๋” ํšจ์œจ์ ์ธ ๋ชจ๋ธ ๊ตฌํ˜„

๐Ÿ“ 1. ์—ฐ๊ตฌ ๋ชฉ์ 

LLM์˜ ์—ฐ์‚ฐ/์ž์› ์†Œ๋น„๋ฅผ ๊ทน์ ์œผ๋กœ ์ค„์ด๊ธฐ ์œ„ํ•ด,
ํ›ˆ๋ จ๋ถ€ํ„ฐ 1bit ์ •๋ฐ€๋„๋กœ ์ˆ˜ํ–‰ ๊ฐ€๋Šฅํ•œ ๋ชจ๋ธ์„ ๋งŒ๋“ค๊ณ ์ž ํ–ˆ์Šต๋‹ˆ๋‹ค.

๊ธฐ์กด์˜ ํ•œ๊ณ„

  • ๊ธฐ์กด ์–‘์žํ™” ๊ธฐ์ˆ ์€ ๋Œ€๋ถ€๋ถ„ ์ถ”๋ก  ์ „์šฉ (post-training quantization)
  • ํ•™์Šต์€ ์—ฌ์ „ํžˆ FP16 ์ด์ƒ์˜ ์ •๋ฐ€๋„๊ฐ€ ํ•„์š”
  • ์–‘์žํ™”๋œ weight์—์„œ ํ›ˆ๋ จ ๋ถˆ์•ˆ์ •, ์„ฑ๋Šฅ ํ•˜๋ฝ, ๊ทธ๋ž˜๋””์–ธํŠธ ์†Œ์‹ค ๋ฌธ์ œ

๐Ÿ”ง 2. BitNet ๋ชจ๋ธ ๊ตฌ์กฐ

ํ•ต์‹ฌ ์„ค๊ณ„

  • Weight Matrix = sign(W) ร— ฮฑ
    • sign(W): -1 ๋˜๋Š” +1๋กœ ๊ตฌ์„ฑ๋œ Binary Matrix
    • ฮฑ: ์Šค์ผ€์ผ ํŒŒ๋ผ๋ฏธํ„ฐ (learnable)
  • GEMM ์—ฐ์‚ฐ์„ ์ •์ˆ˜ ๊ธฐ๋ฐ˜ ๋ถ€ํ˜ธ ๊ณฑ์…ˆ์œผ๋กœ ๋‹จ์ˆœํ™”
  • ํ™œ์„ฑ ํ•จ์ˆ˜: ๋น„์„ ํ˜•์„ฑ์„ ์ค„์ด๊ธฐ ์œ„ํ•ด GELU ๋Œ€์‹  Identity ๋˜๋Š” ReLU ์‚ฌ์šฉ

๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜ ์š”์•ฝ

  • GPT ๊ณ„์—ด Transformer ๊ตฌ์กฐ ๊ธฐ๋ฐ˜
  • LayerNorm โ†’ Linear (Binary) โ†’ Dropout
  • ์ •๋ฐ€๋„ ์ œํ•œ ์™ธ์—๋Š” ๊ตฌ์กฐ์ ์œผ๋กœ ๊ธฐ์กด GPT์™€ ์œ ์‚ฌ

๐Ÿ”ฌ 3. ์‹คํ—˜ ๊ตฌ์„ฑ

์‹คํ—˜ ๋ชจ๋ธ

  • BitNet-1b: 39M, 110M, 390M, 1.3B ํŒŒ๋ผ๋ฏธํ„ฐ ๋ฒ„์ „
  • ๋น„๊ต๊ตฐ: GPT2-small, GPT2-medium, LLaMA 1/2 7B

์‹คํ—˜ ๋ฐ์ดํ„ฐ

  • Pretraining: The Pile + C4 + Wikipedia ๋“ฑ ํ˜ผํ•ฉ
  • ํ‰๊ฐ€ ๋ฒค์น˜๋งˆํฌ: MMLU, HellaSwag, PIQA, Winogrande ๋“ฑ

๐Ÿ“Š 4. ์„ฑ๋Šฅ ๊ฒฐ๊ณผ

MMLU (์–ธ์–ด/์ถ”๋ก  ์ข…ํ•ฉ ํ‰๊ฐ€)

๋ชจ๋ธParam ์ˆ˜์ •๋ฐ€๋„MMLU ์ •ํ™•๋„
BitNet-1b-1.3B1.3B1bit64.7%
GPT-21.5BFP3258.9%
LLaMA-7B7BFP1667.0%

FLOPs ๊ณ„์‚ฐ ๋น„์šฉ ๋น„๊ต

  • BitNet์€ LLaMA-7B ๋Œ€๋น„ 16๋ฐฐ ์ ์€ FLOPs๋กœ ๊ฑฐ์˜ ์œ ์‚ฌํ•œ ์„ฑ๋Šฅ ๋‹ฌ์„ฑ
  • ์ถ”๋ก  ์‹œ GPU ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰์€ 4GB ์ดํ•˜๋กœ ์ž‘๋™ ๊ฐ€๋Šฅ

๐Ÿ“‰ 5. ์„ฑ๋Šฅ ์œ ์ง€ ๊ธฐ์ˆ 

Gradient Scaling

  • ฮฑ (์Šค์ผ€์ผ ํŒŒ๋ผ๋ฏธํ„ฐ)๋Š” ํ•™์Šต ๊ฐ€๋Šฅํ•˜์ง€๋งŒ, ๊ทธ๋ž˜๋””์–ธํŠธ๋Š” ํŠน์ˆ˜ํ•œ ๋ฐฉ๋ฒ•์œผ๋กœ ์•ˆ์ •ํ™”
  • PACT, DoReFa์™€ ๊ฐ™์€ ๊ธฐ์กด ์–‘์žํ™” ํ•™์Šต๋ฒ•์„ ์ผ๋ถ€ ์ฐธ์กฐ

Dropout, Residual, LayerNorm ์กฐ์ •

  • Gradient ํ๋ฆ„์„ ์œ ์ง€ํ•˜๊ธฐ ์œ„ํ•ด ์ผ๋ถ€ ๊ตฌ์กฐ ์š”์†Œ๋ฅผ ๊ฒฝ๋Ÿ‰ํ™” + ์žฌ๋ฐฐ์น˜

โœ… 6. ์ฃผ์š” ๊ธฐ์—ฌ ์š”์•ฝ

๊ธฐ์—ฌ์„ค๋ช…
์ตœ์ดˆ์˜ 1bit ํ•™์Šตํ˜• LLM ์ œ์•ˆํ›ˆ๋ จ๊ณผ ์ถ”๋ก  ๋ชจ๋‘ 1bit๋กœ ์ˆ˜ํ–‰
์ถ”๋ก  FLOPs / ๋ฉ”๋ชจ๋ฆฌ ๊ทน์†Œํ™”8~16๋ฐฐ ์ ˆ๊ฐ
GPT2๋ฅผ ๋Šฅ๊ฐ€ํ•˜๋Š” ์ •ํ™•๋„1.3B ๋ชจ๋ธ๋กœ GPT2 FP32๋ฅผ ์ดˆ๊ณผ
ํ™•์žฅ์„ฑ ์žˆ๋Š” ์‹คํ—˜ ์„ค๊ณ„๋‹ค์–‘ํ•œ ์‚ฌ์ด์ฆˆ ๋ชจ๋ธ์— ์ ์šฉ ๊ฐ€๋Šฅ

๐Ÿ“š ๋…ผ๋ฌธ ๊ฒฐ๋ก 

BitNet์€ ํ•™์Šต๊ณผ ์ถ”๋ก  ๋ชจ๋‘๋ฅผ 1bit๋กœ ์ˆ˜ํ–‰ ๊ฐ€๋Šฅํ•œ ๋ชจ๋ธ๋กœ,
โ€œ์ดˆ๊ฒฝ๋Ÿ‰ LLM ๊ตฌํ˜„์˜ ์‹คํ˜„ ๊ฐ€๋Šฅ์„ฑโ€์„ ์ œ์‹œํ–ˆ์Šต๋‹ˆ๋‹ค.

  • ์ดˆ์ €๋น„์šฉ ํ›ˆ๋ จ ์ธํ”„๋ผ ๊ฐ€๋Šฅ์„ฑ ์ œ์‹œ
  • ๋ชจ๋ฐ”์ผ/์ž„๋ฒ ๋””๋“œ์šฉ LLM ํ•™์Šต ์‹œ๋Œ€๋ฅผ ์—ด ์ˆ˜ ์žˆ์Œ
  • ์•„์ง ๋Œ€๊ทœ๋ชจ reasoning์ด๋‚˜ instruction tuning ์„ฑ๋Šฅ์€ ์ œํ•œ์ ์ด์ง€๋งŒ, ๊ธฐ์ˆ  ์ž ์žฌ๋ ฅ์€ ๋งค์šฐ ๋†’์Œ

๐Ÿ“Ž ์ถ”๊ฐ€ ์ฐธ๊ณ ์ž๋ฃŒ

  • ๋…ผ๋ฌธ ๋งํฌ: https://arxiv.org/abs/2310.11453
  • PyTorch ์ฝ”๋“œ (๋น„๊ณต์‹): GitHub ์ปค๋ฎค๋‹ˆํ‹ฐ ๊ธฐ๋ฐ˜ ๊ตฌํ˜„ ์žˆ์Œ
  • ๊ด€๋ จ ๊ธฐ์ˆ : DoReFa-Net, QLoRA, Binarized Neural Networks