feat: allow user to register model alias explicitly, tests

# What does this PR do?

Context: https://github.com/llamastack/llama-stack/discussions/3483

This PR enables the registering `provider_model_id` as the model identifier without breaking backward compatibility.


## Test Plan
todo
# What does this PR do?


## Test Plan
This commit is contained in:
Eric Huang 2025-09-18 15:47:20 -07:00
parent ac1414b571
commit 83a229554b
20 changed files with 236 additions and 92 deletions

View file

@ -69,11 +69,12 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def register_model(
self,
model_id: str,
model_id: str | None = None,
provider_model_id: str | None = None,
provider_id: str | None = None,
metadata: dict[str, Any] | None = None,
model_type: ModelType | None = None,
use_provider_model_id_as_id: bool = False,
) -> Model:
if provider_id is None:
# If provider_id not specified, use the only provider if it supports this model
@ -85,6 +86,17 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
"Use the provider_id as a prefix to disambiguate, e.g. 'provider_id/model_id'."
)
if model_id is None and provider_model_id is None:
raise ValueError("provider_model_id must be provided")
if model_id == provider_model_id:
logger.warning(
f"`model_id` is now optional. Please remove `{model_id=}` and use `{provider_model_id=}` instead."
)
if use_provider_model_id_as_id and model_id:
raise ValueError(f"use_provider_model_id_as_id and model_id cannot be provided together: {model_id=}")
provider_model_id = provider_model_id or model_id
metadata = metadata or {}
model_type = model_type or ModelType.llm
@ -94,8 +106,9 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
# an identifier different than provider_model_id implies it is an alias, so that
# becomes the globally unique identifier. otherwise provider_model_ids can conflict,
# so as a general rule we must use the provider_id to disambiguate.
if model_id != provider_model_id:
if use_provider_model_id_as_id:
identifier = provider_model_id
elif model_id and model_id != provider_model_id:
identifier = model_id
else:
identifier = f"{provider_id}/{provider_model_id}"