現代 AI 的計算難題
大型語言模型 (Large language models, LLMs) 已成為當代人工智能的支柱,展現出重塑產業和科學發現的卓越能力。它們在生成類人文本、驅動複雜對話代理,甚至協助複雜研究任務方面的熟練程度,使其成為不可或缺的工具。這些強大模型的核心是 transformer 架構,其設計特點是交替層。輸入數據被分解為 tokens,流經一系列 attention 機制(權衡不同 tokens 的重要性),然後是 feed-forward networks (FFNs)(處理收集到的信息)。這種分層的、序列化的處理方式是 transformers 學習和生成輸出的基礎。
然而,正是這種架構,雖然有效,但隨著模型規模和複雜性的膨脹,也帶來了日益嚴峻的挑戰。序列化的特性意味著每一層通常必須等待前一層完成計算後才能開始。這種逐步處理產生了固有的瓶頸,尤其是在推論 (inference) 階段——即使用訓練好的模型實際生成預測或文本的階段。隨著像驅動高級 AI 助手的模型納入數千億甚至數萬億個參數,推論所需的計算資源和時間急劇增加。這種不斷升級的需求轉化為顯著的延遲 (latency)(響應延遲)、降低的吞吐量 (throughput)(單位時間內處理的請求數量)以及不斷增加的運營成本,阻礙了最強大 LLMs 的廣泛部署和實時應用。因此,提高推論效率已成為 AI 研究社群的首要關注點,激發了對創新策略的探索,這些策略旨在簡化計算,同時不損害這些模型提供的卓越性能。核心挑戰在於減輕序列執行所施加的限制,尤其是在計算跨越多個 GPUs 的分佈式環境中,這會給處理時間增加通信開銷。
探索優化版圖:現有工具及其局限性
在使 LLMs 更精簡、更快速的持續努力中,研究人員開發了一套優化技術工具包。每種技術都提供了一條通往效率的途徑,但往往伴隨著自身的妥協,使得沒有任何單一方法能成為通用解決方案。理解這些權衡對於體會像 FFN Fusion 這樣的新方法的必要性至關重要。
一種突出的技術是 quantization(量化)。這涉及降低用於表示模型權重和激活值的數值精度。模型可能使用 16 位、8 位甚至更低位元的表示,而不是標準的 32 位浮點數。這直接縮小了模型的記憶體佔用,並能顯著加快計算速度,因為對較低精度數字的操作通常更快且需要更少能量。然而,量化並非沒有風險。降低精度可能導致信息損失,從而可能降低模型的準確性。這種風險在非常低的位元寬度下變得更加明顯,需要仔細實施,有時還需要重新訓練以減輕準確性下降。挑戰在於找到能夠最大化效率增益,同時將性能下降控制在可接受範圍內的最佳平衡點。
另一種常見策略是 pruning(剪枝)。該技術基於這樣一個原則:大型神經網絡中的許多參數可能是冗餘的,或者對最終輸出的貢獻很小。剪枝算法識別並移除這些不太重要的連接或神經元,從而產生一個更小、更稀疏的模型。與量化一樣,剪枝減少了記憶體需求和計算負載。然而,精確識別哪些參數是「安全」可移除的非常複雜。過於激進的剪枝可能會無意中移除關鍵組件,導致嚴重的準確性損失。剪枝後通常需要對模型進行微調以恢復性能,這增加了工作流程的複雜性。必須進行仔細的校準,以確保剪枝後的模型仍然有效。
一種在架構上更為獨特的方法是 Mixture-of-Experts (MoE) 模型。MoE 模型並非讓每個輸入都通過整個網絡進行處理,而是由多個「專家」子網絡(通常是 FFNs)組成。對於每個輸入 token,一個門控機制會動態選擇這些專家中的一小部分來執行計算。這種條件計算意味著對於任何給定的輸入,只有模型總參數的一小部分被激活,從而帶來顯著的計算節省,尤其是在訓練和推論非常大的模型時。MoE 模型可以擴展到數萬億個參數,同時保持合理的計算成本。然而,它們的效率高度依賴於工作負載。它們在處理非常大的批次大小 (batch sizes) 時表現出色,因為選擇性激活模式能帶來良好的硬件利用率。在較小或中等批次大小下,MoE 模型可能會因計算資源利用不足而受到影響,因為並行硬件可能無法被稀疏激活的專家持續佔滿。此外,實現和負載平衡 MoE 模型可能比部署標準的「密集」(dense) 架構更為複雜。
雖然 quantization、pruning 和 MoE 模型代表了 LLM 優化方面的寶貴進展,但它們固有的局限性凸顯了對替代或補充策略的需求。研究仍在繼續尋找能夠在各種場景下提供廣泛效率提升的方法,理想情況下,這些方法對準確性或實現複雜性的妥協更少,特別是對於那些因訓練和部署相對簡單而仍然流行的密集模型架構。
FFN Fusion:重新思考 Transformer 中的並行性
在這一優化技術的版圖中,NVIDIA 的研究人員引入了一種引人注目的新方法,稱為 FFN Fusion。這項技術直接挑戰了 transformer 架構中固有的序列瓶頸,不是通過改變參數或選擇性地激活部分,而是通過從根本上重新思考如何並行化計算序列。這項創新的起源是對深度 transformer 模型中 FFN 層行為的一個關鍵觀察。
研究人員使用名為 Puzzle 的診斷工具分析了大型模型的內部運作。當他們實驗性地移除 attention 層時,他們注意到模型通常保留了令人驚訝的長連續 FFN 層序列。更重要的是,分析顯示這些相鄰 FFNs 執行的計算經常表現出極小的相互依賴性。實質上,序列中一個 FFN 的輸出通常不會顯著改變緊隨其後的 FFN 所需的方向路徑或核心信息。這表明,這些傳統上一個接一個執行的 FFNs,可能具有同時、並行執行的潛力,而不會顯著干擾模型的整體功能。
這一洞見構成了 FFN Fusion 的基石。核心思想既優雅簡單又強大:識別具有低計算依賴性的連續 FFN 層序列,並將它們合併成一個單一的、更寬的 FFN 層,該層並行執行等效的計算。結構從 Input -> FFN1 -> FFN2 -> FFN3 -> Output
變為 Input -> Fused_FFN (相當於 FFN1+FFN2+FFN3 並行) -> Output
。這種架構轉換有效地縮短了網絡的序列深度,用一個單一、更寬的計算步驟取代了多個步驟。通過針對這些低依賴性的 FFN 序列,FFN Fusion 旨在減少延遲和計算成本,同時保留模型的表示能力和準確性。從 Llama-3.1-405B-Instruct 開發出 Ultra-253B-Base,正是這項技術潛力的主要展示。
架構煉金術:FFN Fusion 如何運作
FFN Fusion 背後的魔力在於其對前饋網絡底層數學結構的巧妙操縱。它不僅僅是將現有層並排運行;它涉及創建一個新的、統一的層,該層複製原始序列的集體行為,但以並行方式進行。
考慮一個包含 k 個連續 FFN 層的序列。在標準 transformer 中,輸入 x
通過 FFN1
,其輸出成為 FFN2
的輸入,依此類推,直到 FFNk
。每一步都明確依賴於前一步的完成。FFN Fusion 打破了這個依賴鏈。數學上,一個 FFN 通常涉及兩個線性變換,中間有一個非線性激活函數(如 GeLU 或 SwiGLU):FFN(x) = W_out * Activation(W_in * x)
。FFN Fusion 利用了線性變換通常可以組合這一事實。
融合過程通過串接 (concatenating) 各個 FFN 層的權重來實現。具體來說,連續 FFNs 的輸入權重矩陣 (W_in
) 被組合(例如,以塊對角線方式)成融合層的單個、更大的輸入權重矩陣。類似地,輸出權重矩陣 (W_out
) 被串接起來,形成一個單一、更寬的輸出權重矩陣。激活函數在這個更大的結構內按元素應用。這種構造確保了融合後的 FFN 在對應於原始 FFNs 的並行路徑上同時操作原始輸入 x
。來自這些並行路徑的輸出隨後通過串接的輸出權重結構被隱式地聚合起來。
理論基礎證實,只要原始層之間的依賴性確實很低,這種融合結構就能保持與原始 FFNs 序列相同的表示能力。關鍵在於識別哪些序列適合融合。為了系統地做到這一點,NVIDIA 的研究人員採用了一種依賴性分析 (dependency analysis) 技術。他們針對一組代表性的輸入 tokens,測量了連續 FFN 層輸出隱藏狀態之間的餘弦距離 (cosine distance)。較小的餘弦距離表明一個 FFN 的輸出向量指向與序列中下一個 FFN 的輸出向量非常相似的方向。這種相似性表明功能依賴性較低——第二個 FFN 並沒有大幅改變第一個 FFN 建立的信息表示。在各層之間始終表現出低餘弦距離的 FFN 序列被確定為融合的主要候選對象,因為合併它們不太可能擾亂模型的學習表示和整體性能。這種數據驅動的方法允許將 FFN Fusion 有針對性地應用於模型中最有效且干擾最小的部分。
從龐然大物到短跑選手:Ultra-253B-Base 的轉變
FFN Fusion 的實際威力通過其應用於當時公開的最大模型之一 Llama-3.1-405B-Instruct 得以生動展示。這個擁有 4050 億參數的模型,在推論方面代表著一項重大的計算任務。研究人員著手進行架構優化過程,將 FFN Fusion 與策略性剪枝相結合,創建了一個新的、更高效的模型,命名為 Ultra-253B-Base。
轉變過程涉及幾個步驟:
- 分析 (Analysis): 使用他們的依賴性分析工具(測量餘弦距離),研究人員在 Llama-405B 架構中識別出表現出低層間依賴性的連續 FFN 層序列。
- 融合 (Fusion): 這些被識別出的 FFN 序列隨後被融合成單個、更寬的 FFN 層,如前所述(串接權重)。這直接減少了網絡中的序列步驟數量。
- 剪枝 (Pruning): 同時或隨後,被認為不太關鍵的參數(可能通過標準剪枝技術識別,或受融合過程啟發)從模型中移除。
這種組合方法產生了 Ultra-253B-Base,一個擁有 2530 億參數的模型。這比原始的 405B 模型減少了超過 37% 的參數,是一個實質性的縮減。通過融合實現的架構變革是實現如此顯著規模縮減,同時旨在保持性能的關鍵。目標不僅僅是創建一個更小的模型,而是一個從根本上更快、計算成本更低的模型,這得益於 FFN Fusion 解鎖的增強並行性。這個案例研究作為一個關鍵的概念驗證,表明大型模型可以通過結構重組來大幅提高效率。
衡量收益:性能、速度與資源節省
任何優化技術的真正考驗在於其可衡量的影響。對於 Ultra-253B-Base 而言,將 FFN Fusion 和剪枝應用於 Llama-405B 基礎模型所產生的結果令人信服,顯示出在多個維度上的顯著改進,而沒有對能力造成實質性妥協。
推論速度與成本 (Inference Speed and Cost): 最顯著的收益體現在推論效率上。與原始的 405B 參數模型相比,Ultra-253B-Base 實現了:
- 推論延遲提高了 1.71 倍。這意味著模型可以更快地生成響應,這對於實時應用至關重要。
- 在批次大小為 32 時,每個 token 的計算成本降低了 35 倍。每個 token 的計算操作 (FLOPs) 大幅減少,直接轉化為更低的能耗和服務模型所需的硬件要求降低。
模型性能基準測試 (Model Performance Benchmarks): 至關重要的是,這些效率提升並未以犧牲模型的智能或能力為代價。Ultra-253B-Base 在一套標準的 LLM 基準測試中進行了嚴格評估,其得分與原始的、大得多的模型高度競爭,在某些情況下甚至超過了後者:
- MMLU (Massive Multitask Language Understanding): 85.17%
- MMLU-Pro (更具挑戰性的版本): 72.25%
- Arena Hard (針對困難提示的人類偏好評估): 84.92%
- HumanEval (代碼生成能力): 86.58%
- MT-Bench (多輪對話質量): 9.19
這些分數表明,儘管只有 2530 億參數,融合和剪枝後的模型仍保持了非常高的理解、推理、編碼能力和對話質量,與其 405B 參數的祖先相當。
記憶體效率 (Memory Efficiency): 除了計算速度和成本之外,FFN Fusion 還有助於節省記憶體。架構上的改變,可能與融合所帶來的其他優化相結合,使得推論期間所需的鍵值 (key-value, KV) 快取大小減少了 2 倍。KV 快取儲存中間激活值(attention 的鍵和值),會消耗大量 GPU 記憶體,尤其對於長輸入序列。將此需求減半使得在記憶體要求較低的硬件上運行模型成為可能,或者在相同的記憶體限制內處理更長的上下文。
這些可量化的結果突顯了 FFN Fusion 的有效性。它使得創建一個不僅更小,而且在速度、計算操作和記憶體使用方面從根本上更高效的模型成為可能,同時在具有挑戰性的基準測試中保持頂級性能。
保存知識:訓練與微調的關鍵作用
通過像 FFN Fusion 和剪枝這樣的技術,對像 Llama-405B 這樣的大型、預訓練語言模型進行架構修改,不可避免地會打破其學習參數的微妙平衡。雖然數學上的等效性旨在局部保留功能,但網絡的全局行為可能會發生變化。為了確保由此產生的 Ultra-253B-Base 模型不僅變得更高效,而且還能保持其高水平的性能,一個精心策劃的修改後訓練過程至關重要。
這個過程涉及兩個主要階段:
知識蒸餾 (Knowledge Distillation): 第一步是將知識從原始的、更大的模型(或合適的教師模型)轉移回修改後的架構中。這是通過蒸餾實現的,即訓練 Ultra-253B-Base 模型模仿教師模型的輸出或內部表示。此階段使用了大量的數據集,特別是 540 億個 tokens,並使用 8k 上下文窗口 (context window) 進行處理。蒸餾有助於融合和剪枝後的模型重新捕捉在架構更改過程中可能受到輕微干擾的細微差別和能力。
分階段微調 (Staged Fine-Tuning): 蒸餾之後,模型經歷了一系列微調階段,專門設計用於使其適應處理逐漸增長的上下文長度。這對於現代 LLMs 至關重要,因為它們通常需要基於廣泛的輸入來處理和生成文本。微調分階段進行:
- 在 16k 上下文窗口下進行微調。
- 在 32k 上下文窗口下進一步微調。
- 最後在 128k 上下文窗口下進行最終微調階段。
這種分階段的方法使模型能夠逐漸調整其參數,包括新形成的融合 FFN 層和優化的 KV 快取機制,以有效地管理非常長序列上的依賴關係和信息流。每個階段都建立在前一個階段的基礎上,確保在不同上下文大小下的穩定性和穩健性能。
這種細緻的訓練方案,結合了大規模蒸餾和分階段的長上下文微調,對於彌合架構效率和高保真性能之間的差距至關重要。它確保了 FFN Fusion 帶來的速度、成本和記憶體優勢不會損害模型在要求苛刻的基準測試中的準確性和能力。
更廣闊的視野:通用性與未來方向
將 Llama-405B 成功轉化為 Ultra-253B-Base 為 FFN Fusion 的潛力提供了有力證據,但其真正價值在於其更廣泛的適用性以及它為未來 LLM 設計提供的見解。研究表明,這不僅僅是一個僅適用於巨大模型的一次性技巧。
跨規模驗證 (Validation Across Scales): NVIDIA 的研究人員明確地在不同規模的模型上測試了 FFN Fusion 方法。他們成功地將該技術應用於 70B 參數模型,相對於其原始對應模型實現了類似的效率提升。他們還報告了在 49B 規模上的驗證,進一步強化了這樣一種觀點:FFN 的獨立性和融合潛力並非最大模型的專有特性,而可能是 transformer 架構更普遍的屬性,可能在 FFN 序列自然更深的更大規模上變得更加明顯。這表明 FFN Fusion 可能成為 LLM 優化工具箱中的標準工具,適用於各種規模的模型。
FFN 與完整區塊融合 (FFN vs. Full Block Fusion): 該研究還揭示了 FFN 層相對於 transformer 區塊內 attention 層的特定作用。雖然連續的 FFN 層通常表現出低依賴性,使其成為融合的理想選擇,但嘗試並行化整個 transformer 區塊(包括 attention 和 FFN 層)被證明更具挑戰性。分析表明涉及 attention 機制的相互依賴性更強。同時融合整個區塊會導致更顯著的性能下降,這表明 attention 層在跨 tokens 整合信息方面扮演著更關鍵、序列依賴性更強的角色。這一發現有助於劃定有效並行化的界限——FFN 序列是肥沃的土壤,而 attention 機制可能需要不同的優化策略。
對 LLM 架構的啟示 (Implications for LLM Architecture): FFN Fusion 提供的不僅僅是一種事後優化技術;它為設計未來的 LLMs 提供了寶貴的見解。發現 FFN 序列通常可以被視為可並行化的單元,挑戰了通常支撐 transformer 設計的嚴格序列假設。這可能會激發出從一開始就內在更具並行友好性的新架構。未來的模型可能會設計成具有明確意圖用於融合或並行執行的 FFN 結構,可能導致硬件-軟件協同設計,其中 GPU 架構被進一步優化以利用此類並行性。使用餘弦距離量化層間依賴性的系統方法也為理解和重新設計神經網絡結構提供了一個有價值的分析工具。通過證明專注於並行化現有組件的深思熟慮的架構重新設計可以實現顯著的效率提升,FFN Fusion 為開發既強大又在計算上更具可持續性的 LLMs 鋪平了道路。它突顯了一條緩解尖端 AI 不斷增長的資源需求的途徑。