From 83a229554b0bbc2e5ffe6fe8f467441895becc79 Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Thu, 18 Sep 2025 15:47:20 -0700 Subject: [PATCH] feat: allow user to register model alias explicitly, tests # What does this PR do? Context: https://github.com/llamastack/llama-stack/discussions/3483 This PR enables the registering `provider_model_id` as the model identifier without breaking backward compatibility. ## Test Plan todo # What does this PR do? ## Test Plan --- .../k8s-benchmark/stack_run_config.yaml | 3 +- llama_stack/apis/models/models.py | 32 ++++++- llama_stack/core/datatypes.py | 9 +- llama_stack/core/routing_tables/models.py | 19 ++++- llama_stack/distributions/dell/dell.py | 6 +- .../distributions/dell/run-with-safety.yaml | 9 +- llama_stack/distributions/dell/run.yaml | 6 +- .../meta-reference-gpu/meta_reference.py | 6 +- .../meta-reference-gpu/run-with-safety.yaml | 9 +- .../distributions/meta-reference-gpu/run.yaml | 6 +- llama_stack/distributions/nvidia/nvidia.py | 4 +- .../distributions/nvidia/run-with-safety.yaml | 6 +- llama_stack/distributions/nvidia/run.yaml | 82 +++++++++--------- .../distributions/open-benchmark/run.yaml | 22 ++--- .../postgres-demo/postgres_demo.py | 4 +- .../distributions/postgres-demo/run.yaml | 6 +- llama_stack/distributions/template.py | 2 +- llama_stack/distributions/watsonx/watsonx.py | 2 +- llama_stack/log.py | 10 ++- .../routers/test_routing_tables.py | 85 +++++++++++++++++++ 20 files changed, 236 insertions(+), 92 deletions(-) diff --git a/benchmarking/k8s-benchmark/stack_run_config.yaml b/benchmarking/k8s-benchmark/stack_run_config.yaml index 5a9e2ae4f..f47c3211b 100644 --- a/benchmarking/k8s-benchmark/stack_run_config.yaml +++ b/benchmarking/k8s-benchmark/stack_run_config.yaml @@ -116,7 +116,8 @@ models: model_id: all-MiniLM-L6-v2 provider_id: sentence-transformers model_type: embedding -- model_id: ${env.INFERENCE_MODEL} +- use_provider_model_id_as_id: true + provider_model_id: ${env.INFERENCE_MODEL} provider_id: vllm-inference model_type: llm shields: diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index 1af6fc9df..384d5799a 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -10,9 +10,12 @@ from typing import Any, Literal, Protocol, runtime_checkable from pydantic import BaseModel, ConfigDict, Field, field_validator from llama_stack.apis.resource import Resource, ResourceType +from llama_stack.log import get_logger from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, webmethod +logger = get_logger(name=__name__, category="core") + class CommonModelFields(BaseModel): metadata: dict[str, Any] = Field( @@ -68,11 +71,36 @@ class Model(CommonModelFields, Resource): class ModelInput(CommonModelFields): - model_id: str - provider_id: str | None = None + """A model input for registering a model. + + :param provider_model_id: The identifier of the model in the provider. + :param provider_id: The identifier of the provider. + :param model_type: The type of model to register. + :param model_id: The identifier of the model to register. If model_id == provider_model_id, provider_id/provider_model_id will be used as the identifier. Otherwise, + model_id will be used as the identifier. + The behavior of this field will soon change to "always use model_id as the identifier". + :param use_provider_model_id_as_id: Set to true to use provider_model_id as the identifier. Use model_id if you want to use a different identifier. + """ + provider_model_id: str | None = None + provider_id: str | None = None model_type: ModelType | None = ModelType.llm model_config = ConfigDict(protected_namespaces=()) + # TODO: update behavior of this field to always be the identifier + model_id: str | None = None + use_provider_model_id_as_id: bool = False + + def model_post_init(self, __context: Any) -> None: + if self.model_id is None and self.provider_model_id is None: + raise ValueError("provider_model_id must be provided") + + if self.model_id == self.provider_model_id: + logger.warning( + f"`model_id` is now optional. The behavior of this field will change if model_id == provider_model_id. Please remove `model_id` and use `provider_model_id` instead.: {self.model_id}" + ) + + if self.use_provider_model_id_as_id and self.model_id: + raise ValueError(f"use_provider_model_id_as_id and model_id cannot be provided together: {self.model_id}") class ListModelsResponse(BaseModel): diff --git a/llama_stack/core/datatypes.py b/llama_stack/core/datatypes.py index faaeefd01..b805160dd 100644 --- a/llama_stack/core/datatypes.py +++ b/llama_stack/core/datatypes.py @@ -26,6 +26,7 @@ from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput from llama_stack.apis.vector_io import VectorIO from llama_stack.core.access_control.datatypes import AccessRule +from llama_stack.log import LoggingConfig from llama_stack.providers.datatypes import Api, ProviderSpec from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig @@ -185,14 +186,6 @@ class DistributionSpec(BaseModel): ) -class LoggingConfig(BaseModel): - category_levels: dict[str, str] = Field( - default_factory=dict, - description=""" - Dictionary of different logging configurations for different portions (ex: core, server) of llama stack""", - ) - - class OAuth2JWKSConfig(BaseModel): # The JWKS URI for collecting public keys uri: str diff --git a/llama_stack/core/routing_tables/models.py b/llama_stack/core/routing_tables/models.py index b6141efa9..3c230aded 100644 --- a/llama_stack/core/routing_tables/models.py +++ b/llama_stack/core/routing_tables/models.py @@ -69,11 +69,12 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): async def register_model( self, - model_id: str, + model_id: str | None = None, provider_model_id: str | None = None, provider_id: str | None = None, metadata: dict[str, Any] | None = None, model_type: ModelType | None = None, + use_provider_model_id_as_id: bool = False, ) -> Model: if provider_id is None: # If provider_id not specified, use the only provider if it supports this model @@ -85,6 +86,17 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): "Use the provider_id as a prefix to disambiguate, e.g. 'provider_id/model_id'." ) + if model_id is None and provider_model_id is None: + raise ValueError("provider_model_id must be provided") + + if model_id == provider_model_id: + logger.warning( + f"`model_id` is now optional. Please remove `{model_id=}` and use `{provider_model_id=}` instead." + ) + + if use_provider_model_id_as_id and model_id: + raise ValueError(f"use_provider_model_id_as_id and model_id cannot be provided together: {model_id=}") + provider_model_id = provider_model_id or model_id metadata = metadata or {} model_type = model_type or ModelType.llm @@ -94,8 +106,9 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): # an identifier different than provider_model_id implies it is an alias, so that # becomes the globally unique identifier. otherwise provider_model_ids can conflict, # so as a general rule we must use the provider_id to disambiguate. - - if model_id != provider_model_id: + if use_provider_model_id_as_id: + identifier = provider_model_id + elif model_id and model_id != provider_model_id: identifier = model_id else: identifier = f"{provider_id}/{provider_model_id}" diff --git a/llama_stack/distributions/dell/dell.py b/llama_stack/distributions/dell/dell.py index e3bf0ee03..5097070e4 100644 --- a/llama_stack/distributions/dell/dell.py +++ b/llama_stack/distributions/dell/dell.py @@ -79,15 +79,15 @@ def get_distribution_template() -> DistributionTemplate: ) inference_model = ModelInput( - model_id="${env.INFERENCE_MODEL}", + provider_model_id="${env.INFERENCE_MODEL}", provider_id="tgi0", ) safety_model = ModelInput( - model_id="${env.SAFETY_MODEL}", + provider_model_id="${env.SAFETY_MODEL}", provider_id="tgi1", ) embedding_model = ModelInput( - model_id="all-MiniLM-L6-v2", + provider_model_id="all-MiniLM-L6-v2", provider_id="sentence-transformers", model_type=ModelType.embedding, metadata={ diff --git a/llama_stack/distributions/dell/run-with-safety.yaml b/llama_stack/distributions/dell/run-with-safety.yaml index d89c92aa1..d00b88656 100644 --- a/llama_stack/distributions/dell/run-with-safety.yaml +++ b/llama_stack/distributions/dell/run-with-safety.yaml @@ -103,18 +103,21 @@ inference_store: db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/inference_store.db models: - metadata: {} - model_id: ${env.INFERENCE_MODEL} + provider_model_id: ${env.INFERENCE_MODEL} provider_id: tgi0 model_type: llm + use_provider_model_id_as_id: false - metadata: {} - model_id: ${env.SAFETY_MODEL} + provider_model_id: ${env.SAFETY_MODEL} provider_id: tgi1 model_type: llm + use_provider_model_id_as_id: false - metadata: embedding_dimension: 384 - model_id: all-MiniLM-L6-v2 + provider_model_id: all-MiniLM-L6-v2 provider_id: sentence-transformers model_type: embedding + use_provider_model_id_as_id: false shields: - shield_id: ${env.SAFETY_MODEL} vector_dbs: [] diff --git a/llama_stack/distributions/dell/run.yaml b/llama_stack/distributions/dell/run.yaml index 7397410ba..7936e2da8 100644 --- a/llama_stack/distributions/dell/run.yaml +++ b/llama_stack/distributions/dell/run.yaml @@ -99,14 +99,16 @@ inference_store: db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/inference_store.db models: - metadata: {} - model_id: ${env.INFERENCE_MODEL} + provider_model_id: ${env.INFERENCE_MODEL} provider_id: tgi0 model_type: llm + use_provider_model_id_as_id: false - metadata: embedding_dimension: 384 - model_id: all-MiniLM-L6-v2 + provider_model_id: all-MiniLM-L6-v2 provider_id: sentence-transformers model_type: embedding + use_provider_model_id_as_id: false shields: [] vector_dbs: [] datasets: [] diff --git a/llama_stack/distributions/meta-reference-gpu/meta_reference.py b/llama_stack/distributions/meta-reference-gpu/meta_reference.py index 78bebb24c..b48d0aea5 100644 --- a/llama_stack/distributions/meta-reference-gpu/meta_reference.py +++ b/llama_stack/distributions/meta-reference-gpu/meta_reference.py @@ -73,11 +73,11 @@ def get_distribution_template() -> DistributionTemplate: ) inference_model = ModelInput( - model_id="${env.INFERENCE_MODEL}", + provider_model_id="${env.INFERENCE_MODEL}", provider_id="meta-reference-inference", ) embedding_model = ModelInput( - model_id="all-MiniLM-L6-v2", + provider_model_id="all-MiniLM-L6-v2", provider_id="sentence-transformers", model_type=ModelType.embedding, metadata={ @@ -85,7 +85,7 @@ def get_distribution_template() -> DistributionTemplate: }, ) safety_model = ModelInput( - model_id="${env.SAFETY_MODEL}", + provider_model_id="${env.SAFETY_MODEL}", provider_id="meta-reference-safety", ) default_tool_groups = [ diff --git a/llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml b/llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml index 910f9ec46..2028f4d73 100644 --- a/llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +++ b/llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml @@ -116,18 +116,21 @@ inference_store: db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/inference_store.db models: - metadata: {} - model_id: ${env.INFERENCE_MODEL} + provider_model_id: ${env.INFERENCE_MODEL} provider_id: meta-reference-inference model_type: llm + use_provider_model_id_as_id: false - metadata: {} - model_id: ${env.SAFETY_MODEL} + provider_model_id: ${env.SAFETY_MODEL} provider_id: meta-reference-safety model_type: llm + use_provider_model_id_as_id: false - metadata: embedding_dimension: 384 - model_id: all-MiniLM-L6-v2 + provider_model_id: all-MiniLM-L6-v2 provider_id: sentence-transformers model_type: embedding + use_provider_model_id_as_id: false shields: - shield_id: ${env.SAFETY_MODEL} vector_dbs: [] diff --git a/llama_stack/distributions/meta-reference-gpu/run.yaml b/llama_stack/distributions/meta-reference-gpu/run.yaml index 5266f3c84..3dc898388 100644 --- a/llama_stack/distributions/meta-reference-gpu/run.yaml +++ b/llama_stack/distributions/meta-reference-gpu/run.yaml @@ -106,14 +106,16 @@ inference_store: db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/inference_store.db models: - metadata: {} - model_id: ${env.INFERENCE_MODEL} + provider_model_id: ${env.INFERENCE_MODEL} provider_id: meta-reference-inference model_type: llm + use_provider_model_id_as_id: false - metadata: embedding_dimension: 384 - model_id: all-MiniLM-L6-v2 + provider_model_id: all-MiniLM-L6-v2 provider_id: sentence-transformers model_type: embedding + use_provider_model_id_as_id: false shields: [] vector_dbs: [] datasets: [] diff --git a/llama_stack/distributions/nvidia/nvidia.py b/llama_stack/distributions/nvidia/nvidia.py index aedda0ae9..1ad960536 100644 --- a/llama_stack/distributions/nvidia/nvidia.py +++ b/llama_stack/distributions/nvidia/nvidia.py @@ -53,11 +53,11 @@ def get_distribution_template() -> DistributionTemplate: config=NVIDIAEvalConfig.sample_run_config(), ) inference_model = ModelInput( - model_id="${env.INFERENCE_MODEL}", + provider_model_id="${env.INFERENCE_MODEL}", provider_id="nvidia", ) safety_model = ModelInput( - model_id="${env.SAFETY_MODEL}", + provider_model_id="${env.SAFETY_MODEL}", provider_id="nvidia", ) diff --git a/llama_stack/distributions/nvidia/run-with-safety.yaml b/llama_stack/distributions/nvidia/run-with-safety.yaml index 015724050..9c351bb74 100644 --- a/llama_stack/distributions/nvidia/run-with-safety.yaml +++ b/llama_stack/distributions/nvidia/run-with-safety.yaml @@ -96,13 +96,15 @@ inference_store: db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/inference_store.db models: - metadata: {} - model_id: ${env.INFERENCE_MODEL} + provider_model_id: ${env.INFERENCE_MODEL} provider_id: nvidia model_type: llm + use_provider_model_id_as_id: false - metadata: {} - model_id: ${env.SAFETY_MODEL} + provider_model_id: ${env.SAFETY_MODEL} provider_id: nvidia model_type: llm + use_provider_model_id_as_id: false shields: - shield_id: ${env.SAFETY_MODEL} provider_id: nvidia diff --git a/llama_stack/distributions/nvidia/run.yaml b/llama_stack/distributions/nvidia/run.yaml index 9fd6b0404..fe8a2d4f3 100644 --- a/llama_stack/distributions/nvidia/run.yaml +++ b/llama_stack/distributions/nvidia/run.yaml @@ -85,88 +85,88 @@ inference_store: db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/inference_store.db models: - metadata: {} - model_id: meta/llama3-8b-instruct - provider_id: nvidia provider_model_id: meta/llama3-8b-instruct - model_type: llm -- metadata: {} - model_id: meta/llama3-70b-instruct provider_id: nvidia + model_type: llm + use_provider_model_id_as_id: false +- metadata: {} provider_model_id: meta/llama3-70b-instruct - model_type: llm -- metadata: {} - model_id: meta/llama-3.1-8b-instruct provider_id: nvidia + model_type: llm + use_provider_model_id_as_id: false +- metadata: {} provider_model_id: meta/llama-3.1-8b-instruct - model_type: llm -- metadata: {} - model_id: meta/llama-3.1-70b-instruct provider_id: nvidia + model_type: llm + use_provider_model_id_as_id: false +- metadata: {} provider_model_id: meta/llama-3.1-70b-instruct - model_type: llm -- metadata: {} - model_id: meta/llama-3.1-405b-instruct provider_id: nvidia + model_type: llm + use_provider_model_id_as_id: false +- metadata: {} provider_model_id: meta/llama-3.1-405b-instruct - model_type: llm -- metadata: {} - model_id: meta/llama-3.2-1b-instruct provider_id: nvidia + model_type: llm + use_provider_model_id_as_id: false +- metadata: {} provider_model_id: meta/llama-3.2-1b-instruct - model_type: llm -- metadata: {} - model_id: meta/llama-3.2-3b-instruct provider_id: nvidia + model_type: llm + use_provider_model_id_as_id: false +- metadata: {} provider_model_id: meta/llama-3.2-3b-instruct - model_type: llm -- metadata: {} - model_id: meta/llama-3.2-11b-vision-instruct provider_id: nvidia + model_type: llm + use_provider_model_id_as_id: false +- metadata: {} provider_model_id: meta/llama-3.2-11b-vision-instruct - model_type: llm -- metadata: {} - model_id: meta/llama-3.2-90b-vision-instruct provider_id: nvidia + model_type: llm + use_provider_model_id_as_id: false +- metadata: {} provider_model_id: meta/llama-3.2-90b-vision-instruct - model_type: llm -- metadata: {} - model_id: meta/llama-3.3-70b-instruct provider_id: nvidia + model_type: llm + use_provider_model_id_as_id: false +- metadata: {} provider_model_id: meta/llama-3.3-70b-instruct - model_type: llm -- metadata: {} - model_id: nvidia/vila provider_id: nvidia - provider_model_id: nvidia/vila model_type: llm + use_provider_model_id_as_id: false +- metadata: {} + provider_model_id: nvidia/vila + provider_id: nvidia + model_type: llm + use_provider_model_id_as_id: false - metadata: embedding_dimension: 2048 context_length: 8192 - model_id: nvidia/llama-3.2-nv-embedqa-1b-v2 - provider_id: nvidia provider_model_id: nvidia/llama-3.2-nv-embedqa-1b-v2 + provider_id: nvidia model_type: embedding + use_provider_model_id_as_id: false - metadata: embedding_dimension: 1024 context_length: 512 - model_id: nvidia/nv-embedqa-e5-v5 - provider_id: nvidia provider_model_id: nvidia/nv-embedqa-e5-v5 + provider_id: nvidia model_type: embedding + use_provider_model_id_as_id: false - metadata: embedding_dimension: 4096 context_length: 512 - model_id: nvidia/nv-embedqa-mistral-7b-v2 - provider_id: nvidia provider_model_id: nvidia/nv-embedqa-mistral-7b-v2 + provider_id: nvidia model_type: embedding + use_provider_model_id_as_id: false - metadata: embedding_dimension: 1024 context_length: 512 - model_id: snowflake/arctic-embed-l - provider_id: nvidia provider_model_id: snowflake/arctic-embed-l + provider_id: nvidia model_type: embedding + use_provider_model_id_as_id: false shields: [] vector_dbs: [] datasets: [] diff --git a/llama_stack/distributions/open-benchmark/run.yaml b/llama_stack/distributions/open-benchmark/run.yaml index d068a0b5a..ea5c1f650 100644 --- a/llama_stack/distributions/open-benchmark/run.yaml +++ b/llama_stack/distributions/open-benchmark/run.yaml @@ -136,30 +136,32 @@ inference_store: db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/inference_store.db models: - metadata: {} - model_id: gpt-4o - provider_id: openai provider_model_id: gpt-4o + provider_id: openai model_type: llm + use_provider_model_id_as_id: false - metadata: {} - model_id: claude-3-5-sonnet-latest - provider_id: anthropic provider_model_id: claude-3-5-sonnet-latest + provider_id: anthropic model_type: llm + use_provider_model_id_as_id: false - metadata: {} - model_id: gemini/gemini-1.5-flash - provider_id: gemini provider_model_id: gemini/gemini-1.5-flash + provider_id: gemini model_type: llm + use_provider_model_id_as_id: false - metadata: {} - model_id: meta-llama/Llama-3.3-70B-Instruct - provider_id: groq provider_model_id: groq/llama-3.3-70b-versatile + provider_id: groq model_type: llm + model_id: meta-llama/Llama-3.3-70B-Instruct + use_provider_model_id_as_id: false - metadata: {} - model_id: meta-llama/Llama-3.1-405B-Instruct - provider_id: together provider_model_id: meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo + provider_id: together model_type: llm + model_id: meta-llama/Llama-3.1-405B-Instruct + use_provider_model_id_as_id: false shields: - shield_id: meta-llama/Llama-Guard-3-8B vector_dbs: [] diff --git a/llama_stack/distributions/postgres-demo/postgres_demo.py b/llama_stack/distributions/postgres-demo/postgres_demo.py index c04cfedfa..f322379a0 100644 --- a/llama_stack/distributions/postgres-demo/postgres_demo.py +++ b/llama_stack/distributions/postgres-demo/postgres_demo.py @@ -75,7 +75,7 @@ def get_distribution_template() -> DistributionTemplate: default_models = [ ModelInput( - model_id="${env.INFERENCE_MODEL}", + provider_model_id="${env.INFERENCE_MODEL}", provider_id="vllm-inference", ) ] @@ -85,7 +85,7 @@ def get_distribution_template() -> DistributionTemplate: config=SentenceTransformersInferenceConfig.sample_run_config(), ) embedding_model = ModelInput( - model_id="all-MiniLM-L6-v2", + provider_model_id="all-MiniLM-L6-v2", provider_id=embedding_provider.provider_id, model_type=ModelType.embedding, metadata={ diff --git a/llama_stack/distributions/postgres-demo/run.yaml b/llama_stack/distributions/postgres-demo/run.yaml index 0cf0e82e6..28b2e50a2 100644 --- a/llama_stack/distributions/postgres-demo/run.yaml +++ b/llama_stack/distributions/postgres-demo/run.yaml @@ -88,14 +88,16 @@ inference_store: password: ${env.POSTGRES_PASSWORD:=llamastack} models: - metadata: {} - model_id: ${env.INFERENCE_MODEL} + provider_model_id: ${env.INFERENCE_MODEL} provider_id: vllm-inference model_type: llm + use_provider_model_id_as_id: false - metadata: embedding_dimension: 384 - model_id: all-MiniLM-L6-v2 + provider_model_id: all-MiniLM-L6-v2 provider_id: sentence-transformers model_type: embedding + use_provider_model_id_as_id: false shields: - shield_id: meta-llama/Llama-Guard-3-8B vector_dbs: [] diff --git a/llama_stack/distributions/template.py b/llama_stack/distributions/template.py index d564312dc..03976d74c 100644 --- a/llama_stack/distributions/template.py +++ b/llama_stack/distributions/template.py @@ -114,7 +114,7 @@ def get_model_registry( identifier = f"{provider_id}/{model_id}" if ids_conflict and provider_id not in model_id else model_id models.append( ModelInput( - model_id=identifier, + model_id=identifier if identifier != entry.provider_model_id else None, provider_model_id=entry.provider_model_id, provider_id=provider_id, model_type=entry.model_type, diff --git a/llama_stack/distributions/watsonx/watsonx.py b/llama_stack/distributions/watsonx/watsonx.py index c3cab5d1b..fa6c1b4b5 100644 --- a/llama_stack/distributions/watsonx/watsonx.py +++ b/llama_stack/distributions/watsonx/watsonx.py @@ -73,7 +73,7 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate: ] embedding_model = ModelInput( - model_id="all-MiniLM-L6-v2", + provider_model_id="all-MiniLM-L6-v2", provider_id="sentence-transformers", model_type=ModelType.embedding, metadata={ diff --git a/llama_stack/log.py b/llama_stack/log.py index cc4c9d4cf..ccff34758 100644 --- a/llama_stack/log.py +++ b/llama_stack/log.py @@ -9,11 +9,19 @@ import os import re from logging.config import dictConfig # allow-direct-logging +from pydantic import BaseModel, Field from rich.console import Console from rich.errors import MarkupError from rich.logging import RichHandler -from llama_stack.core.datatypes import LoggingConfig + +class LoggingConfig(BaseModel): + category_levels: dict[str, str] = Field( + default_factory=dict, + description=""" + Dictionary of different logging configurations for different portions (ex: core, server) of llama stack""", + ) + # Default log level DEFAULT_LOG_LEVEL = logging.INFO diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py index bbfea3f46..6fc08e5bb 100644 --- a/tests/unit/distribution/routers/test_routing_tables.py +++ b/tests/unit/distribution/routers/test_routing_tables.py @@ -645,3 +645,88 @@ async def test_models_source_interaction_cleanup_provider_models(cached_disk_dis # Cleanup await table.shutdown() + + +async def test_models_register_with_use_provider_model_id_as_id(cached_disk_dist_registry): + """Test register_model with the new use_provider_model_id_as_id parameter.""" + table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) + await table.initialize() + + # Register model using use_provider_model_id_as_id parameter + await table.register_model( + provider_model_id="actual-provider-model", provider_id="test_provider", use_provider_model_id_as_id=True + ) + + # Verify the model was registered with provider_model_id as identifier + models = await table.list_models() + assert len(models.data) == 1 + model = models.data[0] + assert model.identifier == "actual-provider-model" + assert model.provider_resource_id == "actual-provider-model" + assert model.provider_id == "test_provider" + + # Test lookup by provider_model_id works + retrieved_model = await table.get_model("actual-provider-model") + assert retrieved_model.identifier == "actual-provider-model" + assert retrieved_model.provider_resource_id == "actual-provider-model" + + # Cleanup + await table.shutdown() + + +async def test_models_register_provider_model_id_only(cached_disk_dist_registry): + """Test register_model with only provider_model_id (new recommended usage).""" + table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) + await table.initialize() + + # Register model using only provider_model_id + await table.register_model(provider_model_id="llama-3.1-8b", provider_id="test_provider", model_type=ModelType.llm) + + # Verify the model was registered with namespaced identifier + models = await table.list_models() + assert len(models.data) == 1 + model = models.data[0] + assert model.identifier == "test_provider/llama-3.1-8b" + assert model.provider_resource_id == "llama-3.1-8b" + assert model.provider_id == "test_provider" + + # Test lookup works + retrieved_model = await table.get_model("test_provider/llama-3.1-8b") + assert retrieved_model.identifier == "test_provider/llama-3.1-8b" + + # Cleanup + await table.shutdown() + + +async def test_models_register_validation_errors(cached_disk_dist_registry): + """Test register_model validation errors.""" + table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) + await table.initialize() + + # Test error when neither model_id nor provider_model_id is provided + with pytest.raises(ValueError, match="provider_model_id must be provided"): + await table.register_model(provider_id="test_provider") + + # Cleanup + await table.shutdown() + + +async def test_models_register_backward_compatibility_warning(cached_disk_dist_registry): + """Test that register_model warns when model_id equals provider_model_id.""" + from unittest.mock import patch + + table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) + await table.initialize() + + # Test warning is logged when model_id == provider_model_id + with patch("llama_stack.core.routing_tables.models.logger") as mock_logger: + await table.register_model(model_id="same-model", provider_model_id="same-model", provider_id="test_provider") + + # Verify warning was called + mock_logger.warning.assert_called_once() + warning_msg = mock_logger.warning.call_args[0][0] + assert "model_id` is now optional" in warning_msg + assert "provider_model_id='same-model'" in warning_msg + + # Cleanup + await table.shutdown()