mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
feat: allow user to register model alias explicitly, tests
# What does this PR do? Context: https://github.com/llamastack/llama-stack/discussions/3483 This PR enables the registering `provider_model_id` as the model identifier without breaking backward compatibility. ## Test Plan todo # What does this PR do? ## Test Plan
This commit is contained in:
parent
ac1414b571
commit
83a229554b
20 changed files with 236 additions and 92 deletions
|
@ -116,7 +116,8 @@ models:
|
||||||
model_id: all-MiniLM-L6-v2
|
model_id: all-MiniLM-L6-v2
|
||||||
provider_id: sentence-transformers
|
provider_id: sentence-transformers
|
||||||
model_type: embedding
|
model_type: embedding
|
||||||
- model_id: ${env.INFERENCE_MODEL}
|
- use_provider_model_id_as_id: true
|
||||||
|
provider_model_id: ${env.INFERENCE_MODEL}
|
||||||
provider_id: vllm-inference
|
provider_id: vllm-inference
|
||||||
model_type: llm
|
model_type: llm
|
||||||
shields:
|
shields:
|
||||||
|
|
|
@ -10,9 +10,12 @@ from typing import Any, Literal, Protocol, runtime_checkable
|
||||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||||
|
|
||||||
from llama_stack.apis.resource import Resource, ResourceType
|
from llama_stack.apis.resource import Resource, ResourceType
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
class CommonModelFields(BaseModel):
|
class CommonModelFields(BaseModel):
|
||||||
metadata: dict[str, Any] = Field(
|
metadata: dict[str, Any] = Field(
|
||||||
|
@ -68,11 +71,36 @@ class Model(CommonModelFields, Resource):
|
||||||
|
|
||||||
|
|
||||||
class ModelInput(CommonModelFields):
|
class ModelInput(CommonModelFields):
|
||||||
model_id: str
|
"""A model input for registering a model.
|
||||||
provider_id: str | None = None
|
|
||||||
|
:param provider_model_id: The identifier of the model in the provider.
|
||||||
|
:param provider_id: The identifier of the provider.
|
||||||
|
:param model_type: The type of model to register.
|
||||||
|
:param model_id: The identifier of the model to register. If model_id == provider_model_id, provider_id/provider_model_id will be used as the identifier. Otherwise,
|
||||||
|
model_id will be used as the identifier.
|
||||||
|
The behavior of this field will soon change to "always use model_id as the identifier".
|
||||||
|
:param use_provider_model_id_as_id: Set to true to use provider_model_id as the identifier. Use model_id if you want to use a different identifier.
|
||||||
|
"""
|
||||||
|
|
||||||
provider_model_id: str | None = None
|
provider_model_id: str | None = None
|
||||||
|
provider_id: str | None = None
|
||||||
model_type: ModelType | None = ModelType.llm
|
model_type: ModelType | None = ModelType.llm
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
# TODO: update behavior of this field to always be the identifier
|
||||||
|
model_id: str | None = None
|
||||||
|
use_provider_model_id_as_id: bool = False
|
||||||
|
|
||||||
|
def model_post_init(self, __context: Any) -> None:
|
||||||
|
if self.model_id is None and self.provider_model_id is None:
|
||||||
|
raise ValueError("provider_model_id must be provided")
|
||||||
|
|
||||||
|
if self.model_id == self.provider_model_id:
|
||||||
|
logger.warning(
|
||||||
|
f"`model_id` is now optional. The behavior of this field will change if model_id == provider_model_id. Please remove `model_id` and use `provider_model_id` instead.: {self.model_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.use_provider_model_id_as_id and self.model_id:
|
||||||
|
raise ValueError(f"use_provider_model_id_as_id and model_id cannot be provided together: {self.model_id}")
|
||||||
|
|
||||||
|
|
||||||
class ListModelsResponse(BaseModel):
|
class ListModelsResponse(BaseModel):
|
||||||
|
|
|
@ -26,6 +26,7 @@ from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
|
||||||
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
|
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.core.access_control.datatypes import AccessRule
|
from llama_stack.core.access_control.datatypes import AccessRule
|
||||||
|
from llama_stack.log import LoggingConfig
|
||||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig
|
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig
|
||||||
|
@ -185,14 +186,6 @@ class DistributionSpec(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class LoggingConfig(BaseModel):
|
|
||||||
category_levels: dict[str, str] = Field(
|
|
||||||
default_factory=dict,
|
|
||||||
description="""
|
|
||||||
Dictionary of different logging configurations for different portions (ex: core, server) of llama stack""",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class OAuth2JWKSConfig(BaseModel):
|
class OAuth2JWKSConfig(BaseModel):
|
||||||
# The JWKS URI for collecting public keys
|
# The JWKS URI for collecting public keys
|
||||||
uri: str
|
uri: str
|
||||||
|
|
|
@ -69,11 +69,12 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
|
|
||||||
async def register_model(
|
async def register_model(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str | None = None,
|
||||||
provider_model_id: str | None = None,
|
provider_model_id: str | None = None,
|
||||||
provider_id: str | None = None,
|
provider_id: str | None = None,
|
||||||
metadata: dict[str, Any] | None = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
model_type: ModelType | None = None,
|
model_type: ModelType | None = None,
|
||||||
|
use_provider_model_id_as_id: bool = False,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
if provider_id is None:
|
if provider_id is None:
|
||||||
# If provider_id not specified, use the only provider if it supports this model
|
# If provider_id not specified, use the only provider if it supports this model
|
||||||
|
@ -85,6 +86,17 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
"Use the provider_id as a prefix to disambiguate, e.g. 'provider_id/model_id'."
|
"Use the provider_id as a prefix to disambiguate, e.g. 'provider_id/model_id'."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if model_id is None and provider_model_id is None:
|
||||||
|
raise ValueError("provider_model_id must be provided")
|
||||||
|
|
||||||
|
if model_id == provider_model_id:
|
||||||
|
logger.warning(
|
||||||
|
f"`model_id` is now optional. Please remove `{model_id=}` and use `{provider_model_id=}` instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_provider_model_id_as_id and model_id:
|
||||||
|
raise ValueError(f"use_provider_model_id_as_id and model_id cannot be provided together: {model_id=}")
|
||||||
|
|
||||||
provider_model_id = provider_model_id or model_id
|
provider_model_id = provider_model_id or model_id
|
||||||
metadata = metadata or {}
|
metadata = metadata or {}
|
||||||
model_type = model_type or ModelType.llm
|
model_type = model_type or ModelType.llm
|
||||||
|
@ -94,8 +106,9 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
# an identifier different than provider_model_id implies it is an alias, so that
|
# an identifier different than provider_model_id implies it is an alias, so that
|
||||||
# becomes the globally unique identifier. otherwise provider_model_ids can conflict,
|
# becomes the globally unique identifier. otherwise provider_model_ids can conflict,
|
||||||
# so as a general rule we must use the provider_id to disambiguate.
|
# so as a general rule we must use the provider_id to disambiguate.
|
||||||
|
if use_provider_model_id_as_id:
|
||||||
if model_id != provider_model_id:
|
identifier = provider_model_id
|
||||||
|
elif model_id and model_id != provider_model_id:
|
||||||
identifier = model_id
|
identifier = model_id
|
||||||
else:
|
else:
|
||||||
identifier = f"{provider_id}/{provider_model_id}"
|
identifier = f"{provider_id}/{provider_model_id}"
|
||||||
|
|
|
@ -79,15 +79,15 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
)
|
)
|
||||||
|
|
||||||
inference_model = ModelInput(
|
inference_model = ModelInput(
|
||||||
model_id="${env.INFERENCE_MODEL}",
|
provider_model_id="${env.INFERENCE_MODEL}",
|
||||||
provider_id="tgi0",
|
provider_id="tgi0",
|
||||||
)
|
)
|
||||||
safety_model = ModelInput(
|
safety_model = ModelInput(
|
||||||
model_id="${env.SAFETY_MODEL}",
|
provider_model_id="${env.SAFETY_MODEL}",
|
||||||
provider_id="tgi1",
|
provider_id="tgi1",
|
||||||
)
|
)
|
||||||
embedding_model = ModelInput(
|
embedding_model = ModelInput(
|
||||||
model_id="all-MiniLM-L6-v2",
|
provider_model_id="all-MiniLM-L6-v2",
|
||||||
provider_id="sentence-transformers",
|
provider_id="sentence-transformers",
|
||||||
model_type=ModelType.embedding,
|
model_type=ModelType.embedding,
|
||||||
metadata={
|
metadata={
|
||||||
|
|
|
@ -103,18 +103,21 @@ inference_store:
|
||||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/inference_store.db
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/inference_store.db
|
||||||
models:
|
models:
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: ${env.INFERENCE_MODEL}
|
provider_model_id: ${env.INFERENCE_MODEL}
|
||||||
provider_id: tgi0
|
provider_id: tgi0
|
||||||
model_type: llm
|
model_type: llm
|
||||||
|
use_provider_model_id_as_id: false
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: ${env.SAFETY_MODEL}
|
provider_model_id: ${env.SAFETY_MODEL}
|
||||||
provider_id: tgi1
|
provider_id: tgi1
|
||||||
model_type: llm
|
model_type: llm
|
||||||
|
use_provider_model_id_as_id: false
|
||||||
- metadata:
|
- metadata:
|
||||||
embedding_dimension: 384
|
embedding_dimension: 384
|
||||||
model_id: all-MiniLM-L6-v2
|
provider_model_id: all-MiniLM-L6-v2
|
||||||
provider_id: sentence-transformers
|
provider_id: sentence-transformers
|
||||||
model_type: embedding
|
model_type: embedding
|
||||||
|
use_provider_model_id_as_id: false
|
||||||
shields:
|
shields:
|
||||||
- shield_id: ${env.SAFETY_MODEL}
|
- shield_id: ${env.SAFETY_MODEL}
|
||||||
vector_dbs: []
|
vector_dbs: []
|
||||||
|
|
|
@ -99,14 +99,16 @@ inference_store:
|
||||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/inference_store.db
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/inference_store.db
|
||||||
models:
|
models:
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: ${env.INFERENCE_MODEL}
|
provider_model_id: ${env.INFERENCE_MODEL}
|
||||||
provider_id: tgi0
|
provider_id: tgi0
|
||||||
model_type: llm
|
model_type: llm
|
||||||
|
use_provider_model_id_as_id: false
|
||||||
- metadata:
|
- metadata:
|
||||||
embedding_dimension: 384
|
embedding_dimension: 384
|
||||||
model_id: all-MiniLM-L6-v2
|
provider_model_id: all-MiniLM-L6-v2
|
||||||
provider_id: sentence-transformers
|
provider_id: sentence-transformers
|
||||||
model_type: embedding
|
model_type: embedding
|
||||||
|
use_provider_model_id_as_id: false
|
||||||
shields: []
|
shields: []
|
||||||
vector_dbs: []
|
vector_dbs: []
|
||||||
datasets: []
|
datasets: []
|
||||||
|
|
|
@ -73,11 +73,11 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
)
|
)
|
||||||
|
|
||||||
inference_model = ModelInput(
|
inference_model = ModelInput(
|
||||||
model_id="${env.INFERENCE_MODEL}",
|
provider_model_id="${env.INFERENCE_MODEL}",
|
||||||
provider_id="meta-reference-inference",
|
provider_id="meta-reference-inference",
|
||||||
)
|
)
|
||||||
embedding_model = ModelInput(
|
embedding_model = ModelInput(
|
||||||
model_id="all-MiniLM-L6-v2",
|
provider_model_id="all-MiniLM-L6-v2",
|
||||||
provider_id="sentence-transformers",
|
provider_id="sentence-transformers",
|
||||||
model_type=ModelType.embedding,
|
model_type=ModelType.embedding,
|
||||||
metadata={
|
metadata={
|
||||||
|
@ -85,7 +85,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
safety_model = ModelInput(
|
safety_model = ModelInput(
|
||||||
model_id="${env.SAFETY_MODEL}",
|
provider_model_id="${env.SAFETY_MODEL}",
|
||||||
provider_id="meta-reference-safety",
|
provider_id="meta-reference-safety",
|
||||||
)
|
)
|
||||||
default_tool_groups = [
|
default_tool_groups = [
|
||||||
|
|
|
@ -116,18 +116,21 @@ inference_store:
|
||||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/inference_store.db
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/inference_store.db
|
||||||
models:
|
models:
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: ${env.INFERENCE_MODEL}
|
provider_model_id: ${env.INFERENCE_MODEL}
|
||||||
provider_id: meta-reference-inference
|
provider_id: meta-reference-inference
|
||||||
model_type: llm
|
model_type: llm
|
||||||
|
use_provider_model_id_as_id: false
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: ${env.SAFETY_MODEL}
|
provider_model_id: ${env.SAFETY_MODEL}
|
||||||
provider_id: meta-reference-safety
|
provider_id: meta-reference-safety
|
||||||
model_type: llm
|
model_type: llm
|
||||||
|
use_provider_model_id_as_id: false
|
||||||
- metadata:
|
- metadata:
|
||||||
embedding_dimension: 384
|
embedding_dimension: 384
|
||||||
model_id: all-MiniLM-L6-v2
|
provider_model_id: all-MiniLM-L6-v2
|
||||||
provider_id: sentence-transformers
|
provider_id: sentence-transformers
|
||||||
model_type: embedding
|
model_type: embedding
|
||||||
|
use_provider_model_id_as_id: false
|
||||||
shields:
|
shields:
|
||||||
- shield_id: ${env.SAFETY_MODEL}
|
- shield_id: ${env.SAFETY_MODEL}
|
||||||
vector_dbs: []
|
vector_dbs: []
|
||||||
|
|
|
@ -106,14 +106,16 @@ inference_store:
|
||||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/inference_store.db
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/inference_store.db
|
||||||
models:
|
models:
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: ${env.INFERENCE_MODEL}
|
provider_model_id: ${env.INFERENCE_MODEL}
|
||||||
provider_id: meta-reference-inference
|
provider_id: meta-reference-inference
|
||||||
model_type: llm
|
model_type: llm
|
||||||
|
use_provider_model_id_as_id: false
|
||||||
- metadata:
|
- metadata:
|
||||||
embedding_dimension: 384
|
embedding_dimension: 384
|
||||||
model_id: all-MiniLM-L6-v2
|
provider_model_id: all-MiniLM-L6-v2
|
||||||
provider_id: sentence-transformers
|
provider_id: sentence-transformers
|
||||||
model_type: embedding
|
model_type: embedding
|
||||||
|
use_provider_model_id_as_id: false
|
||||||
shields: []
|
shields: []
|
||||||
vector_dbs: []
|
vector_dbs: []
|
||||||
datasets: []
|
datasets: []
|
||||||
|
|
|
@ -53,11 +53,11 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
config=NVIDIAEvalConfig.sample_run_config(),
|
config=NVIDIAEvalConfig.sample_run_config(),
|
||||||
)
|
)
|
||||||
inference_model = ModelInput(
|
inference_model = ModelInput(
|
||||||
model_id="${env.INFERENCE_MODEL}",
|
provider_model_id="${env.INFERENCE_MODEL}",
|
||||||
provider_id="nvidia",
|
provider_id="nvidia",
|
||||||
)
|
)
|
||||||
safety_model = ModelInput(
|
safety_model = ModelInput(
|
||||||
model_id="${env.SAFETY_MODEL}",
|
provider_model_id="${env.SAFETY_MODEL}",
|
||||||
provider_id="nvidia",
|
provider_id="nvidia",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -96,13 +96,15 @@ inference_store:
|
||||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/inference_store.db
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/inference_store.db
|
||||||
models:
|
models:
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: ${env.INFERENCE_MODEL}
|
provider_model_id: ${env.INFERENCE_MODEL}
|
||||||
provider_id: nvidia
|
provider_id: nvidia
|
||||||
model_type: llm
|
model_type: llm
|
||||||
|
use_provider_model_id_as_id: false
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: ${env.SAFETY_MODEL}
|
provider_model_id: ${env.SAFETY_MODEL}
|
||||||
provider_id: nvidia
|
provider_id: nvidia
|
||||||
model_type: llm
|
model_type: llm
|
||||||
|
use_provider_model_id_as_id: false
|
||||||
shields:
|
shields:
|
||||||
- shield_id: ${env.SAFETY_MODEL}
|
- shield_id: ${env.SAFETY_MODEL}
|
||||||
provider_id: nvidia
|
provider_id: nvidia
|
||||||
|
|
|
@ -85,88 +85,88 @@ inference_store:
|
||||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/inference_store.db
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/inference_store.db
|
||||||
models:
|
models:
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: meta/llama3-8b-instruct
|
|
||||||
provider_id: nvidia
|
|
||||||
provider_model_id: meta/llama3-8b-instruct
|
provider_model_id: meta/llama3-8b-instruct
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta/llama3-70b-instruct
|
|
||||||
provider_id: nvidia
|
provider_id: nvidia
|
||||||
|
model_type: llm
|
||||||
|
use_provider_model_id_as_id: false
|
||||||
|
- metadata: {}
|
||||||
provider_model_id: meta/llama3-70b-instruct
|
provider_model_id: meta/llama3-70b-instruct
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta/llama-3.1-8b-instruct
|
|
||||||
provider_id: nvidia
|
provider_id: nvidia
|
||||||
|
model_type: llm
|
||||||
|
use_provider_model_id_as_id: false
|
||||||
|
- metadata: {}
|
||||||
provider_model_id: meta/llama-3.1-8b-instruct
|
provider_model_id: meta/llama-3.1-8b-instruct
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta/llama-3.1-70b-instruct
|
|
||||||
provider_id: nvidia
|
provider_id: nvidia
|
||||||
|
model_type: llm
|
||||||
|
use_provider_model_id_as_id: false
|
||||||
|
- metadata: {}
|
||||||
provider_model_id: meta/llama-3.1-70b-instruct
|
provider_model_id: meta/llama-3.1-70b-instruct
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta/llama-3.1-405b-instruct
|
|
||||||
provider_id: nvidia
|
provider_id: nvidia
|
||||||
|
model_type: llm
|
||||||
|
use_provider_model_id_as_id: false
|
||||||
|
- metadata: {}
|
||||||
provider_model_id: meta/llama-3.1-405b-instruct
|
provider_model_id: meta/llama-3.1-405b-instruct
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta/llama-3.2-1b-instruct
|
|
||||||
provider_id: nvidia
|
provider_id: nvidia
|
||||||
|
model_type: llm
|
||||||
|
use_provider_model_id_as_id: false
|
||||||
|
- metadata: {}
|
||||||
provider_model_id: meta/llama-3.2-1b-instruct
|
provider_model_id: meta/llama-3.2-1b-instruct
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta/llama-3.2-3b-instruct
|
|
||||||
provider_id: nvidia
|
provider_id: nvidia
|
||||||
|
model_type: llm
|
||||||
|
use_provider_model_id_as_id: false
|
||||||
|
- metadata: {}
|
||||||
provider_model_id: meta/llama-3.2-3b-instruct
|
provider_model_id: meta/llama-3.2-3b-instruct
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta/llama-3.2-11b-vision-instruct
|
|
||||||
provider_id: nvidia
|
provider_id: nvidia
|
||||||
|
model_type: llm
|
||||||
|
use_provider_model_id_as_id: false
|
||||||
|
- metadata: {}
|
||||||
provider_model_id: meta/llama-3.2-11b-vision-instruct
|
provider_model_id: meta/llama-3.2-11b-vision-instruct
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta/llama-3.2-90b-vision-instruct
|
|
||||||
provider_id: nvidia
|
provider_id: nvidia
|
||||||
|
model_type: llm
|
||||||
|
use_provider_model_id_as_id: false
|
||||||
|
- metadata: {}
|
||||||
provider_model_id: meta/llama-3.2-90b-vision-instruct
|
provider_model_id: meta/llama-3.2-90b-vision-instruct
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta/llama-3.3-70b-instruct
|
|
||||||
provider_id: nvidia
|
provider_id: nvidia
|
||||||
|
model_type: llm
|
||||||
|
use_provider_model_id_as_id: false
|
||||||
|
- metadata: {}
|
||||||
provider_model_id: meta/llama-3.3-70b-instruct
|
provider_model_id: meta/llama-3.3-70b-instruct
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: nvidia/vila
|
|
||||||
provider_id: nvidia
|
provider_id: nvidia
|
||||||
provider_model_id: nvidia/vila
|
|
||||||
model_type: llm
|
model_type: llm
|
||||||
|
use_provider_model_id_as_id: false
|
||||||
|
- metadata: {}
|
||||||
|
provider_model_id: nvidia/vila
|
||||||
|
provider_id: nvidia
|
||||||
|
model_type: llm
|
||||||
|
use_provider_model_id_as_id: false
|
||||||
- metadata:
|
- metadata:
|
||||||
embedding_dimension: 2048
|
embedding_dimension: 2048
|
||||||
context_length: 8192
|
context_length: 8192
|
||||||
model_id: nvidia/llama-3.2-nv-embedqa-1b-v2
|
|
||||||
provider_id: nvidia
|
|
||||||
provider_model_id: nvidia/llama-3.2-nv-embedqa-1b-v2
|
provider_model_id: nvidia/llama-3.2-nv-embedqa-1b-v2
|
||||||
|
provider_id: nvidia
|
||||||
model_type: embedding
|
model_type: embedding
|
||||||
|
use_provider_model_id_as_id: false
|
||||||
- metadata:
|
- metadata:
|
||||||
embedding_dimension: 1024
|
embedding_dimension: 1024
|
||||||
context_length: 512
|
context_length: 512
|
||||||
model_id: nvidia/nv-embedqa-e5-v5
|
|
||||||
provider_id: nvidia
|
|
||||||
provider_model_id: nvidia/nv-embedqa-e5-v5
|
provider_model_id: nvidia/nv-embedqa-e5-v5
|
||||||
|
provider_id: nvidia
|
||||||
model_type: embedding
|
model_type: embedding
|
||||||
|
use_provider_model_id_as_id: false
|
||||||
- metadata:
|
- metadata:
|
||||||
embedding_dimension: 4096
|
embedding_dimension: 4096
|
||||||
context_length: 512
|
context_length: 512
|
||||||
model_id: nvidia/nv-embedqa-mistral-7b-v2
|
|
||||||
provider_id: nvidia
|
|
||||||
provider_model_id: nvidia/nv-embedqa-mistral-7b-v2
|
provider_model_id: nvidia/nv-embedqa-mistral-7b-v2
|
||||||
|
provider_id: nvidia
|
||||||
model_type: embedding
|
model_type: embedding
|
||||||
|
use_provider_model_id_as_id: false
|
||||||
- metadata:
|
- metadata:
|
||||||
embedding_dimension: 1024
|
embedding_dimension: 1024
|
||||||
context_length: 512
|
context_length: 512
|
||||||
model_id: snowflake/arctic-embed-l
|
|
||||||
provider_id: nvidia
|
|
||||||
provider_model_id: snowflake/arctic-embed-l
|
provider_model_id: snowflake/arctic-embed-l
|
||||||
|
provider_id: nvidia
|
||||||
model_type: embedding
|
model_type: embedding
|
||||||
|
use_provider_model_id_as_id: false
|
||||||
shields: []
|
shields: []
|
||||||
vector_dbs: []
|
vector_dbs: []
|
||||||
datasets: []
|
datasets: []
|
||||||
|
|
|
@ -136,30 +136,32 @@ inference_store:
|
||||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/inference_store.db
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/inference_store.db
|
||||||
models:
|
models:
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: gpt-4o
|
|
||||||
provider_id: openai
|
|
||||||
provider_model_id: gpt-4o
|
provider_model_id: gpt-4o
|
||||||
|
provider_id: openai
|
||||||
model_type: llm
|
model_type: llm
|
||||||
|
use_provider_model_id_as_id: false
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: claude-3-5-sonnet-latest
|
|
||||||
provider_id: anthropic
|
|
||||||
provider_model_id: claude-3-5-sonnet-latest
|
provider_model_id: claude-3-5-sonnet-latest
|
||||||
|
provider_id: anthropic
|
||||||
model_type: llm
|
model_type: llm
|
||||||
|
use_provider_model_id_as_id: false
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: gemini/gemini-1.5-flash
|
|
||||||
provider_id: gemini
|
|
||||||
provider_model_id: gemini/gemini-1.5-flash
|
provider_model_id: gemini/gemini-1.5-flash
|
||||||
|
provider_id: gemini
|
||||||
model_type: llm
|
model_type: llm
|
||||||
|
use_provider_model_id_as_id: false
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: meta-llama/Llama-3.3-70B-Instruct
|
|
||||||
provider_id: groq
|
|
||||||
provider_model_id: groq/llama-3.3-70b-versatile
|
provider_model_id: groq/llama-3.3-70b-versatile
|
||||||
|
provider_id: groq
|
||||||
model_type: llm
|
model_type: llm
|
||||||
|
model_id: meta-llama/Llama-3.3-70B-Instruct
|
||||||
|
use_provider_model_id_as_id: false
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: meta-llama/Llama-3.1-405B-Instruct
|
|
||||||
provider_id: together
|
|
||||||
provider_model_id: meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo
|
provider_model_id: meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo
|
||||||
|
provider_id: together
|
||||||
model_type: llm
|
model_type: llm
|
||||||
|
model_id: meta-llama/Llama-3.1-405B-Instruct
|
||||||
|
use_provider_model_id_as_id: false
|
||||||
shields:
|
shields:
|
||||||
- shield_id: meta-llama/Llama-Guard-3-8B
|
- shield_id: meta-llama/Llama-Guard-3-8B
|
||||||
vector_dbs: []
|
vector_dbs: []
|
||||||
|
|
|
@ -75,7 +75,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
|
|
||||||
default_models = [
|
default_models = [
|
||||||
ModelInput(
|
ModelInput(
|
||||||
model_id="${env.INFERENCE_MODEL}",
|
provider_model_id="${env.INFERENCE_MODEL}",
|
||||||
provider_id="vllm-inference",
|
provider_id="vllm-inference",
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
@ -85,7 +85,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
config=SentenceTransformersInferenceConfig.sample_run_config(),
|
config=SentenceTransformersInferenceConfig.sample_run_config(),
|
||||||
)
|
)
|
||||||
embedding_model = ModelInput(
|
embedding_model = ModelInput(
|
||||||
model_id="all-MiniLM-L6-v2",
|
provider_model_id="all-MiniLM-L6-v2",
|
||||||
provider_id=embedding_provider.provider_id,
|
provider_id=embedding_provider.provider_id,
|
||||||
model_type=ModelType.embedding,
|
model_type=ModelType.embedding,
|
||||||
metadata={
|
metadata={
|
||||||
|
|
|
@ -88,14 +88,16 @@ inference_store:
|
||||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||||
models:
|
models:
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: ${env.INFERENCE_MODEL}
|
provider_model_id: ${env.INFERENCE_MODEL}
|
||||||
provider_id: vllm-inference
|
provider_id: vllm-inference
|
||||||
model_type: llm
|
model_type: llm
|
||||||
|
use_provider_model_id_as_id: false
|
||||||
- metadata:
|
- metadata:
|
||||||
embedding_dimension: 384
|
embedding_dimension: 384
|
||||||
model_id: all-MiniLM-L6-v2
|
provider_model_id: all-MiniLM-L6-v2
|
||||||
provider_id: sentence-transformers
|
provider_id: sentence-transformers
|
||||||
model_type: embedding
|
model_type: embedding
|
||||||
|
use_provider_model_id_as_id: false
|
||||||
shields:
|
shields:
|
||||||
- shield_id: meta-llama/Llama-Guard-3-8B
|
- shield_id: meta-llama/Llama-Guard-3-8B
|
||||||
vector_dbs: []
|
vector_dbs: []
|
||||||
|
|
|
@ -114,7 +114,7 @@ def get_model_registry(
|
||||||
identifier = f"{provider_id}/{model_id}" if ids_conflict and provider_id not in model_id else model_id
|
identifier = f"{provider_id}/{model_id}" if ids_conflict and provider_id not in model_id else model_id
|
||||||
models.append(
|
models.append(
|
||||||
ModelInput(
|
ModelInput(
|
||||||
model_id=identifier,
|
model_id=identifier if identifier != entry.provider_model_id else None,
|
||||||
provider_model_id=entry.provider_model_id,
|
provider_model_id=entry.provider_model_id,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
model_type=entry.model_type,
|
model_type=entry.model_type,
|
||||||
|
|
|
@ -73,7 +73,7 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
|
||||||
]
|
]
|
||||||
|
|
||||||
embedding_model = ModelInput(
|
embedding_model = ModelInput(
|
||||||
model_id="all-MiniLM-L6-v2",
|
provider_model_id="all-MiniLM-L6-v2",
|
||||||
provider_id="sentence-transformers",
|
provider_id="sentence-transformers",
|
||||||
model_type=ModelType.embedding,
|
model_type=ModelType.embedding,
|
||||||
metadata={
|
metadata={
|
||||||
|
|
|
@ -9,11 +9,19 @@ import os
|
||||||
import re
|
import re
|
||||||
from logging.config import dictConfig # allow-direct-logging
|
from logging.config import dictConfig # allow-direct-logging
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.errors import MarkupError
|
from rich.errors import MarkupError
|
||||||
from rich.logging import RichHandler
|
from rich.logging import RichHandler
|
||||||
|
|
||||||
from llama_stack.core.datatypes import LoggingConfig
|
|
||||||
|
class LoggingConfig(BaseModel):
|
||||||
|
category_levels: dict[str, str] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="""
|
||||||
|
Dictionary of different logging configurations for different portions (ex: core, server) of llama stack""",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Default log level
|
# Default log level
|
||||||
DEFAULT_LOG_LEVEL = logging.INFO
|
DEFAULT_LOG_LEVEL = logging.INFO
|
||||||
|
|
|
@ -645,3 +645,88 @@ async def test_models_source_interaction_cleanup_provider_models(cached_disk_dis
|
||||||
|
|
||||||
# Cleanup
|
# Cleanup
|
||||||
await table.shutdown()
|
await table.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_models_register_with_use_provider_model_id_as_id(cached_disk_dist_registry):
|
||||||
|
"""Test register_model with the new use_provider_model_id_as_id parameter."""
|
||||||
|
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||||
|
await table.initialize()
|
||||||
|
|
||||||
|
# Register model using use_provider_model_id_as_id parameter
|
||||||
|
await table.register_model(
|
||||||
|
provider_model_id="actual-provider-model", provider_id="test_provider", use_provider_model_id_as_id=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the model was registered with provider_model_id as identifier
|
||||||
|
models = await table.list_models()
|
||||||
|
assert len(models.data) == 1
|
||||||
|
model = models.data[0]
|
||||||
|
assert model.identifier == "actual-provider-model"
|
||||||
|
assert model.provider_resource_id == "actual-provider-model"
|
||||||
|
assert model.provider_id == "test_provider"
|
||||||
|
|
||||||
|
# Test lookup by provider_model_id works
|
||||||
|
retrieved_model = await table.get_model("actual-provider-model")
|
||||||
|
assert retrieved_model.identifier == "actual-provider-model"
|
||||||
|
assert retrieved_model.provider_resource_id == "actual-provider-model"
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
await table.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_models_register_provider_model_id_only(cached_disk_dist_registry):
|
||||||
|
"""Test register_model with only provider_model_id (new recommended usage)."""
|
||||||
|
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||||
|
await table.initialize()
|
||||||
|
|
||||||
|
# Register model using only provider_model_id
|
||||||
|
await table.register_model(provider_model_id="llama-3.1-8b", provider_id="test_provider", model_type=ModelType.llm)
|
||||||
|
|
||||||
|
# Verify the model was registered with namespaced identifier
|
||||||
|
models = await table.list_models()
|
||||||
|
assert len(models.data) == 1
|
||||||
|
model = models.data[0]
|
||||||
|
assert model.identifier == "test_provider/llama-3.1-8b"
|
||||||
|
assert model.provider_resource_id == "llama-3.1-8b"
|
||||||
|
assert model.provider_id == "test_provider"
|
||||||
|
|
||||||
|
# Test lookup works
|
||||||
|
retrieved_model = await table.get_model("test_provider/llama-3.1-8b")
|
||||||
|
assert retrieved_model.identifier == "test_provider/llama-3.1-8b"
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
await table.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_models_register_validation_errors(cached_disk_dist_registry):
|
||||||
|
"""Test register_model validation errors."""
|
||||||
|
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||||
|
await table.initialize()
|
||||||
|
|
||||||
|
# Test error when neither model_id nor provider_model_id is provided
|
||||||
|
with pytest.raises(ValueError, match="provider_model_id must be provided"):
|
||||||
|
await table.register_model(provider_id="test_provider")
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
await table.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_models_register_backward_compatibility_warning(cached_disk_dist_registry):
|
||||||
|
"""Test that register_model warns when model_id equals provider_model_id."""
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||||
|
await table.initialize()
|
||||||
|
|
||||||
|
# Test warning is logged when model_id == provider_model_id
|
||||||
|
with patch("llama_stack.core.routing_tables.models.logger") as mock_logger:
|
||||||
|
await table.register_model(model_id="same-model", provider_model_id="same-model", provider_id="test_provider")
|
||||||
|
|
||||||
|
# Verify warning was called
|
||||||
|
mock_logger.warning.assert_called_once()
|
||||||
|
warning_msg = mock_logger.warning.call_args[0][0]
|
||||||
|
assert "model_id` is now optional" in warning_msg
|
||||||
|
assert "provider_model_id='same-model'" in warning_msg
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
await table.shutdown()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue