亞麻(Flax)是一個先進的神經網絡庫,建立在 JAX 之上,旨在為研究人員和開發者提供靈活且高效的工具,來構建複雜的機器學習模型。亞麻與 JAX 的無縫整合使得自動微分、即時編譯(Just-In-Time, JIT)和硬體加速器的支持成為可能,這使得它非常適合實驗研究和生產環境。
這篇文章將探討亞麻的核心特性,將其與其他框架進行比較,並提供一個使用亞麻功能性編程方法的實用範例。
學習目標
- 了解亞麻作為一個高效、靈活的神經網絡庫,適合用於研究和生產。
- 學習亞麻的功能性編程方法如何改善機器學習模型的可重現性和調試。
- 探索亞麻的 Linen API,以有效構建和管理複雜的神經網絡架構。
- 發現亞麻與 Optax 的整合,以簡化訓練工作流程中的優化和梯度處理。
- 獲得有關亞麻的參數管理、狀態處理和模型序列化的見解,以便更好地部署和持久化。
這篇文章是數據科學博客馬拉松的一部分。
什麼是亞麻(Flax)?
亞麻是一個高效的神經網絡庫,建立在 JAX 之上,旨在為研究人員和開發者提供構建尖端機器學習模型所需的靈活性和效率。亞麻利用 JAX 的能力,如自動微分和即時編譯,為研究和生產環境提供了一個強大的框架。
比較:亞麻(Flax)與其他框架
亞麻與其他深度學習框架(如 TensorFlow、PyTorch 和 Keras)相比,具有獨特的設計原則:
- 功能性編程範式:亞麻採用純粹的功能風格,將模型視為沒有隱藏狀態的純函數。這種方法增強了可重現性和調試的便利性。
- 與 JAX 的組合性:通過利用 JAX 的轉換(jit、grad、vmap),亞麻允許模型計算的無縫優化和並行化。
- 模組化:亞麻的模組系統促進可重用組件的構建,使得從簡單的構建塊構建複雜架構變得更容易。
- 性能:基於 JAX,亞麻繼承了其高效的性能能力,包括對 GPU 和 TPU 等硬體加速器的支持。
亞麻的主要特性
- Linen API:亞麻的高級 API 用於定義神經網絡層和模型,強調清晰性和易用性。
- 參數管理:使用不可變數據結構高效處理模型參數,促進功能純粹性。
- 與 Optax 的整合:與 Optax 的無縫兼容,這是一個用於 JAX 的梯度處理和優化庫。
- 序列化:提供強大的工具來保存和加載模型參數,便於模型的持久化和部署。
- 可擴展性:能夠創建自定義模組並將其與其他基於 JAX 的庫集成。
設置環境
在使用亞麻構建模型之前,必須設置開發環境並安裝必要的庫。我們將安裝最新版本的 JAX、JAXlib 和亞麻。JAX 是提供高效數值計算的基礎,而亞麻則在其基礎上提供靈活的神經網絡框架。
# 安裝最新的 JAXlib 版本。
!pip install --upgrade -q pip jax jaxlib
# 安裝亞麻:
!pip install --upgrade -q git+https://github.com/google/flax.git
import jax
from typing import Any, Callable, Sequence
from jax import random, numpy as jnp
import flax
from flax import linen as nn
解釋:
- JAX 和 JAXlib:JAX 是一個高效的數值計算和自動微分庫,而 JAXlib 提供 JAX 所需的低級實現。
- 亞麻:一個建立在 JAX 之上的神經網絡庫,提供靈活且高效的 API 用於構建模型。
- 亞麻的 Linen API:作為 nn 導入,Linen 是亞麻的高級 API,用於定義神經網絡層和模型。
亞麻基礎:線性回歸範例
線性回歸是一種基礎的機器學習技術,用於建模依賴變數與一個或多個獨立變數之間的關係。在亞麻中,我們可以使用單個密集(完全連接)層來實現線性回歸。
模型實例化
首先,讓我們使用亞麻的 Linen API 實例化一個密集層。
# 我們創建一個密集層實例(以 'features' 參數作為輸入)
model = nn.Dense(features=5
解釋:
- nn.Dense:表示一個密集(完全連接)神經網絡層,具有指定數量的輸出特徵。在這裡,我們創建了一個具有 5 個輸出特徵的密集層。
參數初始化
在亞麻中,模型參數不會存儲在模型內部。相反,您需要使用隨機鍵和虛擬輸入數據來初始化它們。這個過程利用了亞麻的延遲初始化,其中參數形狀根據輸入數據推斷。
key1, key2 = random.split(random.key(0))
x = random.normal(key1, (10,)) # 虛擬輸入數據
params = model.init(key2, x) # 初始化調用
jax.tree_util.tree_map(lambda x: x.shape, params) # 檢查輸出形狀
解釋:
- 隨機鍵分割:JAX 使用純函數並通過顯式的 PRNG 鍵處理隨機性。我們將初始鍵分割為兩個,以便獨立生成隨機數。
- 虛擬輸入數據:使用形狀為 (10,) 的虛擬輸入 x 來觸發參數初始化過程中的形狀推斷。
- model.init:根據輸入數據的形狀和隨機鍵初始化模型的參數。
- tree_map:將函數應用於參數樹中的每個葉子,以檢查形狀。
注意:JAX 和亞麻像 NumPy 一樣是基於行的系統,這意味著向量被表示為行向量而不是列向量。這可以在這裡的內核形狀中看到。
前向傳播
在初始化參數後,您可以執行前向傳播以計算給定輸入的模型輸出。
model.apply(params, x)
解釋:
- model.apply:使用提供的參數和輸入數據執行模型的前向傳播。
梯度下降訓練
在模型初始化後,我們可以執行梯度下降來訓練我們的線性回歸模型。我們將生成合成數據並定義均方誤差(MSE)損失函數。
# 設置問題維度。
n_samples = 20
x_dim = 10
y_dim = 5
# 生成隨機的真實 W 和 b。
key = random.key(0)
k1, k2 = random.split(key)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2, (y_dim,))
# 將參數存儲在 FrozenDict pytree 中。
true_params = flax.core.freeze({'params': {'bias': b, 'kernel': W}})
# 生成帶有額外噪聲的樣本。
key_sample, key_noise = random.split(k1)
x_samples = random.normal(key_sample, (n_samples, x_dim))
y_samples = jnp.dot(x_samples, W) + b + 0.1 * random.normal(key_noise, (n_samples, y_dim))
print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)
解釋:
- 問題維度:定義樣本數(n_samples)、輸入維度(x_dim)和輸出維度(y_dim)。
- 真實參數:隨機初始化生成合成目標數據所需的真實權重 W 和偏差 b。
- FrozenDict:亞麻使用 FrozenDict 來確保參數的不可變性。
- 數據生成:創建帶有噪聲的合成輸入數據 x_samples 和目標數據 y_samples,以模擬現實場景。
定義 MSE 損失函數
接下來,我們將定義均方誤差(MSE)損失函數,並使用 JAX 的 JIT 編譯進行梯度下降以提高效率。
# 定義 MSE 損失函數。
@jax.jit
def mse(params, x_batched, y_batched):
# 定義單對(x, y)的平方損失
def squared_error(x, y):
pred = model.apply(params, x)
return jnp.inner(y - pred, y - pred) / 2.0
# 向量化以計算所有樣本的平均損失。
return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0)
解釋:
- @jax.jit:JIT 編譯 mse 函數以優化性能。
- squared_error:計算預測值與真實值之間的平方誤差。
- jax.vmap:向量化 squared_error 函數以高效地應用於所有樣本。
- 均方誤差:計算所有樣本的平均損失。
梯度下降參數和更新函數
我們將設置學習率並定義計算梯度和更新模型參數的函數。
learning_rate = 0.3 # 梯度步長。
print('Loss for "true" W,b: ', mse(true_params, x_samples, y_samples))
loss_grad_fn = jax.value_and_grad(mse)
@jax.jit
def update_params(params, learning_rate, grads):
params = jax.tree_util.tree_map(
lambda p, g: p - learning_rate * g, params, grads)
return params
for i in range(101):
# 執行一次梯度更新。
loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
params = update_params(params, learning_rate, grads)
if i % 10 == 0:
print(f'Loss step {i}: ', loss_val)
解釋:
- 學習率:決定參數更新過程中的步長。
- loss_grad_fn:使用 jax.value_and_grad 計算損失值及其相對於參數的梯度。
- update_params:通過減去學習率和梯度的乘積來更新模型參數。
訓練循環
最後,我們將執行訓練循環,進行參數更新並監控損失。
import optax
tx = optax.adam(learning_rate=learning_rate)
opt_state = tx.init(params)
loss_grad_fn = jax.value_and_grad(mse)
for i in range(101):
loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
updates, opt_state = tx.update(grads, opt_state)
params = optax.apply_updates(params, updates)
if i % 10 == 0:
print('Loss step {}: '.format(i), loss_val)
解釋:
- Optax 優化器:使用指定的學習率初始化 Adam 優化器。
- 優化器狀態:維護優化器所需的狀態(例如,Adam 的動量項)。
- tx.update:根據梯度和優化器狀態計算參數更新。
- optax.apply_updates:將計算出的更新應用於模型參數。
- 訓練循環:迭代訓練步驟,更新參數並監控損失。
使用 Optax 的好處:
- 簡單性:抽象掉手動梯度更新,減少冗餘代碼。
- 靈活性:支持多種優化算法和梯度變換。
- 組合性:允許將簡單的梯度變換組合成更複雜的優化器。
序列化:保存和加載模型
訓練後,您可能希望保存模型的參數以便日後使用或部署。亞麻提供強大的序列化工具來促進這一過程。
from flax import serialization
# 將參數序列化為字節。
bytes_output = serialization.to_bytes(params)
# 將參數序列化為字典。
dict_output = serialization.to_state_dict(params)
print('Dict output')
print(dict_output)
print('Bytes output')
print(bytes_output)
解釋:
- serialization.to_bytes:將參數樹轉換為字節字符串,適合存儲或傳輸。
- serialization.to_state_dict:將參數樹轉換為字典,便於保存為 JSON 或其他人類可讀格式。
反序列化模型
使用 from_bytes 方法和參數模板將模型參數加載回來。
# 使用序列化的字節加載模型。
loaded_params = serialization.from_bytes(params, bytes_output)
定義自定義模型
亞麻的靈活性在於定義超出簡單線性回歸的自定義模型。本節將探討如何創建自定義的多層感知器(MLP)並在模型中管理狀態。
模組基礎
亞麻中的模組是 nn.Module 的子類,表示層或整個模型。以下是如何定義一個具有密集層和激活函數序列的自定義 MLP。
class ExplicitMLP(nn.Module):
features: Sequence[int]
def setup(self):
# 我們自動知道如何處理子模組的列表和字典
self.layers = [nn.Dense(feat) for feat in self.features]
def __call__(self, inputs):
x = inputs
for i, lyr in enumerate(self.layers):
x = lyr(x)
if i != len(self.layers) - 1:
x = nn.relu(x)
return x
key1, key2 = random.split(random.key(0), 2)
x = random.uniform(key1, (4,4))
model = ExplicitMLP(features=[3,4,5])
params = model.init(key2, x)
y = model.apply(params, x)
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))
print('output:\n', y)
解釋:
- ExplicitMLP:一個簡單的多層感知器,為每層指定特徵。
- setup():註冊子模組(密集層),亞麻會跟蹤這些模組以進行參數初始化和序列化。
- __call__():定義前向傳播,應用每一層和 ReLU 激活,除了最後一層。
嘗試直接調用模型而不使用 apply 將導致錯誤:
try:
y = model(x) # 返回錯誤
except AttributeError as e:
print(e)
解釋:
- model.apply:亞麻的功能性 API 需要使用 apply 來執行給定參數的模型前向傳播。
使用 @nn.compact 裝飾器
另一種更簡潔的定義子模組的方法是使用 @nn.compact 裝飾器,這可以在 __call__ 方法中使用。
class SimpleMLP(nn.Module):
features: Sequence[int]
@nn.compact
def __call__(self, inputs):
x = inputs
for i, feat in enumerate(self.features):
x = nn.Dense(feat, name=f'layers_{i}')(x)
if i != len(self.features) - 1:
x = nn.relu(x)
return x
key1, key2 = random.split(random.key(0), 2)
x = random.uniform(key1, (4,4))
model = SimpleMLP(features=[3,4,5])
params = model.init(key2, x)
y = model.apply(params, x)
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))
print('output:\n', y)
解釋:
- @nn.compact:一個裝飾器,允許在 __call__ 方法中定義子模組和參數,使模型定義更簡潔易讀。
- 命名子模組:可選地為子模組提供名稱以增加清晰度;否則,亞麻會自動生成名稱,如 “Dense_0”、“Dense_1”等。
setup 和 @nn.compact 的區別:
- setup 方法:允許在 __call__ 方法外部定義子模組。對於具有多個方法或動態結構的模組非常有用。
- @nn.compact 裝飾器:允許在 __call__ 方法中定義子模組。對於簡單和固定的架構更為簡潔。
模組參數
有時,您可能需要定義亞麻未提供的自定義層。以下是如何使用 @nn.compact 方法從頭開始創建一個簡單的密集層。
class SimpleDense(nn.Module):
features: int
kernel_init: Callable = nn.initializers.lecun_normal()
bias_init: Callable = nn.initializers.zeros_init()
@nn.compact
def __call__(self, inputs):
kernel = self.param('kernel',
self.kernel_init, # 初始化函數
(inputs.shape[-1], self.features)) # 形狀信息。
y = jnp.dot(inputs, kernel)
bias = self.param('bias', self.bias_init, (self.features,))
y = y + bias
return y
key1, key2 = random.split(random.key(0), 2)
x = random.uniform(key1, (4, 4))
model = SimpleDense(features=3)
params = model.init(key2, x)
y = model.apply(params, x)
print('initialized parameters:\n', params)
print('output:\n', y)
解釋:
- 自定義參數:使用 self.param 註冊自定義參數(內核和偏差)。
- 初始化函數:指定每個參數的初始化方式。
- 手動計算:使用 jnp.dot 手動執行密集計算。
關鍵點:
- self.param:註冊一個帶有名稱、初始化函數和形狀的參數。
- 手動參數管理:提供對參數定義和初始化的細粒度控制。
變數和變數集合
除了參數,神經網絡通常還維護狀態變數,例如批量正則化中的運行統計。亞麻允許您使用變數方法來管理這些變數。
範例:帶有運行均值的偏差加法器
class BiasAdderWithRunningMean(nn.Module):
decay: float = 0.99
@nn.compact
def __call__(self, x):
# 檢查 'mean' 變數是否已初始化。
is_initialized = self.has_variable('batch_stats', 'mean')
# 初始化均值的運行平均值。
ra_mean = self.variable('batch_stats', 'mean',
lambda s: jnp.zeros(s),
x.shape[1:])
# 初始化偏差參數。
bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:])
if is_initialized:
ra_mean.value = self.decay * ra_mean.value + (1.0 - self.decay) * jnp.mean(x, axis=0, keepdims=True)
return x - ra_mean.value + bias
# 初始化並應用模型。
key1, key2 = random.split(random.key(0), 2)
x = jnp.ones((10, 5))
model = BiasAdderWithRunningMean()
variables = model.init(key1, x)
print('initialized variables:\n', variables)
y, updated_state = model.apply(variables, x, mutable=['batch_stats'])
print('updated state:\n', updated_state)
解釋:
- self.variable:在 ‘batch_stats’ 集合下註冊一個可變變數(均值)。
- 狀態初始化:用零初始化運行均值。
- 狀態更新:如果已初始化,則在前向傳播期間更新運行均值。
- 可變狀態:使用 apply 中的 mutable 參數指定哪些集合在前向傳播期間是可變的。
管理優化器和模型狀態
處理參數和狀態變數(如運行均值)可能很複雜。以下是一個示例,展示如何使用 Optax 將參數更新與狀態變數更新集成。
for val in [1.0, 2.0, 3.0]:
x = val * jnp.ones((10,5))
y, updated_state = model.apply(variables, x, mutable=['batch_stats'])
old_state, params = flax.core.pop(variables, 'params')
variables = flax.core.freeze({'params': params, **updated_state})
print('updated state:\n', updated_state) # 僅顯示可變部分
from functools import partial
@partial(jax.jit, static_argnums=(0, 1))
def update_step(tx, apply_fn, x, opt_state, params, state):
def loss(params):
y, updated_state = apply_fn({'params': params, **state},
x, mutable=list(state.keys()))
l = ((x - y) ** 2).sum()
return l, updated_state
(l, state), grads = jax.value_and_grad(loss, has_aux=True)(params)
updates, opt_state = tx.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return opt_state, params, state
x = jnp.ones((10,5))
variables = model.init(random.key(0), x)
state, params = flax.core.pop(variables, 'params')
del variables
tx = optax.sgd(learning_rate=0.02)
opt_state = tx.init(params)
for _ in range(3):
opt_state, params, state = update_step(tx, model.apply, x, opt_state, params, state)
print('Updated state: ', state)
解釋:
- update_step 函數:一個 JIT 編譯的函數,同時更新參數和狀態變數。
- 損失函數:計算損失並同時更新狀態變數。
- 梯度計算:使用 jax.value_and_grad 計算相對於參數的梯度。
- Optax 更新:將優化器更新應用於參數。
- 訓練循環:迭代訓練步驟,同時更新參數和狀態變數。
注意:函數簽名可能冗長,並且可能無法直接與 jax.jit() 一起使用,因為某些函數參數不是“有效的 JAX 類型”。亞麻提供了一個方便的包裝器,稱為 TrainState,以簡化此過程。請參閱 flax.training.train_state.TrainState 獲取更多信息。
使用 jax2tf 將模型導出到 TensorFlow 的 SavedModel
JAX 發布了一個實驗性轉換器,稱為 jax2tf,允許將訓練好的亞麻模型轉換為 TensorFlow SavedModel 格式(以便用於 TF Hub、TF.lite、TF.js 或其他下游應用)。該庫包含更多文檔並提供各種亞麻的範例。
結論
亞麻是一個多功能且強大的神經網絡庫,利用 JAX 的高效能力。從設置簡單的線性回歸模型到定義複雜的自定義架構和管理狀態,亞麻為研究和生產環境提供了一個靈活的框架。
在本指南中,我們涵蓋了:
- 環境設置:安裝 JAX、JAXlib 和亞麻。
- 線性回歸:實現和訓練簡單的線性模型。
- 使用 Optax 進行優化:通過先進的優化器簡化訓練過程。
- 序列化:高效保存和加載模型參數。
- 自定義模型:構建具有狀態管理的自定義神經網絡架構。
通過掌握這些基礎知識,您將能夠充分利用亞麻在機器學習項目中的潛力。無論您是在進行學術研究、開發生產就緒的模型,還是探索創新的架構,亞麻都提供了支持您努力的工具和靈活性。
關鍵要點
- 亞麻是一個靈活的高效神經網絡庫,建立在 JAX 之上,為深度學習模型提供模組化和組合性。
- 它遵循功能性編程範式,增強模型的可重現性、調試和可維護性。
- 亞麻與 JAX 無縫整合,利用其優化和並行化能力進行高速計算。
- Linen API 和 `@nn.compact` 裝飾器簡化了神經網絡層和參數的定義和管理。
- 亞麻提供狀態管理、模型序列化和使用可組合優化器(如 Optax)進行高效訓練的工具。
本文中顯示的媒體不屬於 Analytics Vidhya,並由作者自行決定使用。
常見問題
答:亞麻是一個建立在 JAX 之上的先進神經網絡庫,旨在提供高靈活性和性能。研究人員和開發者使用它來高效構建複雜的機器學習模型,利用 JAX 的自動微分和 JIT 編譯進行優化計算。
答:亞麻因其採用功能性編程範式而脫穎而出,將模型視為沒有隱藏狀態的純函數。這促進了調試和可重現性的便利性。它還與 JAX 深度整合,能夠無縫使用 jit、grad 和 vmap 等轉換以增強優化。
答:Linen API 是亞麻的高級、用戶友好的 API,用於定義神經網絡層和模型。它強調清晰性和模組化,使得構建、理解和擴展複雜架構變得更容易。
答:Optax 庫為 JAX 提供先進的梯度處理和優化工具。與亞麻一起使用時,它通過可組合的優化器簡化訓練過程,減少手動編碼並增強靈活性,支持多種優化算法。
答:亞麻使用不可變數據結構(如 FrozenDict)進行參數管理,確保功能純粹性。模型狀態(如批量正則化的運行統計)可以使用集合進行管理,並在前向傳播期間使用可變參數進行更新。
新聞來源
本文由 AI 台灣 使用 AI 編撰,內容僅供參考,請自行進行事實查核。加入 AI TAIWAN Google News,隨時掌握最新 AI 資訊!