Flax

关于此页面

这是 BentoML 中 FLax 的 API 参考。有关如何在 BentoML 中使用 Flax 的更多信息,请参阅 /frameworks/flax

bentoml.flax.save_model(name: Tag | str, module: nn.Module, state: dict[str, t.Any] | FrozenDict[str, t.Any] | struct.PyTreeNode, *, signatures: ModelSignaturesType | None = None, labels: dict[str, str] | None = None, custom_objects: dict[str, t.Any] | None = None, external_modules: t.List[ModuleType] | None = None, metadata: dict[str, t.Any] | None = None) bentoml.Model

flax.linen.Module 模型实例保存到 BentoML 模型库。

参数:
  • name – 在 BentoML 模型库中为模型指定的名称。这必须是有效的 Tag 名称。

  • module – 要保存的 flax.linen.Module 实例。

  • signatures – 要使用的预测方法的签名。如果未提供,签名默认为 predict。更多详情请参阅 ModelSignature

  • labels – 与模型关联的一组默认管理标签。例如 {"training-set": "data-1"}

  • custom_objects – 要与模型一起保存的自定义对象。例如 {"my-normalizer": normalizer}。自定义对象目前使用 cloudpickle 序列化,但此实现可能会更改。

  • external_modules – 用户定义的额外 Python 模块,与模型或自定义对象一起保存,例如分词器模块、预处理器模块、模型配置模块。

  • metadata – 与模型关联的元数据。例如 {"bias": 4}。元数据旨在用于模型管理 UI 中显示,因此必须是默认的 Python 类型,例如 strint

返回值:

可用于从 BentoML 模型库访问已保存模型的标签。

返回类型:

Tag

示例

import jax

rng, init_rng = jax.random.split(rng)
state = create_train_state(init_rng, config)

for epoch in range(1, config.num_epochs + 1):
    rng, input_rng = jax.random.split(rng)
    state, train_loss, train_accuracy = train_epoch(
        state, train_ds, config.batch_size, input_rng
    )
    _, test_loss, test_accuracy = apply_model(
        state, test_ds["image"], test_ds["label"]
    )

    logger.info(
        "epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f, test_accuracy: %.2f",
        epoch, train_loss, train_accuracy * 100, test_loss, test_accuracy * 100
    )

# `Save` the model with BentoML
tag = bentoml.flax.save_model("mnist", CNN(), state)
bentoml.flax.load_model(bento_model: str | Tag | bentoml.Model, init: bool = True, device: str | XlaBackend = 'cpu') tuple[nn.Module, dict[str, t.Any]]

从本地 BentoML 模型库加载具有给定标签的 flax.linen.Module 模型实例。

参数:
  • bento_model – 要从模型库获取的模型标签,或用于加载模型的 BentoML ~bentoml.Model 实例。

  • init – 是否初始化给定的 flax.linen.Module 的 state dict。默认情况下,权重和值将被放入 jnp.ndarray。如果 init 设置为 False,则 state_dict 只会被放入给定的加速器设备。

  • device – 放置 state dict 的设备。默认情况下,它将被放置在 cpu 上。这仅在 init 设置为 False 时使用。

返回值:

从模型库加载的 flax.linen.Module 及其 state_dict 的元组。

示例

import bentoml
import jax

net, state_dict = bentoml.flax.load_model("mnist:latest")
predict_fn = jax.jit(lambda s: net.apply({"params": state_dict["params"]}, x))
results = predict_fn(jnp.ones((1, 28, 28, 1)))
bentoml.flax.get(tag_like: str | Tag) bentoml.Model

获取具有给定标签的 BentoML 模型。

参数:

tag_like – 要从模型库检索的模型标签。

返回值:

具有匹配标签的 BentoML Model

返回类型:

模型

示例

import bentoml

model = bentoml.flax.get("mnist:latest")