๐ ๋ ผ๋ฌธ ๊ฐ์
- ์ ๋ชฉ: BitNet: Training Language Models in 1 Bit
- ์ ์: Microsoft Azure AI ์ฐ๊ตฌํ
- ๋ฐํ ์์ : 2023๋ 10์
- ํต์ฌ ๊ธฐ์ฌ:
- ํ๋ จ ๊ฐ๋ฅํ 1bit weight matrix๋ฅผ ๋์
- GEMM ๊ณฑ์ ์ ์ ๊ฑฐํ์ฌ ํ์ต/์ถ๋ก ๋น์ฉ ์ต์ํ
- ๊ฒฝ์๋ ฅ ์๋ ์ฑ๋ฅ ์ ์งํ๋ฉฐ๋ 8~16๋ฐฐ ๋ ํจ์จ์ ์ธ ๋ชจ๋ธ ๊ตฌํ
๐ 1. ์ฐ๊ตฌ ๋ชฉ์
LLM์ ์ฐ์ฐ/์์ ์๋น๋ฅผ ๊ทน์ ์ผ๋ก ์ค์ด๊ธฐ ์ํด,
ํ๋ จ๋ถํฐ 1bit ์ ๋ฐ๋๋ก ์ํ ๊ฐ๋ฅํ ๋ชจ๋ธ์ ๋ง๋ค๊ณ ์ ํ์ต๋๋ค.
๊ธฐ์กด์ ํ๊ณ
- ๊ธฐ์กด ์์ํ ๊ธฐ์ ์ ๋๋ถ๋ถ ์ถ๋ก ์ ์ฉ (post-training quantization)
- ํ์ต์ ์ฌ์ ํ FP16 ์ด์์ ์ ๋ฐ๋๊ฐ ํ์
- ์์ํ๋ weight์์ ํ๋ จ ๋ถ์์ , ์ฑ๋ฅ ํ๋ฝ, ๊ทธ๋๋์ธํธ ์์ค ๋ฌธ์
๐ง 2. BitNet ๋ชจ๋ธ ๊ตฌ์กฐ
ํต์ฌ ์ค๊ณ
- Weight Matrix = sign(W) ร ฮฑ
sign(W)
: -1 ๋๋ +1๋ก ๊ตฌ์ฑ๋ Binary Matrixฮฑ
: ์ค์ผ์ผ ํ๋ผ๋ฏธํฐ (learnable)
- GEMM ์ฐ์ฐ์ ์ ์ ๊ธฐ๋ฐ ๋ถํธ ๊ณฑ์ ์ผ๋ก ๋จ์ํ
- ํ์ฑ ํจ์: ๋น์ ํ์ฑ์ ์ค์ด๊ธฐ ์ํด GELU ๋์ Identity ๋๋ ReLU ์ฌ์ฉ
๋ชจ๋ธ ์ํคํ ์ฒ ์์ฝ
- GPT ๊ณ์ด Transformer ๊ตฌ์กฐ ๊ธฐ๋ฐ
- LayerNorm โ Linear (Binary) โ Dropout
- ์ ๋ฐ๋ ์ ํ ์ธ์๋ ๊ตฌ์กฐ์ ์ผ๋ก ๊ธฐ์กด GPT์ ์ ์ฌ
๐ฌ 3. ์คํ ๊ตฌ์ฑ
์คํ ๋ชจ๋ธ
- BitNet-1b: 39M, 110M, 390M, 1.3B ํ๋ผ๋ฏธํฐ ๋ฒ์
- ๋น๊ต๊ตฐ: GPT2-small, GPT2-medium, LLaMA 1/2 7B
์คํ ๋ฐ์ดํฐ
- Pretraining: The Pile + C4 + Wikipedia ๋ฑ ํผํฉ
- ํ๊ฐ ๋ฒค์น๋งํฌ: MMLU, HellaSwag, PIQA, Winogrande ๋ฑ
๐ 4. ์ฑ๋ฅ ๊ฒฐ๊ณผ
MMLU (์ธ์ด/์ถ๋ก ์ข ํฉ ํ๊ฐ)
๋ชจ๋ธ | Param ์ | ์ ๋ฐ๋ | MMLU ์ ํ๋ |
---|---|---|---|
BitNet-1b-1.3B | 1.3B | 1bit | 64.7% |
GPT-2 | 1.5B | FP32 | 58.9% |
LLaMA-7B | 7B | FP16 | 67.0% |
FLOPs ๊ณ์ฐ ๋น์ฉ ๋น๊ต
- BitNet์ LLaMA-7B ๋๋น 16๋ฐฐ ์ ์ FLOPs๋ก ๊ฑฐ์ ์ ์ฌํ ์ฑ๋ฅ ๋ฌ์ฑ
- ์ถ๋ก ์ GPU ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ 4GB ์ดํ๋ก ์๋ ๊ฐ๋ฅ
๐ 5. ์ฑ๋ฅ ์ ์ง ๊ธฐ์
Gradient Scaling
ฮฑ
(์ค์ผ์ผ ํ๋ผ๋ฏธํฐ)๋ ํ์ต ๊ฐ๋ฅํ์ง๋ง, ๊ทธ๋๋์ธํธ๋ ํน์ํ ๋ฐฉ๋ฒ์ผ๋ก ์์ ํ- PACT, DoReFa์ ๊ฐ์ ๊ธฐ์กด ์์ํ ํ์ต๋ฒ์ ์ผ๋ถ ์ฐธ์กฐ
Dropout, Residual, LayerNorm ์กฐ์
- Gradient ํ๋ฆ์ ์ ์งํ๊ธฐ ์ํด ์ผ๋ถ ๊ตฌ์กฐ ์์๋ฅผ ๊ฒฝ๋ํ + ์ฌ๋ฐฐ์น
โ 6. ์ฃผ์ ๊ธฐ์ฌ ์์ฝ
๊ธฐ์ฌ | ์ค๋ช |
---|---|
์ต์ด์ 1bit ํ์ตํ LLM ์ ์ | ํ๋ จ๊ณผ ์ถ๋ก ๋ชจ๋ 1bit๋ก ์ํ |
์ถ๋ก FLOPs / ๋ฉ๋ชจ๋ฆฌ ๊ทน์ํ | 8~16๋ฐฐ ์ ๊ฐ |
GPT2๋ฅผ ๋ฅ๊ฐํ๋ ์ ํ๋ | 1.3B ๋ชจ๋ธ๋ก GPT2 FP32๋ฅผ ์ด๊ณผ |
ํ์ฅ์ฑ ์๋ ์คํ ์ค๊ณ | ๋ค์ํ ์ฌ์ด์ฆ ๋ชจ๋ธ์ ์ ์ฉ ๊ฐ๋ฅ |
๐ ๋ ผ๋ฌธ ๊ฒฐ๋ก
BitNet์ ํ์ต๊ณผ ์ถ๋ก ๋ชจ๋๋ฅผ 1bit๋ก ์ํ ๊ฐ๋ฅํ ๋ชจ๋ธ๋ก,
โ์ด๊ฒฝ๋ LLM ๊ตฌํ์ ์คํ ๊ฐ๋ฅ์ฑโ์ ์ ์ํ์ต๋๋ค.
- ์ด์ ๋น์ฉ ํ๋ จ ์ธํ๋ผ ๊ฐ๋ฅ์ฑ ์ ์
- ๋ชจ๋ฐ์ผ/์๋ฒ ๋๋์ฉ LLM ํ์ต ์๋๋ฅผ ์ด ์ ์์
- ์์ง ๋๊ท๋ชจ reasoning์ด๋ instruction tuning ์ฑ๋ฅ์ ์ ํ์ ์ด์ง๋ง, ๊ธฐ์ ์ ์ฌ๋ ฅ์ ๋งค์ฐ ๋์
๐ ์ถ๊ฐ ์ฐธ๊ณ ์๋ฃ
- ๋ ผ๋ฌธ ๋งํฌ: https://arxiv.org/abs/2310.11453
- PyTorch ์ฝ๋ (๋น๊ณต์): GitHub ์ปค๋ฎค๋ํฐ ๊ธฐ๋ฐ ๊ตฌํ ์์
- ๊ด๋ จ ๊ธฐ์ : DoReFa-Net, QLoRA, Binarized Neural Networks