Flax¶
关于此页面
这是 BentoML 中 FLax 的 API 参考。有关如何在 BentoML 中使用 Flax 的更多信息,请参阅 /frameworks/flax。
- bentoml.flax.save_model(name: Tag | str, module: nn.Module, state: dict[str, t.Any] | FrozenDict[str, t.Any] | struct.PyTreeNode, *, signatures: ModelSignaturesType | None = None, labels: dict[str, str] | None = None, custom_objects: dict[str, t.Any] | None = None, external_modules: t.List[ModuleType] | None = None, metadata: dict[str, t.Any] | None = None) bentoml.Model ¶
将
flax.linen.Module
模型实例保存到 BentoML 模型库。- 参数:
name – 在 BentoML 模型库中为模型指定的名称。这必须是有效的
Tag
名称。module – 要保存的
flax.linen.Module
实例。signatures – 要使用的预测方法的签名。如果未提供,签名默认为
predict
。更多详情请参阅ModelSignature
。labels – 与模型关联的一组默认管理标签。例如
{"training-set": "data-1"}
。custom_objects – 要与模型一起保存的自定义对象。例如
{"my-normalizer": normalizer}
。自定义对象目前使用 cloudpickle 序列化,但此实现可能会更改。external_modules – 用户定义的额外 Python 模块,与模型或自定义对象一起保存,例如分词器模块、预处理器模块、模型配置模块。
metadata – 与模型关联的元数据。例如
{"bias": 4}
。元数据旨在用于模型管理 UI 中显示,因此必须是默认的 Python 类型,例如 str 或 int。
- 返回值:
可用于从 BentoML 模型库访问已保存模型的标签。
- 返回类型:
Tag
示例
import jax rng, init_rng = jax.random.split(rng) state = create_train_state(init_rng, config) for epoch in range(1, config.num_epochs + 1): rng, input_rng = jax.random.split(rng) state, train_loss, train_accuracy = train_epoch( state, train_ds, config.batch_size, input_rng ) _, test_loss, test_accuracy = apply_model( state, test_ds["image"], test_ds["label"] ) logger.info( "epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f, test_accuracy: %.2f", epoch, train_loss, train_accuracy * 100, test_loss, test_accuracy * 100 ) # `Save` the model with BentoML tag = bentoml.flax.save_model("mnist", CNN(), state)
- bentoml.flax.load_model(bento_model: str | Tag | bentoml.Model, init: bool = True, device: str | XlaBackend = 'cpu') tuple[nn.Module, dict[str, t.Any]] ¶
从本地 BentoML 模型库加载具有给定标签的
flax.linen.Module
模型实例。- 参数:
bento_model – 要从模型库获取的模型标签,或用于加载模型的 BentoML ~bentoml.Model 实例。
init – 是否初始化给定的
flax.linen.Module
的 state dict。默认情况下,权重和值将被放入jnp.ndarray
。如果init
设置为False
,则 state_dict 只会被放入给定的加速器设备。device – 放置 state dict 的设备。默认情况下,它将被放置在
cpu
上。这仅在init
设置为False
时使用。
- 返回值:
从模型库加载的
flax.linen.Module
及其state_dict
的元组。
示例
import bentoml import jax net, state_dict = bentoml.flax.load_model("mnist:latest") predict_fn = jax.jit(lambda s: net.apply({"params": state_dict["params"]}, x)) results = predict_fn(jnp.ones((1, 28, 28, 1)))