feat: allow user to register model alias explicitly, tests

# What does this PR do?

Context: https://github.com/llamastack/llama-stack/discussions/3483

This PR enables the registering `provider_model_id` as the model identifier without breaking backward compatibility.


## Test Plan
todo
# What does this PR do?


## Test Plan
This commit is contained in:
Eric Huang 2025-09-18 15:47:20 -07:00
parent ac1414b571
commit 83a229554b
20 changed files with 236 additions and 92 deletions

View file

@ -116,7 +116,8 @@ models:
model_id: all-MiniLM-L6-v2 model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers provider_id: sentence-transformers
model_type: embedding model_type: embedding
- model_id: ${env.INFERENCE_MODEL} - use_provider_model_id_as_id: true
provider_model_id: ${env.INFERENCE_MODEL}
provider_id: vllm-inference provider_id: vllm-inference
model_type: llm model_type: llm
shields: shields:

View file

@ -10,9 +10,12 @@ from typing import Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel, ConfigDict, Field, field_validator from pydantic import BaseModel, ConfigDict, Field, field_validator
from llama_stack.apis.resource import Resource, ResourceType from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.log import get_logger
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod from llama_stack.schema_utils import json_schema_type, webmethod
logger = get_logger(name=__name__, category="core")
class CommonModelFields(BaseModel): class CommonModelFields(BaseModel):
metadata: dict[str, Any] = Field( metadata: dict[str, Any] = Field(
@ -68,11 +71,36 @@ class Model(CommonModelFields, Resource):
class ModelInput(CommonModelFields): class ModelInput(CommonModelFields):
model_id: str """A model input for registering a model.
provider_id: str | None = None
:param provider_model_id: The identifier of the model in the provider.
:param provider_id: The identifier of the provider.
:param model_type: The type of model to register.
:param model_id: The identifier of the model to register. If model_id == provider_model_id, provider_id/provider_model_id will be used as the identifier. Otherwise,
model_id will be used as the identifier.
The behavior of this field will soon change to "always use model_id as the identifier".
:param use_provider_model_id_as_id: Set to true to use provider_model_id as the identifier. Use model_id if you want to use a different identifier.
"""
provider_model_id: str | None = None provider_model_id: str | None = None
provider_id: str | None = None
model_type: ModelType | None = ModelType.llm model_type: ModelType | None = ModelType.llm
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
# TODO: update behavior of this field to always be the identifier
model_id: str | None = None
use_provider_model_id_as_id: bool = False
def model_post_init(self, __context: Any) -> None:
if self.model_id is None and self.provider_model_id is None:
raise ValueError("provider_model_id must be provided")
if self.model_id == self.provider_model_id:
logger.warning(
f"`model_id` is now optional. The behavior of this field will change if model_id == provider_model_id. Please remove `model_id` and use `provider_model_id` instead.: {self.model_id}"
)
if self.use_provider_model_id_as_id and self.model_id:
raise ValueError(f"use_provider_model_id_as_id and model_id cannot be provided together: {self.model_id}")
class ListModelsResponse(BaseModel): class ListModelsResponse(BaseModel):

View file

@ -26,6 +26,7 @@ from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.access_control.datatypes import AccessRule from llama_stack.core.access_control.datatypes import AccessRule
from llama_stack.log import LoggingConfig
from llama_stack.providers.datatypes import Api, ProviderSpec from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig
@ -185,14 +186,6 @@ class DistributionSpec(BaseModel):
) )
class LoggingConfig(BaseModel):
category_levels: dict[str, str] = Field(
default_factory=dict,
description="""
Dictionary of different logging configurations for different portions (ex: core, server) of llama stack""",
)
class OAuth2JWKSConfig(BaseModel): class OAuth2JWKSConfig(BaseModel):
# The JWKS URI for collecting public keys # The JWKS URI for collecting public keys
uri: str uri: str

View file

@ -69,11 +69,12 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def register_model( async def register_model(
self, self,
model_id: str, model_id: str | None = None,
provider_model_id: str | None = None, provider_model_id: str | None = None,
provider_id: str | None = None, provider_id: str | None = None,
metadata: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None,
model_type: ModelType | None = None, model_type: ModelType | None = None,
use_provider_model_id_as_id: bool = False,
) -> Model: ) -> Model:
if provider_id is None: if provider_id is None:
# If provider_id not specified, use the only provider if it supports this model # If provider_id not specified, use the only provider if it supports this model
@ -85,6 +86,17 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
"Use the provider_id as a prefix to disambiguate, e.g. 'provider_id/model_id'." "Use the provider_id as a prefix to disambiguate, e.g. 'provider_id/model_id'."
) )
if model_id is None and provider_model_id is None:
raise ValueError("provider_model_id must be provided")
if model_id == provider_model_id:
logger.warning(
f"`model_id` is now optional. Please remove `{model_id=}` and use `{provider_model_id=}` instead."
)
if use_provider_model_id_as_id and model_id:
raise ValueError(f"use_provider_model_id_as_id and model_id cannot be provided together: {model_id=}")
provider_model_id = provider_model_id or model_id provider_model_id = provider_model_id or model_id
metadata = metadata or {} metadata = metadata or {}
model_type = model_type or ModelType.llm model_type = model_type or ModelType.llm
@ -94,8 +106,9 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
# an identifier different than provider_model_id implies it is an alias, so that # an identifier different than provider_model_id implies it is an alias, so that
# becomes the globally unique identifier. otherwise provider_model_ids can conflict, # becomes the globally unique identifier. otherwise provider_model_ids can conflict,
# so as a general rule we must use the provider_id to disambiguate. # so as a general rule we must use the provider_id to disambiguate.
if use_provider_model_id_as_id:
if model_id != provider_model_id: identifier = provider_model_id
elif model_id and model_id != provider_model_id:
identifier = model_id identifier = model_id
else: else:
identifier = f"{provider_id}/{provider_model_id}" identifier = f"{provider_id}/{provider_model_id}"

View file

@ -79,15 +79,15 @@ def get_distribution_template() -> DistributionTemplate:
) )
inference_model = ModelInput( inference_model = ModelInput(
model_id="${env.INFERENCE_MODEL}", provider_model_id="${env.INFERENCE_MODEL}",
provider_id="tgi0", provider_id="tgi0",
) )
safety_model = ModelInput( safety_model = ModelInput(
model_id="${env.SAFETY_MODEL}", provider_model_id="${env.SAFETY_MODEL}",
provider_id="tgi1", provider_id="tgi1",
) )
embedding_model = ModelInput( embedding_model = ModelInput(
model_id="all-MiniLM-L6-v2", provider_model_id="all-MiniLM-L6-v2",
provider_id="sentence-transformers", provider_id="sentence-transformers",
model_type=ModelType.embedding, model_type=ModelType.embedding,
metadata={ metadata={

View file

@ -103,18 +103,21 @@ inference_store:
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/inference_store.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/inference_store.db
models: models:
- metadata: {} - metadata: {}
model_id: ${env.INFERENCE_MODEL} provider_model_id: ${env.INFERENCE_MODEL}
provider_id: tgi0 provider_id: tgi0
model_type: llm model_type: llm
use_provider_model_id_as_id: false
- metadata: {} - metadata: {}
model_id: ${env.SAFETY_MODEL} provider_model_id: ${env.SAFETY_MODEL}
provider_id: tgi1 provider_id: tgi1
model_type: llm model_type: llm
use_provider_model_id_as_id: false
- metadata: - metadata:
embedding_dimension: 384 embedding_dimension: 384
model_id: all-MiniLM-L6-v2 provider_model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers provider_id: sentence-transformers
model_type: embedding model_type: embedding
use_provider_model_id_as_id: false
shields: shields:
- shield_id: ${env.SAFETY_MODEL} - shield_id: ${env.SAFETY_MODEL}
vector_dbs: [] vector_dbs: []

View file

@ -99,14 +99,16 @@ inference_store:
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/inference_store.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/inference_store.db
models: models:
- metadata: {} - metadata: {}
model_id: ${env.INFERENCE_MODEL} provider_model_id: ${env.INFERENCE_MODEL}
provider_id: tgi0 provider_id: tgi0
model_type: llm model_type: llm
use_provider_model_id_as_id: false
- metadata: - metadata:
embedding_dimension: 384 embedding_dimension: 384
model_id: all-MiniLM-L6-v2 provider_model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers provider_id: sentence-transformers
model_type: embedding model_type: embedding
use_provider_model_id_as_id: false
shields: [] shields: []
vector_dbs: [] vector_dbs: []
datasets: [] datasets: []

View file

@ -73,11 +73,11 @@ def get_distribution_template() -> DistributionTemplate:
) )
inference_model = ModelInput( inference_model = ModelInput(
model_id="${env.INFERENCE_MODEL}", provider_model_id="${env.INFERENCE_MODEL}",
provider_id="meta-reference-inference", provider_id="meta-reference-inference",
) )
embedding_model = ModelInput( embedding_model = ModelInput(
model_id="all-MiniLM-L6-v2", provider_model_id="all-MiniLM-L6-v2",
provider_id="sentence-transformers", provider_id="sentence-transformers",
model_type=ModelType.embedding, model_type=ModelType.embedding,
metadata={ metadata={
@ -85,7 +85,7 @@ def get_distribution_template() -> DistributionTemplate:
}, },
) )
safety_model = ModelInput( safety_model = ModelInput(
model_id="${env.SAFETY_MODEL}", provider_model_id="${env.SAFETY_MODEL}",
provider_id="meta-reference-safety", provider_id="meta-reference-safety",
) )
default_tool_groups = [ default_tool_groups = [

View file

@ -116,18 +116,21 @@ inference_store:
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/inference_store.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/inference_store.db
models: models:
- metadata: {} - metadata: {}
model_id: ${env.INFERENCE_MODEL} provider_model_id: ${env.INFERENCE_MODEL}
provider_id: meta-reference-inference provider_id: meta-reference-inference
model_type: llm model_type: llm
use_provider_model_id_as_id: false
- metadata: {} - metadata: {}
model_id: ${env.SAFETY_MODEL} provider_model_id: ${env.SAFETY_MODEL}
provider_id: meta-reference-safety provider_id: meta-reference-safety
model_type: llm model_type: llm
use_provider_model_id_as_id: false
- metadata: - metadata:
embedding_dimension: 384 embedding_dimension: 384
model_id: all-MiniLM-L6-v2 provider_model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers provider_id: sentence-transformers
model_type: embedding model_type: embedding
use_provider_model_id_as_id: false
shields: shields:
- shield_id: ${env.SAFETY_MODEL} - shield_id: ${env.SAFETY_MODEL}
vector_dbs: [] vector_dbs: []

View file

@ -106,14 +106,16 @@ inference_store:
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/inference_store.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/inference_store.db
models: models:
- metadata: {} - metadata: {}
model_id: ${env.INFERENCE_MODEL} provider_model_id: ${env.INFERENCE_MODEL}
provider_id: meta-reference-inference provider_id: meta-reference-inference
model_type: llm model_type: llm
use_provider_model_id_as_id: false
- metadata: - metadata:
embedding_dimension: 384 embedding_dimension: 384
model_id: all-MiniLM-L6-v2 provider_model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers provider_id: sentence-transformers
model_type: embedding model_type: embedding
use_provider_model_id_as_id: false
shields: [] shields: []
vector_dbs: [] vector_dbs: []
datasets: [] datasets: []

View file

@ -53,11 +53,11 @@ def get_distribution_template() -> DistributionTemplate:
config=NVIDIAEvalConfig.sample_run_config(), config=NVIDIAEvalConfig.sample_run_config(),
) )
inference_model = ModelInput( inference_model = ModelInput(
model_id="${env.INFERENCE_MODEL}", provider_model_id="${env.INFERENCE_MODEL}",
provider_id="nvidia", provider_id="nvidia",
) )
safety_model = ModelInput( safety_model = ModelInput(
model_id="${env.SAFETY_MODEL}", provider_model_id="${env.SAFETY_MODEL}",
provider_id="nvidia", provider_id="nvidia",
) )

View file

@ -96,13 +96,15 @@ inference_store:
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/inference_store.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/inference_store.db
models: models:
- metadata: {} - metadata: {}
model_id: ${env.INFERENCE_MODEL} provider_model_id: ${env.INFERENCE_MODEL}
provider_id: nvidia provider_id: nvidia
model_type: llm model_type: llm
use_provider_model_id_as_id: false
- metadata: {} - metadata: {}
model_id: ${env.SAFETY_MODEL} provider_model_id: ${env.SAFETY_MODEL}
provider_id: nvidia provider_id: nvidia
model_type: llm model_type: llm
use_provider_model_id_as_id: false
shields: shields:
- shield_id: ${env.SAFETY_MODEL} - shield_id: ${env.SAFETY_MODEL}
provider_id: nvidia provider_id: nvidia

View file

@ -85,88 +85,88 @@ inference_store:
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/inference_store.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/inference_store.db
models: models:
- metadata: {} - metadata: {}
model_id: meta/llama3-8b-instruct
provider_id: nvidia
provider_model_id: meta/llama3-8b-instruct provider_model_id: meta/llama3-8b-instruct
model_type: llm
- metadata: {}
model_id: meta/llama3-70b-instruct
provider_id: nvidia provider_id: nvidia
model_type: llm
use_provider_model_id_as_id: false
- metadata: {}
provider_model_id: meta/llama3-70b-instruct provider_model_id: meta/llama3-70b-instruct
model_type: llm
- metadata: {}
model_id: meta/llama-3.1-8b-instruct
provider_id: nvidia provider_id: nvidia
model_type: llm
use_provider_model_id_as_id: false
- metadata: {}
provider_model_id: meta/llama-3.1-8b-instruct provider_model_id: meta/llama-3.1-8b-instruct
model_type: llm
- metadata: {}
model_id: meta/llama-3.1-70b-instruct
provider_id: nvidia provider_id: nvidia
model_type: llm
use_provider_model_id_as_id: false
- metadata: {}
provider_model_id: meta/llama-3.1-70b-instruct provider_model_id: meta/llama-3.1-70b-instruct
model_type: llm
- metadata: {}
model_id: meta/llama-3.1-405b-instruct
provider_id: nvidia provider_id: nvidia
model_type: llm
use_provider_model_id_as_id: false
- metadata: {}
provider_model_id: meta/llama-3.1-405b-instruct provider_model_id: meta/llama-3.1-405b-instruct
model_type: llm
- metadata: {}
model_id: meta/llama-3.2-1b-instruct
provider_id: nvidia provider_id: nvidia
model_type: llm
use_provider_model_id_as_id: false
- metadata: {}
provider_model_id: meta/llama-3.2-1b-instruct provider_model_id: meta/llama-3.2-1b-instruct
model_type: llm
- metadata: {}
model_id: meta/llama-3.2-3b-instruct
provider_id: nvidia provider_id: nvidia
model_type: llm
use_provider_model_id_as_id: false
- metadata: {}
provider_model_id: meta/llama-3.2-3b-instruct provider_model_id: meta/llama-3.2-3b-instruct
model_type: llm
- metadata: {}
model_id: meta/llama-3.2-11b-vision-instruct
provider_id: nvidia provider_id: nvidia
model_type: llm
use_provider_model_id_as_id: false
- metadata: {}
provider_model_id: meta/llama-3.2-11b-vision-instruct provider_model_id: meta/llama-3.2-11b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta/llama-3.2-90b-vision-instruct
provider_id: nvidia provider_id: nvidia
model_type: llm
use_provider_model_id_as_id: false
- metadata: {}
provider_model_id: meta/llama-3.2-90b-vision-instruct provider_model_id: meta/llama-3.2-90b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta/llama-3.3-70b-instruct
provider_id: nvidia provider_id: nvidia
model_type: llm
use_provider_model_id_as_id: false
- metadata: {}
provider_model_id: meta/llama-3.3-70b-instruct provider_model_id: meta/llama-3.3-70b-instruct
model_type: llm
- metadata: {}
model_id: nvidia/vila
provider_id: nvidia provider_id: nvidia
provider_model_id: nvidia/vila
model_type: llm model_type: llm
use_provider_model_id_as_id: false
- metadata: {}
provider_model_id: nvidia/vila
provider_id: nvidia
model_type: llm
use_provider_model_id_as_id: false
- metadata: - metadata:
embedding_dimension: 2048 embedding_dimension: 2048
context_length: 8192 context_length: 8192
model_id: nvidia/llama-3.2-nv-embedqa-1b-v2
provider_id: nvidia
provider_model_id: nvidia/llama-3.2-nv-embedqa-1b-v2 provider_model_id: nvidia/llama-3.2-nv-embedqa-1b-v2
provider_id: nvidia
model_type: embedding model_type: embedding
use_provider_model_id_as_id: false
- metadata: - metadata:
embedding_dimension: 1024 embedding_dimension: 1024
context_length: 512 context_length: 512
model_id: nvidia/nv-embedqa-e5-v5
provider_id: nvidia
provider_model_id: nvidia/nv-embedqa-e5-v5 provider_model_id: nvidia/nv-embedqa-e5-v5
provider_id: nvidia
model_type: embedding model_type: embedding
use_provider_model_id_as_id: false
- metadata: - metadata:
embedding_dimension: 4096 embedding_dimension: 4096
context_length: 512 context_length: 512
model_id: nvidia/nv-embedqa-mistral-7b-v2
provider_id: nvidia
provider_model_id: nvidia/nv-embedqa-mistral-7b-v2 provider_model_id: nvidia/nv-embedqa-mistral-7b-v2
provider_id: nvidia
model_type: embedding model_type: embedding
use_provider_model_id_as_id: false
- metadata: - metadata:
embedding_dimension: 1024 embedding_dimension: 1024
context_length: 512 context_length: 512
model_id: snowflake/arctic-embed-l
provider_id: nvidia
provider_model_id: snowflake/arctic-embed-l provider_model_id: snowflake/arctic-embed-l
provider_id: nvidia
model_type: embedding model_type: embedding
use_provider_model_id_as_id: false
shields: [] shields: []
vector_dbs: [] vector_dbs: []
datasets: [] datasets: []

View file

@ -136,30 +136,32 @@ inference_store:
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/inference_store.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/inference_store.db
models: models:
- metadata: {} - metadata: {}
model_id: gpt-4o
provider_id: openai
provider_model_id: gpt-4o provider_model_id: gpt-4o
provider_id: openai
model_type: llm model_type: llm
use_provider_model_id_as_id: false
- metadata: {} - metadata: {}
model_id: claude-3-5-sonnet-latest
provider_id: anthropic
provider_model_id: claude-3-5-sonnet-latest provider_model_id: claude-3-5-sonnet-latest
provider_id: anthropic
model_type: llm model_type: llm
use_provider_model_id_as_id: false
- metadata: {} - metadata: {}
model_id: gemini/gemini-1.5-flash
provider_id: gemini
provider_model_id: gemini/gemini-1.5-flash provider_model_id: gemini/gemini-1.5-flash
provider_id: gemini
model_type: llm model_type: llm
use_provider_model_id_as_id: false
- metadata: {} - metadata: {}
model_id: meta-llama/Llama-3.3-70B-Instruct
provider_id: groq
provider_model_id: groq/llama-3.3-70b-versatile provider_model_id: groq/llama-3.3-70b-versatile
provider_id: groq
model_type: llm model_type: llm
model_id: meta-llama/Llama-3.3-70B-Instruct
use_provider_model_id_as_id: false
- metadata: {} - metadata: {}
model_id: meta-llama/Llama-3.1-405B-Instruct
provider_id: together
provider_model_id: meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo provider_model_id: meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo
provider_id: together
model_type: llm model_type: llm
model_id: meta-llama/Llama-3.1-405B-Instruct
use_provider_model_id_as_id: false
shields: shields:
- shield_id: meta-llama/Llama-Guard-3-8B - shield_id: meta-llama/Llama-Guard-3-8B
vector_dbs: [] vector_dbs: []

View file

@ -75,7 +75,7 @@ def get_distribution_template() -> DistributionTemplate:
default_models = [ default_models = [
ModelInput( ModelInput(
model_id="${env.INFERENCE_MODEL}", provider_model_id="${env.INFERENCE_MODEL}",
provider_id="vllm-inference", provider_id="vllm-inference",
) )
] ]
@ -85,7 +85,7 @@ def get_distribution_template() -> DistributionTemplate:
config=SentenceTransformersInferenceConfig.sample_run_config(), config=SentenceTransformersInferenceConfig.sample_run_config(),
) )
embedding_model = ModelInput( embedding_model = ModelInput(
model_id="all-MiniLM-L6-v2", provider_model_id="all-MiniLM-L6-v2",
provider_id=embedding_provider.provider_id, provider_id=embedding_provider.provider_id,
model_type=ModelType.embedding, model_type=ModelType.embedding,
metadata={ metadata={

View file

@ -88,14 +88,16 @@ inference_store:
password: ${env.POSTGRES_PASSWORD:=llamastack} password: ${env.POSTGRES_PASSWORD:=llamastack}
models: models:
- metadata: {} - metadata: {}
model_id: ${env.INFERENCE_MODEL} provider_model_id: ${env.INFERENCE_MODEL}
provider_id: vllm-inference provider_id: vllm-inference
model_type: llm model_type: llm
use_provider_model_id_as_id: false
- metadata: - metadata:
embedding_dimension: 384 embedding_dimension: 384
model_id: all-MiniLM-L6-v2 provider_model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers provider_id: sentence-transformers
model_type: embedding model_type: embedding
use_provider_model_id_as_id: false
shields: shields:
- shield_id: meta-llama/Llama-Guard-3-8B - shield_id: meta-llama/Llama-Guard-3-8B
vector_dbs: [] vector_dbs: []

View file

@ -114,7 +114,7 @@ def get_model_registry(
identifier = f"{provider_id}/{model_id}" if ids_conflict and provider_id not in model_id else model_id identifier = f"{provider_id}/{model_id}" if ids_conflict and provider_id not in model_id else model_id
models.append( models.append(
ModelInput( ModelInput(
model_id=identifier, model_id=identifier if identifier != entry.provider_model_id else None,
provider_model_id=entry.provider_model_id, provider_model_id=entry.provider_model_id,
provider_id=provider_id, provider_id=provider_id,
model_type=entry.model_type, model_type=entry.model_type,

View file

@ -73,7 +73,7 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
] ]
embedding_model = ModelInput( embedding_model = ModelInput(
model_id="all-MiniLM-L6-v2", provider_model_id="all-MiniLM-L6-v2",
provider_id="sentence-transformers", provider_id="sentence-transformers",
model_type=ModelType.embedding, model_type=ModelType.embedding,
metadata={ metadata={

View file

@ -9,11 +9,19 @@ import os
import re import re
from logging.config import dictConfig # allow-direct-logging from logging.config import dictConfig # allow-direct-logging
from pydantic import BaseModel, Field
from rich.console import Console from rich.console import Console
from rich.errors import MarkupError from rich.errors import MarkupError
from rich.logging import RichHandler from rich.logging import RichHandler
from llama_stack.core.datatypes import LoggingConfig
class LoggingConfig(BaseModel):
category_levels: dict[str, str] = Field(
default_factory=dict,
description="""
Dictionary of different logging configurations for different portions (ex: core, server) of llama stack""",
)
# Default log level # Default log level
DEFAULT_LOG_LEVEL = logging.INFO DEFAULT_LOG_LEVEL = logging.INFO

View file

@ -645,3 +645,88 @@ async def test_models_source_interaction_cleanup_provider_models(cached_disk_dis
# Cleanup # Cleanup
await table.shutdown() await table.shutdown()
async def test_models_register_with_use_provider_model_id_as_id(cached_disk_dist_registry):
"""Test register_model with the new use_provider_model_id_as_id parameter."""
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register model using use_provider_model_id_as_id parameter
await table.register_model(
provider_model_id="actual-provider-model", provider_id="test_provider", use_provider_model_id_as_id=True
)
# Verify the model was registered with provider_model_id as identifier
models = await table.list_models()
assert len(models.data) == 1
model = models.data[0]
assert model.identifier == "actual-provider-model"
assert model.provider_resource_id == "actual-provider-model"
assert model.provider_id == "test_provider"
# Test lookup by provider_model_id works
retrieved_model = await table.get_model("actual-provider-model")
assert retrieved_model.identifier == "actual-provider-model"
assert retrieved_model.provider_resource_id == "actual-provider-model"
# Cleanup
await table.shutdown()
async def test_models_register_provider_model_id_only(cached_disk_dist_registry):
"""Test register_model with only provider_model_id (new recommended usage)."""
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register model using only provider_model_id
await table.register_model(provider_model_id="llama-3.1-8b", provider_id="test_provider", model_type=ModelType.llm)
# Verify the model was registered with namespaced identifier
models = await table.list_models()
assert len(models.data) == 1
model = models.data[0]
assert model.identifier == "test_provider/llama-3.1-8b"
assert model.provider_resource_id == "llama-3.1-8b"
assert model.provider_id == "test_provider"
# Test lookup works
retrieved_model = await table.get_model("test_provider/llama-3.1-8b")
assert retrieved_model.identifier == "test_provider/llama-3.1-8b"
# Cleanup
await table.shutdown()
async def test_models_register_validation_errors(cached_disk_dist_registry):
"""Test register_model validation errors."""
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Test error when neither model_id nor provider_model_id is provided
with pytest.raises(ValueError, match="provider_model_id must be provided"):
await table.register_model(provider_id="test_provider")
# Cleanup
await table.shutdown()
async def test_models_register_backward_compatibility_warning(cached_disk_dist_registry):
"""Test that register_model warns when model_id equals provider_model_id."""
from unittest.mock import patch
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Test warning is logged when model_id == provider_model_id
with patch("llama_stack.core.routing_tables.models.logger") as mock_logger:
await table.register_model(model_id="same-model", provider_model_id="same-model", provider_id="test_provider")
# Verify warning was called
mock_logger.warning.assert_called_once()
warning_msg = mock_logger.warning.call_args[0][0]
assert "model_id` is now optional" in warning_msg
assert "provider_model_id='same-model'" in warning_msg
# Cleanup
await table.shutdown()