From b67aef2fc42b4f8cbddf0bec2c846234c869e04c Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Thu, 25 Sep 2025 17:17:00 -0400 Subject: [PATCH] feat: add static embedding metadata to dynamic model listings for providers using OpenAIMixin (#3547) # What does this PR do? - remove auto-download of ollama embedding models - add embedding model metadata to dynamic listing w/ unit test - add support and tests for allowed_models - removed inference provider models.py files where dynamic listing is enabled - store embedding metadata in embedding_model_metadata field on inference providers - make model_entries optional on ModelRegistryHelper and LiteLLMOpenAIMixin - make OpenAIMixin a ModelRegistryHelper - skip base64 embedding test for remote::ollama, always returns floats - only use OpenAI client for ollama model listing - remove unused build_model_entry function - remove unused get_huggingface_repo function ## Test Plan ci w/ new tests --- .../self_hosted_distro/nvidia.md | 19 --- llama_stack/distributions/nvidia/nvidia.py | 9 +- llama_stack/distributions/nvidia/run.yaml | 85 +---------- .../providers/remote/eval/nvidia/eval.py | 3 +- .../remote/inference/anthropic/anthropic.py | 14 +- .../remote/inference/anthropic/models.py | 40 ----- .../providers/remote/inference/azure/azure.py | 2 - .../remote/inference/azure/models.py | 28 ---- .../remote/inference/bedrock/bedrock.py | 2 +- .../remote/inference/cerebras/cerebras.py | 5 - .../remote/inference/cerebras/models.py | 28 ---- .../remote/inference/databricks/databricks.py | 43 ++---- .../remote/inference/fireworks/fireworks.py | 7 +- .../remote/inference/fireworks/models.py | 70 --------- .../remote/inference/gemini/gemini.py | 6 +- .../remote/inference/gemini/models.py | 34 ----- .../providers/remote/inference/groq/groq.py | 3 - .../providers/remote/inference/groq/models.py | 48 ------ .../inference/llama_openai_compat/llama.py | 3 - .../inference/llama_openai_compat/models.py | 25 ---- .../remote/inference/nvidia/models.py | 109 -------------- .../remote/inference/nvidia/nvidia.py | 17 ++- .../remote/inference/ollama/models.py | 106 -------------- .../remote/inference/ollama/ollama.py | 138 +++++++----------- .../remote/inference/openai/models.py | 60 -------- .../remote/inference/openai/openai.py | 7 +- .../inference/passthrough/passthrough.py | 2 +- .../remote/inference/sambanova/models.py | 28 ---- .../remote/inference/sambanova/sambanova.py | 2 - .../remote/inference/together/models.py | 103 ------------- .../remote/inference/together/together.py | 18 ++- .../remote/inference/vertexai/models.py | 20 --- .../remote/inference/vertexai/vertexai.py | 2 - .../providers/remote/inference/vllm/vllm.py | 2 +- .../remote/inference/watsonx/watsonx.py | 2 +- .../utils/inference/litellm_openai_mixin.py | 6 +- .../utils/inference/model_registry.py | 27 +--- .../providers/utils/inference/openai_mixin.py | 47 ++++-- .../models-bd032f995f2a-3255f444.json | 96 ++++++++++++ .../inference/test_litellm_openai_mixin.py | 1 - tests/unit/providers/nvidia/test_eval.py | 2 +- .../utils/inference/test_openai_mixin.py | 110 +++++++++++++- .../providers/utils/test_model_registry.py | 4 +- 43 files changed, 368 insertions(+), 1015 deletions(-) delete mode 100644 llama_stack/providers/remote/inference/anthropic/models.py delete mode 100644 llama_stack/providers/remote/inference/azure/models.py delete mode 100644 llama_stack/providers/remote/inference/cerebras/models.py delete mode 100644 llama_stack/providers/remote/inference/fireworks/models.py delete mode 100644 llama_stack/providers/remote/inference/gemini/models.py delete mode 100644 llama_stack/providers/remote/inference/groq/models.py delete mode 100644 llama_stack/providers/remote/inference/llama_openai_compat/models.py delete mode 100644 llama_stack/providers/remote/inference/nvidia/models.py delete mode 100644 llama_stack/providers/remote/inference/ollama/models.py delete mode 100644 llama_stack/providers/remote/inference/openai/models.py delete mode 100644 llama_stack/providers/remote/inference/sambanova/models.py delete mode 100644 llama_stack/providers/remote/inference/together/models.py delete mode 100644 llama_stack/providers/remote/inference/vertexai/models.py create mode 100644 tests/integration/recordings/responses/models-bd032f995f2a-3255f444.json diff --git a/docs/docs/distributions/self_hosted_distro/nvidia.md b/docs/docs/distributions/self_hosted_distro/nvidia.md index fba411640..1e52797db 100644 --- a/docs/docs/distributions/self_hosted_distro/nvidia.md +++ b/docs/docs/distributions/self_hosted_distro/nvidia.md @@ -37,25 +37,6 @@ The following environment variables can be configured: - `INFERENCE_MODEL`: Inference model (default: `Llama3.1-8B-Instruct`) - `SAFETY_MODEL`: Name of the model to use for safety (default: `meta/llama-3.1-8b-instruct`) -### Models - -The following models are available by default: - -- `meta/llama3-8b-instruct ` -- `meta/llama3-70b-instruct ` -- `meta/llama-3.1-8b-instruct ` -- `meta/llama-3.1-70b-instruct ` -- `meta/llama-3.1-405b-instruct ` -- `meta/llama-3.2-1b-instruct ` -- `meta/llama-3.2-3b-instruct ` -- `meta/llama-3.2-11b-vision-instruct ` -- `meta/llama-3.2-90b-vision-instruct ` -- `meta/llama-3.3-70b-instruct ` -- `nvidia/vila ` -- `nvidia/llama-3.2-nv-embedqa-1b-v2 ` -- `nvidia/nv-embedqa-e5-v5 ` -- `nvidia/nv-embedqa-mistral-7b-v2 ` -- `snowflake/arctic-embed-l ` ## Prerequisites diff --git a/llama_stack/distributions/nvidia/nvidia.py b/llama_stack/distributions/nvidia/nvidia.py index 779fabf2c..b41eea130 100644 --- a/llama_stack/distributions/nvidia/nvidia.py +++ b/llama_stack/distributions/nvidia/nvidia.py @@ -7,12 +7,11 @@ from pathlib import Path from llama_stack.core.datatypes import BuildProvider, ModelInput, Provider, ShieldInput, ToolGroupInput -from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings, get_model_registry +from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig from llama_stack.providers.remote.datasetio.nvidia import NvidiaDatasetIOConfig from llama_stack.providers.remote.eval.nvidia import NVIDIAEvalConfig from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig -from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig @@ -68,9 +67,6 @@ def get_distribution_template(name: str = "nvidia") -> DistributionTemplate: provider_id="nvidia", ) - available_models = { - "nvidia": MODEL_ENTRIES, - } default_tool_groups = [ ToolGroupInput( toolgroup_id="builtin::rag", @@ -78,7 +74,6 @@ def get_distribution_template(name: str = "nvidia") -> DistributionTemplate: ), ] - default_models, _ = get_model_registry(available_models) return DistributionTemplate( name=name, distro_type="self_hosted", @@ -86,7 +81,6 @@ def get_distribution_template(name: str = "nvidia") -> DistributionTemplate: container_image=None, template_path=Path(__file__).parent / "doc_template.md", providers=providers, - available_models_by_provider=available_models, run_configs={ "run.yaml": RunConfigSettings( provider_overrides={ @@ -95,7 +89,6 @@ def get_distribution_template(name: str = "nvidia") -> DistributionTemplate: "eval": [eval_provider], "files": [files_provider], }, - default_models=default_models, default_tool_groups=default_tool_groups, ), "run-with-safety.yaml": RunConfigSettings( diff --git a/llama_stack/distributions/nvidia/run.yaml b/llama_stack/distributions/nvidia/run.yaml index 362970d2e..3f3cfc514 100644 --- a/llama_stack/distributions/nvidia/run.yaml +++ b/llama_stack/distributions/nvidia/run.yaml @@ -92,90 +92,7 @@ metadata_store: inference_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/inference_store.db -models: -- metadata: {} - model_id: meta/llama3-8b-instruct - provider_id: nvidia - provider_model_id: meta/llama3-8b-instruct - model_type: llm -- metadata: {} - model_id: meta/llama3-70b-instruct - provider_id: nvidia - provider_model_id: meta/llama3-70b-instruct - model_type: llm -- metadata: {} - model_id: meta/llama-3.1-8b-instruct - provider_id: nvidia - provider_model_id: meta/llama-3.1-8b-instruct - model_type: llm -- metadata: {} - model_id: meta/llama-3.1-70b-instruct - provider_id: nvidia - provider_model_id: meta/llama-3.1-70b-instruct - model_type: llm -- metadata: {} - model_id: meta/llama-3.1-405b-instruct - provider_id: nvidia - provider_model_id: meta/llama-3.1-405b-instruct - model_type: llm -- metadata: {} - model_id: meta/llama-3.2-1b-instruct - provider_id: nvidia - provider_model_id: meta/llama-3.2-1b-instruct - model_type: llm -- metadata: {} - model_id: meta/llama-3.2-3b-instruct - provider_id: nvidia - provider_model_id: meta/llama-3.2-3b-instruct - model_type: llm -- metadata: {} - model_id: meta/llama-3.2-11b-vision-instruct - provider_id: nvidia - provider_model_id: meta/llama-3.2-11b-vision-instruct - model_type: llm -- metadata: {} - model_id: meta/llama-3.2-90b-vision-instruct - provider_id: nvidia - provider_model_id: meta/llama-3.2-90b-vision-instruct - model_type: llm -- metadata: {} - model_id: meta/llama-3.3-70b-instruct - provider_id: nvidia - provider_model_id: meta/llama-3.3-70b-instruct - model_type: llm -- metadata: {} - model_id: nvidia/vila - provider_id: nvidia - provider_model_id: nvidia/vila - model_type: llm -- metadata: - embedding_dimension: 2048 - context_length: 8192 - model_id: nvidia/llama-3.2-nv-embedqa-1b-v2 - provider_id: nvidia - provider_model_id: nvidia/llama-3.2-nv-embedqa-1b-v2 - model_type: embedding -- metadata: - embedding_dimension: 1024 - context_length: 512 - model_id: nvidia/nv-embedqa-e5-v5 - provider_id: nvidia - provider_model_id: nvidia/nv-embedqa-e5-v5 - model_type: embedding -- metadata: - embedding_dimension: 4096 - context_length: 512 - model_id: nvidia/nv-embedqa-mistral-7b-v2 - provider_id: nvidia - provider_model_id: nvidia/nv-embedqa-mistral-7b-v2 - model_type: embedding -- metadata: - embedding_dimension: 1024 - context_length: 512 - model_id: snowflake/arctic-embed-l - provider_id: nvidia - provider_model_id: snowflake/arctic-embed-l - model_type: embedding +models: [] shields: [] vector_dbs: [] datasets: [] diff --git a/llama_stack/providers/remote/eval/nvidia/eval.py b/llama_stack/providers/remote/eval/nvidia/eval.py index a474e78e3..8fc7ffdd3 100644 --- a/llama_stack/providers/remote/eval/nvidia/eval.py +++ b/llama_stack/providers/remote/eval/nvidia/eval.py @@ -14,7 +14,6 @@ from llama_stack.apis.datasets import Datasets from llama_stack.apis.inference import Inference from llama_stack.apis.scoring import Scoring, ScoringResult from llama_stack.providers.datatypes import BenchmarksProtocolPrivate -from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from .....apis.common.job_types import Job, JobStatus @@ -45,7 +44,7 @@ class NVIDIAEvalImpl( self.inference_api = inference_api self.agents_api = agents_api - ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES) + ModelRegistryHelper.__init__(self) async def initialize(self) -> None: ... diff --git a/llama_stack/providers/remote/inference/anthropic/anthropic.py b/llama_stack/providers/remote/inference/anthropic/anthropic.py index 0f247218d..cdde4a411 100644 --- a/llama_stack/providers/remote/inference/anthropic/anthropic.py +++ b/llama_stack/providers/remote/inference/anthropic/anthropic.py @@ -8,14 +8,24 @@ from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOp from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from .config import AnthropicConfig -from .models import MODEL_ENTRIES class AnthropicInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin): + # source: https://docs.claude.com/en/docs/build-with-claude/embeddings + # TODO: add support for voyageai, which is where these models are hosted + # embedding_model_metadata = { + # "voyage-3-large": {"embedding_dimension": 1024, "context_length": 32000}, # supports dimensions 256, 512, 1024, 2048 + # "voyage-3.5": {"embedding_dimension": 1024, "context_length": 32000}, # supports dimensions 256, 512, 1024, 2048 + # "voyage-3.5-lite": {"embedding_dimension": 1024, "context_length": 32000}, # supports dimensions 256, 512, 1024, 2048 + # "voyage-code-3": {"embedding_dimension": 1024, "context_length": 32000}, # supports dimensions 256, 512, 1024, 2048 + # "voyage-finance-2": {"embedding_dimension": 1024, "context_length": 32000}, + # "voyage-law-2": {"embedding_dimension": 1024, "context_length": 16000}, + # "voyage-multimodal-3": {"embedding_dimension": 1024, "context_length": 32000}, + # } + def __init__(self, config: AnthropicConfig) -> None: LiteLLMOpenAIMixin.__init__( self, - MODEL_ENTRIES, litellm_provider_name="anthropic", api_key_from_config=config.api_key, provider_data_api_key_field="anthropic_api_key", diff --git a/llama_stack/providers/remote/inference/anthropic/models.py b/llama_stack/providers/remote/inference/anthropic/models.py deleted file mode 100644 index 4cbe44b02..000000000 --- a/llama_stack/providers/remote/inference/anthropic/models.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.apis.models import ModelType -from llama_stack.providers.utils.inference.model_registry import ( - ProviderModelEntry, -) - -LLM_MODEL_IDS = [ - "claude-3-5-sonnet-latest", - "claude-3-7-sonnet-latest", - "claude-3-5-haiku-latest", -] - -SAFETY_MODELS_ENTRIES = [] - -MODEL_ENTRIES = ( - [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] - + [ - ProviderModelEntry( - provider_model_id="voyage-3", - model_type=ModelType.embedding, - metadata={"embedding_dimension": 1024, "context_length": 32000}, - ), - ProviderModelEntry( - provider_model_id="voyage-3-lite", - model_type=ModelType.embedding, - metadata={"embedding_dimension": 512, "context_length": 32000}, - ), - ProviderModelEntry( - provider_model_id="voyage-code-3", - model_type=ModelType.embedding, - metadata={"embedding_dimension": 1024, "context_length": 32000}, - ), - ] - + SAFETY_MODELS_ENTRIES -) diff --git a/llama_stack/providers/remote/inference/azure/azure.py b/llama_stack/providers/remote/inference/azure/azure.py index 449bbbb1c..a2c69b69c 100644 --- a/llama_stack/providers/remote/inference/azure/azure.py +++ b/llama_stack/providers/remote/inference/azure/azure.py @@ -14,14 +14,12 @@ from llama_stack.providers.utils.inference.litellm_openai_mixin import ( from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from .config import AzureConfig -from .models import MODEL_ENTRIES class AzureInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin): def __init__(self, config: AzureConfig) -> None: LiteLLMOpenAIMixin.__init__( self, - MODEL_ENTRIES, litellm_provider_name="azure", api_key_from_config=config.api_key.get_secret_value(), provider_data_api_key_field="azure_api_key", diff --git a/llama_stack/providers/remote/inference/azure/models.py b/llama_stack/providers/remote/inference/azure/models.py deleted file mode 100644 index 64c87969b..000000000 --- a/llama_stack/providers/remote/inference/azure/models.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.providers.utils.inference.model_registry import ( - ProviderModelEntry, -) - -# https://learn.microsoft.com/en-us/azure/ai-foundry/openai/concepts/models?tabs=global-standard%2Cstandard-chat-completions -LLM_MODEL_IDS = [ - "gpt-5", - "gpt-5-mini", - "gpt-5-nano", - "gpt-5-chat", - "o1", - "o1-mini", - "o3-mini", - "o4-mini", - "gpt-4.1", - "gpt-4.1-mini", - "gpt-4.1-nano", -] - -SAFETY_MODELS_ENTRIES = list[ProviderModelEntry]() - -MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + SAFETY_MODELS_ENTRIES diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index 106caed9b..29b935bbd 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -98,7 +98,7 @@ class BedrockInferenceAdapter( OpenAICompletionToLlamaStackMixin, ): def __init__(self, config: BedrockConfig) -> None: - ModelRegistryHelper.__init__(self, MODEL_ENTRIES) + ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES) self._config = config self._client = None diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py index 6947dbc87..6662f004d 100644 --- a/llama_stack/providers/remote/inference/cerebras/cerebras.py +++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -49,7 +49,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( ) from .config import CerebrasImplConfig -from .models import MODEL_ENTRIES class CerebrasInferenceAdapter( @@ -58,10 +57,6 @@ class CerebrasInferenceAdapter( Inference, ): def __init__(self, config: CerebrasImplConfig) -> None: - ModelRegistryHelper.__init__( - self, - model_entries=MODEL_ENTRIES, - ) self.config = config # TODO: make this use provider data, etc. like other providers diff --git a/llama_stack/providers/remote/inference/cerebras/models.py b/llama_stack/providers/remote/inference/cerebras/models.py deleted file mode 100644 index 4de2e62c9..000000000 --- a/llama_stack/providers/remote/inference/cerebras/models.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.models.llama.sku_types import CoreModelId -from llama_stack.providers.utils.inference.model_registry import ( - build_hf_repo_model_entry, -) - -SAFETY_MODELS_ENTRIES = [] - -# https://inference-docs.cerebras.ai/models -MODEL_ENTRIES = [ - build_hf_repo_model_entry( - "llama3.1-8b", - CoreModelId.llama3_1_8b_instruct.value, - ), - build_hf_repo_model_entry( - "llama-3.3-70b", - CoreModelId.llama3_3_70b_instruct.value, - ), - build_hf_repo_model_entry( - "llama-4-scout-17b-16e-instruct", - CoreModelId.llama4_scout_17b_16e_instruct.value, - ), -] + SAFETY_MODELS_ENTRIES diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index f2dc302e0..25fd9f3b7 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -23,6 +23,8 @@ from llama_stack.apis.inference import ( Inference, LogProbConfig, Message, + Model, + ModelType, OpenAICompletion, ResponseFormat, SamplingParams, @@ -32,11 +34,7 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) -from llama_stack.apis.models import Model, ModelType from llama_stack.log import get_logger -from llama_stack.providers.utils.inference.model_registry import ( - ProviderModelEntry, -) from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from .config import DatabricksImplConfig @@ -44,29 +42,16 @@ from .config import DatabricksImplConfig logger = get_logger(name=__name__, category="inference::databricks") -# source: https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/supported-models -EMBEDDING_MODEL_ENTRIES = { - "databricks-gte-large-en": ProviderModelEntry( - provider_model_id="databricks-gte-large-en", - metadata={ - "embedding_dimension": 1024, - "context_length": 8192, - }, - ), - "databricks-bge-large-en": ProviderModelEntry( - provider_model_id="databricks-bge-large-en", - metadata={ - "embedding_dimension": 1024, - "context_length": 512, - }, - ), -} - - class DatabricksInferenceAdapter( OpenAIMixin, Inference, ): + # source: https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/supported-models + embedding_model_metadata = { + "databricks-gte-large-en": {"embedding_dimension": 1024, "context_length": 8192}, + "databricks-bge-large-en": {"embedding_dimension": 1024, "context_length": 512}, + } + def __init__(self, config: DatabricksImplConfig) -> None: self.config = config @@ -156,11 +141,11 @@ class DatabricksInferenceAdapter( if endpoint.task == "llm/v1/chat": model.model_type = ModelType.llm # this is redundant, but informative elif endpoint.task == "llm/v1/embeddings": - if endpoint.name not in EMBEDDING_MODEL_ENTRIES: + if endpoint.name not in self.embedding_model_metadata: logger.warning(f"No metadata information available for embedding model {endpoint.name}, skipping.") continue model.model_type = ModelType.embedding - model.metadata = EMBEDDING_MODEL_ENTRIES[endpoint.name].metadata + model.metadata = self.embedding_model_metadata[endpoint.name] else: logger.warning(f"Unknown model type, skipping: {endpoint}") continue @@ -169,13 +154,5 @@ class DatabricksInferenceAdapter( return list(self._model_cache.values()) - async def register_model(self, model: Model) -> Model: - if not await self.check_model_availability(model.provider_resource_id): - raise ValueError(f"Model {model.provider_resource_id} is not available in Databricks workspace.") - return model - - async def unregister_model(self, model_id: str) -> None: - pass - async def should_refresh_models(self) -> bool: return False diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index 2fcf1be2e..cf7e93974 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -54,15 +54,18 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( ) from .config import FireworksImplConfig -from .models import MODEL_ENTRIES logger = get_logger(name=__name__, category="inference::fireworks") class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, NeedsRequestProviderData): + embedding_model_metadata = { + "nomic-ai/nomic-embed-text-v1.5": {"embedding_dimension": 768, "context_length": 8192}, + } + def __init__(self, config: FireworksImplConfig) -> None: - ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models) self.config = config + self.allowed_models = config.allowed_models async def initialize(self) -> None: pass diff --git a/llama_stack/providers/remote/inference/fireworks/models.py b/llama_stack/providers/remote/inference/fireworks/models.py deleted file mode 100644 index 30807a0d4..000000000 --- a/llama_stack/providers/remote/inference/fireworks/models.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.apis.models import ModelType -from llama_stack.models.llama.sku_types import CoreModelId -from llama_stack.providers.utils.inference.model_registry import ( - ProviderModelEntry, - build_hf_repo_model_entry, -) - -SAFETY_MODELS_ENTRIES = [ - build_hf_repo_model_entry( - "accounts/fireworks/models/llama-guard-3-8b", - CoreModelId.llama_guard_3_8b.value, - ), - build_hf_repo_model_entry( - "accounts/fireworks/models/llama-guard-3-11b-vision", - CoreModelId.llama_guard_3_11b_vision.value, - ), -] - -MODEL_ENTRIES = [ - build_hf_repo_model_entry( - "accounts/fireworks/models/llama-v3p1-8b-instruct", - CoreModelId.llama3_1_8b_instruct.value, - ), - build_hf_repo_model_entry( - "accounts/fireworks/models/llama-v3p1-70b-instruct", - CoreModelId.llama3_1_70b_instruct.value, - ), - build_hf_repo_model_entry( - "accounts/fireworks/models/llama-v3p1-405b-instruct", - CoreModelId.llama3_1_405b_instruct.value, - ), - build_hf_repo_model_entry( - "accounts/fireworks/models/llama-v3p2-3b-instruct", - CoreModelId.llama3_2_3b_instruct.value, - ), - build_hf_repo_model_entry( - "accounts/fireworks/models/llama-v3p2-11b-vision-instruct", - CoreModelId.llama3_2_11b_vision_instruct.value, - ), - build_hf_repo_model_entry( - "accounts/fireworks/models/llama-v3p2-90b-vision-instruct", - CoreModelId.llama3_2_90b_vision_instruct.value, - ), - build_hf_repo_model_entry( - "accounts/fireworks/models/llama-v3p3-70b-instruct", - CoreModelId.llama3_3_70b_instruct.value, - ), - build_hf_repo_model_entry( - "accounts/fireworks/models/llama4-scout-instruct-basic", - CoreModelId.llama4_scout_17b_16e_instruct.value, - ), - build_hf_repo_model_entry( - "accounts/fireworks/models/llama4-maverick-instruct-basic", - CoreModelId.llama4_maverick_17b_128e_instruct.value, - ), - ProviderModelEntry( - provider_model_id="nomic-ai/nomic-embed-text-v1.5", - model_type=ModelType.embedding, - metadata={ - "embedding_dimension": 768, - "context_length": 8192, - }, - ), -] + SAFETY_MODELS_ENTRIES diff --git a/llama_stack/providers/remote/inference/gemini/gemini.py b/llama_stack/providers/remote/inference/gemini/gemini.py index 569227fdd..30ceedff0 100644 --- a/llama_stack/providers/remote/inference/gemini/gemini.py +++ b/llama_stack/providers/remote/inference/gemini/gemini.py @@ -8,14 +8,16 @@ from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOp from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from .config import GeminiConfig -from .models import MODEL_ENTRIES class GeminiInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin): + embedding_model_metadata = { + "text-embedding-004": {"embedding_dimension": 768, "context_length": 2048}, + } + def __init__(self, config: GeminiConfig) -> None: LiteLLMOpenAIMixin.__init__( self, - MODEL_ENTRIES, litellm_provider_name="gemini", api_key_from_config=config.api_key, provider_data_api_key_field="gemini_api_key", diff --git a/llama_stack/providers/remote/inference/gemini/models.py b/llama_stack/providers/remote/inference/gemini/models.py deleted file mode 100644 index bd696b0ac..000000000 --- a/llama_stack/providers/remote/inference/gemini/models.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.apis.models import ModelType -from llama_stack.providers.utils.inference.model_registry import ( - ProviderModelEntry, -) - -LLM_MODEL_IDS = [ - "gemini-1.5-flash", - "gemini-1.5-pro", - "gemini-2.0-flash", - "gemini-2.0-flash-lite", - "gemini-2.5-flash", - "gemini-2.5-flash-lite", - "gemini-2.5-pro", -] - -SAFETY_MODELS_ENTRIES = [] - -MODEL_ENTRIES = ( - [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] - + [ - ProviderModelEntry( - provider_model_id="text-embedding-004", - model_type=ModelType.embedding, - metadata={"embedding_dimension": 768, "context_length": 2048}, - ), - ] - + SAFETY_MODELS_ENTRIES -) diff --git a/llama_stack/providers/remote/inference/groq/groq.py b/llama_stack/providers/remote/inference/groq/groq.py index 888953af0..e449f2005 100644 --- a/llama_stack/providers/remote/inference/groq/groq.py +++ b/llama_stack/providers/remote/inference/groq/groq.py @@ -9,8 +9,6 @@ from llama_stack.providers.remote.inference.groq.config import GroqConfig from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin -from .models import MODEL_ENTRIES - class GroqInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin): _config: GroqConfig @@ -18,7 +16,6 @@ class GroqInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin): def __init__(self, config: GroqConfig): LiteLLMOpenAIMixin.__init__( self, - model_entries=MODEL_ENTRIES, litellm_provider_name="groq", api_key_from_config=config.api_key, provider_data_api_key_field="groq_api_key", diff --git a/llama_stack/providers/remote/inference/groq/models.py b/llama_stack/providers/remote/inference/groq/models.py deleted file mode 100644 index fac66db72..000000000 --- a/llama_stack/providers/remote/inference/groq/models.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.models.llama.sku_list import CoreModelId -from llama_stack.providers.utils.inference.model_registry import ( - build_hf_repo_model_entry, - build_model_entry, -) - -SAFETY_MODELS_ENTRIES = [] - -MODEL_ENTRIES = [ - build_hf_repo_model_entry( - "llama3-8b-8192", - CoreModelId.llama3_1_8b_instruct.value, - ), - build_model_entry( - "llama-3.1-8b-instant", - CoreModelId.llama3_1_8b_instruct.value, - ), - build_hf_repo_model_entry( - "llama3-70b-8192", - CoreModelId.llama3_70b_instruct.value, - ), - build_hf_repo_model_entry( - "llama-3.3-70b-versatile", - CoreModelId.llama3_3_70b_instruct.value, - ), - # Groq only contains a preview version for llama-3.2-3b - # Preview models aren't recommended for production use, but we include this one - # to pass the test fixture - # TODO(aidand): Replace this with a stable model once Groq supports it - build_hf_repo_model_entry( - "llama-3.2-3b-preview", - CoreModelId.llama3_2_3b_instruct.value, - ), - build_hf_repo_model_entry( - "meta-llama/llama-4-scout-17b-16e-instruct", - CoreModelId.llama4_scout_17b_16e_instruct.value, - ), - build_hf_repo_model_entry( - "meta-llama/llama-4-maverick-17b-128e-instruct", - CoreModelId.llama4_maverick_17b_128e_instruct.value, - ), -] + SAFETY_MODELS_ENTRIES diff --git a/llama_stack/providers/remote/inference/llama_openai_compat/llama.py b/llama_stack/providers/remote/inference/llama_openai_compat/llama.py index f2069b5e5..489b12a68 100644 --- a/llama_stack/providers/remote/inference/llama_openai_compat/llama.py +++ b/llama_stack/providers/remote/inference/llama_openai_compat/llama.py @@ -8,8 +8,6 @@ from llama_stack.providers.remote.inference.llama_openai_compat.config import Ll from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin -from .models import MODEL_ENTRIES - logger = get_logger(name=__name__, category="inference::llama_openai_compat") @@ -30,7 +28,6 @@ class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin): def __init__(self, config: LlamaCompatConfig): LiteLLMOpenAIMixin.__init__( self, - model_entries=MODEL_ENTRIES, litellm_provider_name="meta_llama", api_key_from_config=config.api_key, provider_data_api_key_field="llama_api_key", diff --git a/llama_stack/providers/remote/inference/llama_openai_compat/models.py b/llama_stack/providers/remote/inference/llama_openai_compat/models.py deleted file mode 100644 index 6285e98e1..000000000 --- a/llama_stack/providers/remote/inference/llama_openai_compat/models.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.models.llama.sku_types import CoreModelId -from llama_stack.providers.utils.inference.model_registry import ( - build_hf_repo_model_entry, -) - -MODEL_ENTRIES = [ - build_hf_repo_model_entry( - "Llama-3.3-70B-Instruct", - CoreModelId.llama3_3_70b_instruct.value, - ), - build_hf_repo_model_entry( - "Llama-4-Scout-17B-16E-Instruct-FP8", - CoreModelId.llama4_scout_17b_16e_instruct.value, - ), - build_hf_repo_model_entry( - "Llama-4-Maverick-17B-128E-Instruct-FP8", - CoreModelId.llama4_maverick_17b_128e_instruct.value, - ), -] diff --git a/llama_stack/providers/remote/inference/nvidia/models.py b/llama_stack/providers/remote/inference/nvidia/models.py deleted file mode 100644 index df07f46b6..000000000 --- a/llama_stack/providers/remote/inference/nvidia/models.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.apis.models import ModelType -from llama_stack.models.llama.sku_types import CoreModelId -from llama_stack.providers.utils.inference.model_registry import ( - ProviderModelEntry, - build_hf_repo_model_entry, -) - -SAFETY_MODELS_ENTRIES = [] - -# https://docs.nvidia.com/nim/large-language-models/latest/supported-llm-agnostic-architectures.html -MODEL_ENTRIES = [ - build_hf_repo_model_entry( - "meta/llama3-8b-instruct", - CoreModelId.llama3_8b_instruct.value, - ), - build_hf_repo_model_entry( - "meta/llama3-70b-instruct", - CoreModelId.llama3_70b_instruct.value, - ), - build_hf_repo_model_entry( - "meta/llama-3.1-8b-instruct", - CoreModelId.llama3_1_8b_instruct.value, - ), - build_hf_repo_model_entry( - "meta/llama-3.1-70b-instruct", - CoreModelId.llama3_1_70b_instruct.value, - ), - build_hf_repo_model_entry( - "meta/llama-3.1-405b-instruct", - CoreModelId.llama3_1_405b_instruct.value, - ), - build_hf_repo_model_entry( - "meta/llama-3.2-1b-instruct", - CoreModelId.llama3_2_1b_instruct.value, - ), - build_hf_repo_model_entry( - "meta/llama-3.2-3b-instruct", - CoreModelId.llama3_2_3b_instruct.value, - ), - build_hf_repo_model_entry( - "meta/llama-3.2-11b-vision-instruct", - CoreModelId.llama3_2_11b_vision_instruct.value, - ), - build_hf_repo_model_entry( - "meta/llama-3.2-90b-vision-instruct", - CoreModelId.llama3_2_90b_vision_instruct.value, - ), - build_hf_repo_model_entry( - "meta/llama-3.3-70b-instruct", - CoreModelId.llama3_3_70b_instruct.value, - ), - ProviderModelEntry( - provider_model_id="nvidia/vila", - model_type=ModelType.llm, - ), - # NeMo Retriever Text Embedding models - - # - # https://docs.nvidia.com/nim/nemo-retriever/text-embedding/latest/support-matrix.html - # - # +-----------------------------------+--------+-----------+-----------+------------+ - # | Model ID | Max | Publisher | Embedding | Dynamic | - # | | Tokens | | Dimension | Embeddings | - # +-----------------------------------+--------+-----------+-----------+------------+ - # | nvidia/llama-3.2-nv-embedqa-1b-v2 | 8192 | NVIDIA | 2048 | Yes | - # | nvidia/nv-embedqa-e5-v5 | 512 | NVIDIA | 1024 | No | - # | nvidia/nv-embedqa-mistral-7b-v2 | 512 | NVIDIA | 4096 | No | - # | snowflake/arctic-embed-l | 512 | Snowflake | 1024 | No | - # +-----------------------------------+--------+-----------+-----------+------------+ - ProviderModelEntry( - provider_model_id="nvidia/llama-3.2-nv-embedqa-1b-v2", - model_type=ModelType.embedding, - metadata={ - "embedding_dimension": 2048, - "context_length": 8192, - }, - ), - ProviderModelEntry( - provider_model_id="nvidia/nv-embedqa-e5-v5", - model_type=ModelType.embedding, - metadata={ - "embedding_dimension": 1024, - "context_length": 512, - }, - ), - ProviderModelEntry( - provider_model_id="nvidia/nv-embedqa-mistral-7b-v2", - model_type=ModelType.embedding, - metadata={ - "embedding_dimension": 4096, - "context_length": 512, - }, - ), - ProviderModelEntry( - provider_model_id="snowflake/arctic-embed-l", - model_type=ModelType.embedding, - metadata={ - "embedding_dimension": 1024, - "context_length": 512, - }, - ), - # TODO(mf): how do we handle Nemotron models? - # "Llama3.1-Nemotron-51B-Instruct" -> "meta/llama-3.1-nemotron-51b-instruct", -] + SAFETY_MODELS_ENTRIES diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index a5475bc92..92094a0f3 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -37,9 +37,6 @@ from llama_stack.apis.inference import ( ) from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat -from llama_stack.providers.utils.inference.model_registry import ( - ModelRegistryHelper, -) from llama_stack.providers.utils.inference.openai_compat import ( convert_openai_chat_completion_choice, convert_openai_chat_completion_stream, @@ -48,7 +45,6 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.prompt_adapter import content_has_media from . import NVIDIAConfig -from .models import MODEL_ENTRIES from .openai_utils import ( convert_chat_completion_request, convert_completion_request, @@ -60,7 +56,7 @@ from .utils import _is_nvidia_hosted logger = get_logger(name=__name__, category="inference::nvidia") -class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper): +class NVIDIAInferenceAdapter(OpenAIMixin, Inference): """ NVIDIA Inference Adapter for Llama Stack. @@ -74,10 +70,15 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper): - ModelRegistryHelper.check_model_availability() just returns False and shows a warning """ - def __init__(self, config: NVIDIAConfig) -> None: - # TODO(mf): filter by available models - ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES) + # source: https://docs.nvidia.com/nim/nemo-retriever/text-embedding/latest/support-matrix.html + embedding_model_metadata = { + "nvidia/llama-3.2-nv-embedqa-1b-v2": {"embedding_dimension": 2048, "context_length": 8192}, + "nvidia/nv-embedqa-e5-v5": {"embedding_dimension": 512, "context_length": 1024}, + "nvidia/nv-embedqa-mistral-7b-v2": {"embedding_dimension": 512, "context_length": 4096}, + "snowflake/arctic-embed-l": {"embedding_dimension": 512, "context_length": 1024}, + } + def __init__(self, config: NVIDIAConfig) -> None: logger.info(f"Initializing NVIDIAInferenceAdapter({config.url})...") if _is_nvidia_hosted(config): diff --git a/llama_stack/providers/remote/inference/ollama/models.py b/llama_stack/providers/remote/inference/ollama/models.py deleted file mode 100644 index 7c0a19a1a..000000000 --- a/llama_stack/providers/remote/inference/ollama/models.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.apis.models import ModelType -from llama_stack.models.llama.sku_types import CoreModelId -from llama_stack.providers.utils.inference.model_registry import ( - ProviderModelEntry, - build_hf_repo_model_entry, - build_model_entry, -) - -SAFETY_MODELS_ENTRIES = [ - # The Llama Guard models don't have their full fp16 versions - # so we are going to alias their default version to the canonical SKU - build_hf_repo_model_entry( - "llama-guard3:8b", - CoreModelId.llama_guard_3_8b.value, - ), - build_hf_repo_model_entry( - "llama-guard3:1b", - CoreModelId.llama_guard_3_1b.value, - ), -] - -MODEL_ENTRIES = [ - build_hf_repo_model_entry( - "llama3.1:8b-instruct-fp16", - CoreModelId.llama3_1_8b_instruct.value, - ), - build_model_entry( - "llama3.1:8b", - CoreModelId.llama3_1_8b_instruct.value, - ), - build_hf_repo_model_entry( - "llama3.1:70b-instruct-fp16", - CoreModelId.llama3_1_70b_instruct.value, - ), - build_model_entry( - "llama3.1:70b", - CoreModelId.llama3_1_70b_instruct.value, - ), - build_hf_repo_model_entry( - "llama3.1:405b-instruct-fp16", - CoreModelId.llama3_1_405b_instruct.value, - ), - build_model_entry( - "llama3.1:405b", - CoreModelId.llama3_1_405b_instruct.value, - ), - build_hf_repo_model_entry( - "llama3.2:1b-instruct-fp16", - CoreModelId.llama3_2_1b_instruct.value, - ), - build_model_entry( - "llama3.2:1b", - CoreModelId.llama3_2_1b_instruct.value, - ), - build_hf_repo_model_entry( - "llama3.2:3b-instruct-fp16", - CoreModelId.llama3_2_3b_instruct.value, - ), - build_model_entry( - "llama3.2:3b", - CoreModelId.llama3_2_3b_instruct.value, - ), - build_hf_repo_model_entry( - "llama3.2-vision:11b-instruct-fp16", - CoreModelId.llama3_2_11b_vision_instruct.value, - ), - build_model_entry( - "llama3.2-vision:latest", - CoreModelId.llama3_2_11b_vision_instruct.value, - ), - build_hf_repo_model_entry( - "llama3.2-vision:90b-instruct-fp16", - CoreModelId.llama3_2_90b_vision_instruct.value, - ), - build_model_entry( - "llama3.2-vision:90b", - CoreModelId.llama3_2_90b_vision_instruct.value, - ), - build_hf_repo_model_entry( - "llama3.3:70b", - CoreModelId.llama3_3_70b_instruct.value, - ), - ProviderModelEntry( - provider_model_id="all-minilm:l6-v2", - aliases=["all-minilm"], - model_type=ModelType.embedding, - metadata={ - "embedding_dimension": 384, - "context_length": 512, - }, - ), - ProviderModelEntry( - provider_model_id="nomic-embed-text", - model_type=ModelType.embedding, - metadata={ - "embedding_dimension": 768, - "context_length": 8192, - }, - ), -] + SAFETY_MODELS_ENTRIES diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 67a22cbe3..81a5fb9ad 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -45,8 +45,9 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) -from llama_stack.apis.models import Model, ModelType +from llama_stack.apis.models import Model from llama_stack.log import get_logger +from llama_stack.models.llama.sku_types import CoreModelId from llama_stack.providers.datatypes import ( HealthResponse, HealthStatus, @@ -55,6 +56,7 @@ from llama_stack.providers.datatypes import ( from llama_stack.providers.remote.inference.ollama.config import OllamaImplConfig from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, + build_hf_repo_model_entry, ) from llama_stack.providers.utils.inference.openai_compat import ( OpenAICompatCompletionChoice, @@ -77,8 +79,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( request_has_media, ) -from .models import MODEL_ENTRIES - logger = get_logger(name=__name__, category="inference::ollama") @@ -90,8 +90,44 @@ class OllamaInferenceAdapter( # automatically set by the resolver when instantiating the provider __provider_id__: str + embedding_model_metadata = { + "all-minilm:l6-v2": { + "embedding_dimension": 384, + "context_length": 512, + }, + "nomic-embed-text:latest": { + "embedding_dimension": 768, + "context_length": 8192, + }, + "nomic-embed-text:v1.5": { + "embedding_dimension": 768, + "context_length": 8192, + }, + "nomic-embed-text:137m-v1.5-fp16": { + "embedding_dimension": 768, + "context_length": 8192, + }, + } + def __init__(self, config: OllamaImplConfig) -> None: - self.register_helper = ModelRegistryHelper(MODEL_ENTRIES) + # TODO: remove ModelRegistryHelper.__init__ when completion and + # chat_completion are. this exists to satisfy the input / + # output processing for llama models. specifically, + # tool_calling is handled by raw template processing, + # instead of using the /api/chat endpoint w/ tools=... + ModelRegistryHelper.__init__( + self, + model_entries=[ + build_hf_repo_model_entry( + "llama3.2:3b-instruct-fp16", + CoreModelId.llama3_2_3b_instruct.value, + ), + build_hf_repo_model_entry( + "llama-guard3:1b", + CoreModelId.llama_guard_3_1b.value, + ), + ], + ) self.config = config self._clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {} @@ -120,59 +156,6 @@ class OllamaInferenceAdapter( async def should_refresh_models(self) -> bool: return self.config.refresh_models - async def list_models(self) -> list[Model] | None: - provider_id = self.__provider_id__ - response = await self.ollama_client.list() - - # always add the two embedding models which can be pulled on demand - models = [ - Model( - identifier="all-minilm:l6-v2", - provider_resource_id="all-minilm:l6-v2", - provider_id=provider_id, - metadata={ - "embedding_dimension": 384, - "context_length": 512, - }, - model_type=ModelType.embedding, - ), - # add all-minilm alias - Model( - identifier="all-minilm", - provider_resource_id="all-minilm:l6-v2", - provider_id=provider_id, - metadata={ - "embedding_dimension": 384, - "context_length": 512, - }, - model_type=ModelType.embedding, - ), - Model( - identifier="nomic-embed-text", - provider_resource_id="nomic-embed-text:latest", - provider_id=provider_id, - metadata={ - "embedding_dimension": 768, - "context_length": 8192, - }, - model_type=ModelType.embedding, - ), - ] - for m in response.models: - # kill embedding models since we don't know dimensions for them - if "bert" in m.details.family: - continue - models.append( - Model( - identifier=m.model, - provider_resource_id=m.model, - provider_id=provider_id, - metadata={}, - model_type=ModelType.llm, - ) - ) - return models - async def health(self) -> HealthResponse: """ Performs a health check by verifying connectivity to the Ollama server. @@ -301,7 +284,7 @@ class OllamaInferenceAdapter( input_dict: dict[str, Any] = {} media_present = request_has_media(request) - llama_model = self.register_helper.get_llama_model(request.model) + llama_model = self.get_llama_model(request.model) if isinstance(request, ChatCompletionRequest): if media_present or not llama_model: contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages] @@ -409,37 +392,16 @@ class OllamaInferenceAdapter( return EmbeddingsResponse(embeddings=embeddings) async def register_model(self, model: Model) -> Model: - try: - model = await self.register_helper.register_model(model) - except ValueError: - pass # Ignore statically unknown model, will check live listing + if await self.check_model_availability(model.provider_model_id): + return model + elif await self.check_model_availability(f"{model.provider_model_id}:latest"): + model.provider_resource_id = f"{model.provider_model_id}:latest" + logger.warning( + f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_model_id}'" + ) + return model - if model.model_type == ModelType.embedding: - response = await self.ollama_client.list() - if model.provider_resource_id not in [m.model for m in response.models]: - await self.ollama_client.pull(model.provider_resource_id) - - # we use list() here instead of ps() - - # - ps() only lists running models, not available models - # - models not currently running are run by the ollama server as needed - response = await self.ollama_client.list() - available_models = [m.model for m in response.models] - - provider_resource_id = model.provider_resource_id - assert provider_resource_id is not None # mypy - if provider_resource_id not in available_models: - available_models_latest = [m.model.split(":latest")[0] for m in response.models] - if provider_resource_id in available_models_latest: - logger.warning( - f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'" - ) - return model - raise UnsupportedModelError(provider_resource_id, available_models) - - # mutating this should be considered an anti-pattern - model.provider_resource_id = provider_resource_id - - return model + raise UnsupportedModelError(model.provider_model_id, list(self._model_cache.keys())) async def openai_chat_completion( self, diff --git a/llama_stack/providers/remote/inference/openai/models.py b/llama_stack/providers/remote/inference/openai/models.py deleted file mode 100644 index 28d0c4b41..000000000 --- a/llama_stack/providers/remote/inference/openai/models.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from dataclasses import dataclass - -from llama_stack.apis.models import ModelType -from llama_stack.providers.utils.inference.model_registry import ( - ProviderModelEntry, -) - -LLM_MODEL_IDS = [ - "gpt-3.5-turbo-0125", - "gpt-3.5-turbo", - "gpt-3.5-turbo-instruct", - "gpt-4", - "gpt-4-turbo", - "gpt-4o", - "gpt-4o-2024-08-06", - "gpt-4o-mini", - "gpt-4o-audio-preview", - "chatgpt-4o-latest", - "o1", - "o1-mini", - "o3-mini", - "o4-mini", -] - - -@dataclass -class EmbeddingModelInfo: - """Structured representation of embedding model information.""" - - embedding_dimension: int - context_length: int - - -EMBEDDING_MODEL_IDS: dict[str, EmbeddingModelInfo] = { - "text-embedding-3-small": EmbeddingModelInfo(1536, 8192), - "text-embedding-3-large": EmbeddingModelInfo(3072, 8192), -} -SAFETY_MODELS_ENTRIES = [] - -MODEL_ENTRIES = ( - [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] - + [ - ProviderModelEntry( - provider_model_id=model_id, - model_type=ModelType.embedding, - metadata={ - "embedding_dimension": model_info.embedding_dimension, - "context_length": model_info.context_length, - }, - ) - for model_id, model_info in EMBEDDING_MODEL_IDS.items() - ] - + SAFETY_MODELS_ENTRIES -) diff --git a/llama_stack/providers/remote/inference/openai/openai.py b/llama_stack/providers/remote/inference/openai/openai.py index 0f73c9321..18530f20b 100644 --- a/llama_stack/providers/remote/inference/openai/openai.py +++ b/llama_stack/providers/remote/inference/openai/openai.py @@ -9,7 +9,6 @@ from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOp from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from .config import OpenAIConfig -from .models import MODEL_ENTRIES logger = get_logger(name=__name__, category="inference::openai") @@ -40,10 +39,14 @@ class OpenAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin): - ModelRegistryHelper.check_model_availability() (inherited by LiteLLMOpenAIMixin) just returns False and shows a warning """ + embedding_model_metadata = { + "text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192}, + "text-embedding-3-large": {"embedding_dimension": 3072, "context_length": 8192}, + } + def __init__(self, config: OpenAIConfig) -> None: LiteLLMOpenAIMixin.__init__( self, - MODEL_ENTRIES, litellm_provider_name="openai", api_key_from_config=config.api_key, provider_data_api_key_field="openai_api_key", diff --git a/llama_stack/providers/remote/inference/passthrough/passthrough.py b/llama_stack/providers/remote/inference/passthrough/passthrough.py index 2f1cd40f2..a2bdf0369 100644 --- a/llama_stack/providers/remote/inference/passthrough/passthrough.py +++ b/llama_stack/providers/remote/inference/passthrough/passthrough.py @@ -43,7 +43,7 @@ from .config import PassthroughImplConfig class PassthroughInferenceAdapter(Inference): def __init__(self, config: PassthroughImplConfig) -> None: - ModelRegistryHelper.__init__(self, []) + ModelRegistryHelper.__init__(self) self.config = config async def initialize(self) -> None: diff --git a/llama_stack/providers/remote/inference/sambanova/models.py b/llama_stack/providers/remote/inference/sambanova/models.py deleted file mode 100644 index db781eb86..000000000 --- a/llama_stack/providers/remote/inference/sambanova/models.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.models.llama.sku_types import CoreModelId -from llama_stack.providers.utils.inference.model_registry import ( - build_hf_repo_model_entry, -) - -SAFETY_MODELS_ENTRIES = [] - - -MODEL_ENTRIES = [ - build_hf_repo_model_entry( - "Meta-Llama-3.1-8B-Instruct", - CoreModelId.llama3_1_8b_instruct.value, - ), - build_hf_repo_model_entry( - "Meta-Llama-3.3-70B-Instruct", - CoreModelId.llama3_3_70b_instruct.value, - ), - build_hf_repo_model_entry( - "Llama-4-Maverick-17B-128E-Instruct", - CoreModelId.llama4_maverick_17b_128e_instruct.value, - ), -] + SAFETY_MODELS_ENTRIES diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index ee3b0f648..6121e81f7 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -9,7 +9,6 @@ from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOp from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from .config import SambaNovaImplConfig -from .models import MODEL_ENTRIES class SambaNovaInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin): @@ -29,7 +28,6 @@ class SambaNovaInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin): self.environment_available_models = [] LiteLLMOpenAIMixin.__init__( self, - model_entries=MODEL_ENTRIES, litellm_provider_name="sambanova", api_key_from_config=self.config.api_key.get_secret_value() if self.config.api_key else None, provider_data_api_key_field="sambanova_api_key", diff --git a/llama_stack/providers/remote/inference/together/models.py b/llama_stack/providers/remote/inference/together/models.py deleted file mode 100644 index 2aba614cb..000000000 --- a/llama_stack/providers/remote/inference/together/models.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.models.llama.sku_types import CoreModelId -from llama_stack.providers.utils.inference.model_registry import ( - ProviderModelEntry, - build_hf_repo_model_entry, -) - -SAFETY_MODELS_ENTRIES = [ - build_hf_repo_model_entry( - "meta-llama/Llama-Guard-3-8B", - CoreModelId.llama_guard_3_8b.value, - ), - build_hf_repo_model_entry( - "meta-llama/Llama-Guard-3-11B-Vision-Turbo", - CoreModelId.llama_guard_3_11b_vision.value, - ), -] - -# source: https://docs.together.ai/docs/serverless-models#embedding-models -EMBEDDING_MODEL_ENTRIES = { - "togethercomputer/m2-bert-80M-32k-retrieval": ProviderModelEntry( - provider_model_id="togethercomputer/m2-bert-80M-32k-retrieval", - metadata={ - "embedding_dimension": 768, - "context_length": 32768, - }, - ), - "BAAI/bge-large-en-v1.5": ProviderModelEntry( - provider_model_id="BAAI/bge-large-en-v1.5", - metadata={ - "embedding_dimension": 1024, - "context_length": 512, - }, - ), - "BAAI/bge-base-en-v1.5": ProviderModelEntry( - provider_model_id="BAAI/bge-base-en-v1.5", - metadata={ - "embedding_dimension": 768, - "context_length": 512, - }, - ), - "Alibaba-NLP/gte-modernbert-base": ProviderModelEntry( - provider_model_id="Alibaba-NLP/gte-modernbert-base", - metadata={ - "embedding_dimension": 768, - "context_length": 8192, - }, - ), - "intfloat/multilingual-e5-large-instruct": ProviderModelEntry( - provider_model_id="intfloat/multilingual-e5-large-instruct", - metadata={ - "embedding_dimension": 1024, - "context_length": 512, - }, - ), -} -MODEL_ENTRIES = ( - [ - build_hf_repo_model_entry( - "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", - CoreModelId.llama3_1_8b_instruct.value, - ), - build_hf_repo_model_entry( - "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", - CoreModelId.llama3_1_70b_instruct.value, - ), - build_hf_repo_model_entry( - "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo", - CoreModelId.llama3_1_405b_instruct.value, - ), - build_hf_repo_model_entry( - "meta-llama/Llama-3.2-3B-Instruct-Turbo", - CoreModelId.llama3_2_3b_instruct.value, - ), - build_hf_repo_model_entry( - "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo", - CoreModelId.llama3_2_11b_vision_instruct.value, - ), - build_hf_repo_model_entry( - "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo", - CoreModelId.llama3_2_90b_vision_instruct.value, - ), - build_hf_repo_model_entry( - "meta-llama/Llama-3.3-70B-Instruct-Turbo", - CoreModelId.llama3_3_70b_instruct.value, - ), - build_hf_repo_model_entry( - "meta-llama/Llama-4-Scout-17B-16E-Instruct", - CoreModelId.llama4_scout_17b_16e_instruct.value, - ), - build_hf_repo_model_entry( - "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", - CoreModelId.llama4_maverick_17b_128e_instruct.value, - ), - ] - + SAFETY_MODELS_ENTRIES - + list(EMBEDDING_MODEL_ENTRIES.values()) -) diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index d45bd489f..653f84610 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -56,15 +56,22 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( ) from .config import TogetherImplConfig -from .models import EMBEDDING_MODEL_ENTRIES, MODEL_ENTRIES logger = get_logger(name=__name__, category="inference::together") class TogetherInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, NeedsRequestProviderData): + embedding_model_metadata = { + "togethercomputer/m2-bert-80M-32k-retrieval": {"embedding_dimension": 768, "context_length": 32768}, + "BAAI/bge-large-en-v1.5": {"embedding_dimension": 1024, "context_length": 512}, + "BAAI/bge-base-en-v1.5": {"embedding_dimension": 768, "context_length": 512}, + "Alibaba-NLP/gte-modernbert-base": {"embedding_dimension": 768, "context_length": 8192}, + "intfloat/multilingual-e5-large-instruct": {"embedding_dimension": 1024, "context_length": 512}, + } + def __init__(self, config: TogetherImplConfig) -> None: - ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models) self.config = config + self.allowed_models = config.allowed_models self._model_cache: dict[str, Model] = {} def get_api_key(self): @@ -264,15 +271,16 @@ class TogetherInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Need # Together's /v1/models is not compatible with OpenAI's /v1/models. Together support ticket #13355 -> will not fix, use Together's own client for m in await self._get_client().models.list(): if m.type == "embedding": - if m.id not in EMBEDDING_MODEL_ENTRIES: + if m.id not in self.embedding_model_metadata: logger.warning(f"Unknown embedding dimension for model {m.id}, skipping.") continue + metadata = self.embedding_model_metadata[m.id] self._model_cache[m.id] = Model( provider_id=self.__provider_id__, - provider_resource_id=EMBEDDING_MODEL_ENTRIES[m.id].provider_model_id, + provider_resource_id=m.id, identifier=m.id, model_type=ModelType.embedding, - metadata=EMBEDDING_MODEL_ENTRIES[m.id].metadata, + metadata=metadata, ) else: self._model_cache[m.id] = Model( diff --git a/llama_stack/providers/remote/inference/vertexai/models.py b/llama_stack/providers/remote/inference/vertexai/models.py deleted file mode 100644 index e72db533d..000000000 --- a/llama_stack/providers/remote/inference/vertexai/models.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.providers.utils.inference.model_registry import ( - ProviderModelEntry, -) - -# Vertex AI model IDs with vertex_ai/ prefix as required by litellm -LLM_MODEL_IDS = [ - "vertex_ai/gemini-2.0-flash", - "vertex_ai/gemini-2.5-flash", - "vertex_ai/gemini-2.5-pro", -] - -SAFETY_MODELS_ENTRIES = list[ProviderModelEntry]() - -MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + SAFETY_MODELS_ENTRIES diff --git a/llama_stack/providers/remote/inference/vertexai/vertexai.py b/llama_stack/providers/remote/inference/vertexai/vertexai.py index 8996543e7..770d21a2a 100644 --- a/llama_stack/providers/remote/inference/vertexai/vertexai.py +++ b/llama_stack/providers/remote/inference/vertexai/vertexai.py @@ -16,14 +16,12 @@ from llama_stack.providers.utils.inference.litellm_openai_mixin import ( from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from .config import VertexAIConfig -from .models import MODEL_ENTRIES class VertexAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin): def __init__(self, config: VertexAIConfig) -> None: LiteLLMOpenAIMixin.__init__( self, - MODEL_ENTRIES, litellm_provider_name="vertex_ai", api_key_from_config=None, # Vertex AI uses ADC, not API keys provider_data_api_key_field="vertex_project", # Use project for validation diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 15f807846..8fbb4b815 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -292,7 +292,7 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro def __init__(self, config: VLLMInferenceAdapterConfig) -> None: LiteLLMOpenAIMixin.__init__( self, - build_hf_repo_model_entries(), + model_entries=build_hf_repo_model_entries(), litellm_provider_name="vllm", api_key_from_config=config.api_token, provider_data_api_key_field="vllm_api_token", diff --git a/llama_stack/providers/remote/inference/watsonx/watsonx.py b/llama_stack/providers/remote/inference/watsonx/watsonx.py index ab5ca76db..cb8b45565 100644 --- a/llama_stack/providers/remote/inference/watsonx/watsonx.py +++ b/llama_stack/providers/remote/inference/watsonx/watsonx.py @@ -76,7 +76,7 @@ logger = get_logger(name=__name__, category="inference::watsonx") class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): def __init__(self, config: WatsonXConfig) -> None: - ModelRegistryHelper.__init__(self, MODEL_ENTRIES) + ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES) logger.info(f"Initializing watsonx InferenceAdapter({config.url})...") self._config = config diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index 9bd43e4c9..b1e38f323 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -40,7 +40,7 @@ from llama_stack.apis.inference import ( ) from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.log import get_logger -from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper +from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry from llama_stack.providers.utils.inference.openai_compat import ( b64_encode_openai_embeddings_response, convert_message_to_openai_dict_new, @@ -67,10 +67,10 @@ class LiteLLMOpenAIMixin( # when calling litellm. def __init__( self, - model_entries, litellm_provider_name: str, api_key_from_config: str | None, provider_data_api_key_field: str, + model_entries: list[ProviderModelEntry] | None = None, openai_compat_api_base: str | None = None, download_images: bool = False, json_schema_strict: bool = True, @@ -86,7 +86,7 @@ class LiteLLMOpenAIMixin( :param download_images: Whether to download images and convert to base64 for message conversion. :param json_schema_strict: Whether to use strict mode for JSON schema validation. """ - ModelRegistryHelper.__init__(self, model_entries) + ModelRegistryHelper.__init__(self, model_entries=model_entries) self.litellm_provider_name = litellm_provider_name self.api_key_from_config = api_key_from_config diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index b6b06c0b6..ff15b2d43 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -11,7 +11,6 @@ from pydantic import BaseModel, Field from llama_stack.apis.common.errors import UnsupportedModelError from llama_stack.apis.models import ModelType from llama_stack.log import get_logger -from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.utils.inference import ( ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR, @@ -37,13 +36,6 @@ class ProviderModelEntry(BaseModel): metadata: dict[str, Any] = Field(default_factory=dict) -def get_huggingface_repo(model_descriptor: str) -> str | None: - for model in all_registered_models(): - if model.descriptor() == model_descriptor: - return model.huggingface_repo - return None - - def build_hf_repo_model_entry( provider_model_id: str, model_descriptor: str, @@ -63,25 +55,20 @@ def build_hf_repo_model_entry( ) -def build_model_entry(provider_model_id: str, model_descriptor: str) -> ProviderModelEntry: - return ProviderModelEntry( - provider_model_id=provider_model_id, - aliases=[], - llama_model=model_descriptor, - model_type=ModelType.llm, - ) - - class ModelRegistryHelper(ModelsProtocolPrivate): __provider_id__: str - def __init__(self, model_entries: list[ProviderModelEntry], allowed_models: list[str] | None = None): - self.model_entries = model_entries + def __init__( + self, + model_entries: list[ProviderModelEntry] | None = None, + allowed_models: list[str] | None = None, + ): self.allowed_models = allowed_models self.alias_to_provider_id_map = {} self.provider_id_to_llama_model_map = {} - for entry in model_entries: + self.model_entries = model_entries or [] + for entry in self.model_entries: for alias in entry.aliases: self.alias_to_provider_id_map[alias] = entry.provider_model_id diff --git a/llama_stack/providers/utils/inference/openai_mixin.py b/llama_stack/providers/utils/inference/openai_mixin.py index 2fe343f63..84211dc96 100644 --- a/llama_stack/providers/utils/inference/openai_mixin.py +++ b/llama_stack/providers/utils/inference/openai_mixin.py @@ -24,12 +24,13 @@ from llama_stack.apis.inference import ( ) from llama_stack.apis.models import ModelType from llama_stack.log import get_logger +from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params logger = get_logger(name=__name__, category="providers::utils") -class OpenAIMixin(ABC): +class OpenAIMixin(ModelRegistryHelper, ABC): """ Mixin class that provides OpenAI-specific functionality for inference providers. This class handles direct OpenAI API calls using the AsyncOpenAI client. @@ -50,10 +51,18 @@ class OpenAIMixin(ABC): # This is useful for providers that do not return a unique id in the response. overwrite_completion_id: bool = False + # Embedding model metadata for this provider + # Can be set by subclasses or instances to provide embedding models + # Format: {"model_id": {"embedding_dimension": 1536, "context_length": 8192}} + embedding_model_metadata: dict[str, dict[str, int]] = {} + # Cache of available models keyed by model ID # This is set in list_models() and used in check_model_availability() _model_cache: dict[str, Model] = {} + # List of allowed models for this provider, if empty all models allowed + allowed_models: list[str] = [] + @abstractmethod def get_api_key(self) -> str: """ @@ -302,22 +311,36 @@ class OpenAIMixin(ABC): async def list_models(self) -> list[Model] | None: """ - List available models from the provider's /v1/models endpoint. + List available models from the provider's /v1/models endpoint augmented with static embedding model metadata. Also, caches the models in self._model_cache for use in check_model_availability(). :return: A list of Model instances representing available models. """ - self._model_cache = { - m.id: Model( - # __provider_id__ is dynamically added by instantiate_provider in resolver.py - provider_id=self.__provider_id__, # type: ignore[attr-defined] - provider_resource_id=m.id, - identifier=m.id, - model_type=ModelType.llm, - ) - async for m in self.client.models.list() - } + self._model_cache = {} + + async for m in self.client.models.list(): + if self.allowed_models and m.id not in self.allowed_models: + logger.info(f"Skipping model {m.id} as it is not in the allowed models list") + continue + if metadata := self.embedding_model_metadata.get(m.id): + # This is an embedding model - augment with metadata + model = Model( + provider_id=self.__provider_id__, # type: ignore[attr-defined] + provider_resource_id=m.id, + identifier=m.id, + model_type=ModelType.embedding, + metadata=metadata, + ) + else: + # This is an LLM + model = Model( + provider_id=self.__provider_id__, # type: ignore[attr-defined] + provider_resource_id=m.id, + identifier=m.id, + model_type=ModelType.llm, + ) + self._model_cache[m.id] = model return list(self._model_cache.values()) diff --git a/tests/integration/recordings/responses/models-bd032f995f2a-3255f444.json b/tests/integration/recordings/responses/models-bd032f995f2a-3255f444.json new file mode 100644 index 000000000..0909cfcac --- /dev/null +++ b/tests/integration/recordings/responses/models-bd032f995f2a-3255f444.json @@ -0,0 +1,96 @@ +{ + "request": { + "method": "POST", + "url": "http://0.0.0.0:11434/v1/v1/models", + "headers": {}, + "body": {}, + "endpoint": "/v1/models", + "model": "" + }, + "response": { + "body": [ + { + "__type__": "openai.types.model.Model", + "__data__": { + "id": "qwen3:8b", + "created": 1758707188, + "object": "model", + "owned_by": "library" + } + }, + { + "__type__": "openai.types.model.Model", + "__data__": { + "id": "nomic-embed-text:137m-v1.5-fp16", + "created": 1758640855, + "object": "model", + "owned_by": "library" + } + }, + { + "__type__": "openai.types.model.Model", + "__data__": { + "id": "nomic-embed-text:latest", + "created": 1756727155, + "object": "model", + "owned_by": "library" + } + }, + { + "__type__": "openai.types.model.Model", + "__data__": { + "id": "llama3.2-vision:11b", + "created": 1756722893, + "object": "model", + "owned_by": "library" + } + }, + { + "__type__": "openai.types.model.Model", + "__data__": { + "id": "llama-guard3:1b", + "created": 1756671473, + "object": "model", + "owned_by": "library" + } + }, + { + "__type__": "openai.types.model.Model", + "__data__": { + "id": "all-minilm:l6-v2", + "created": 1756655274, + "object": "model", + "owned_by": "library" + } + }, + { + "__type__": "openai.types.model.Model", + "__data__": { + "id": "all-minilm:latest", + "created": 1747317111, + "object": "model", + "owned_by": "library" + } + }, + { + "__type__": "openai.types.model.Model", + "__data__": { + "id": "llama3.2:3b-instruct-fp16", + "created": 1744974677, + "object": "model", + "owned_by": "library" + } + }, + { + "__type__": "openai.types.model.Model", + "__data__": { + "id": "llama3.2:3b", + "created": 1743536220, + "object": "model", + "owned_by": "library" + } + } + ], + "is_streaming": false + } +} diff --git a/tests/unit/providers/inference/test_litellm_openai_mixin.py b/tests/unit/providers/inference/test_litellm_openai_mixin.py index bbc437edf..dc17e6abf 100644 --- a/tests/unit/providers/inference/test_litellm_openai_mixin.py +++ b/tests/unit/providers/inference/test_litellm_openai_mixin.py @@ -26,7 +26,6 @@ class TestProviderDataValidator(BaseModel): class TestLiteLLMAdapter(LiteLLMOpenAIMixin): def __init__(self, config: TestConfig): super().__init__( - model_entries=[], litellm_provider_name="test", api_key_from_config=config.api_key, provider_data_api_key_field="test_api_key", diff --git a/tests/unit/providers/nvidia/test_eval.py b/tests/unit/providers/nvidia/test_eval.py index 2bdcbbeba..55dfd7bee 100644 --- a/tests/unit/providers/nvidia/test_eval.py +++ b/tests/unit/providers/nvidia/test_eval.py @@ -150,7 +150,7 @@ class TestNVIDIAEvalImpl(unittest.TestCase): self._assert_request_body( { "config": f"nvidia/{MOCK_BENCHMARK_ID}", - "target": {"type": "model", "model": "meta/llama-3.1-8b-instruct"}, + "target": {"type": "model", "model": "Llama3.1-8B-Instruct"}, } ) diff --git a/tests/unit/providers/utils/inference/test_openai_mixin.py b/tests/unit/providers/utils/inference/test_openai_mixin.py index 93f82da19..d62292542 100644 --- a/tests/unit/providers/utils/inference/test_openai_mixin.py +++ b/tests/unit/providers/utils/inference/test_openai_mixin.py @@ -13,7 +13,6 @@ from llama_stack.apis.models import ModelType from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin -# Test implementation of OpenAIMixin for testing purposes class OpenAIMixinImpl(OpenAIMixin): def __init__(self): self.__provider_id__ = "test-provider" @@ -25,12 +24,35 @@ class OpenAIMixinImpl(OpenAIMixin): raise NotImplementedError("This method should be mocked in tests") +class OpenAIMixinWithEmbeddingsImpl(OpenAIMixin): + """Test implementation with embedding model metadata""" + + embedding_model_metadata = { + "text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192}, + "text-embedding-ada-002": {"embedding_dimension": 1536, "context_length": 8192}, + } + + __provider_id__ = "test-provider" + + def get_api_key(self) -> str: + raise NotImplementedError("This method should be mocked in tests") + + def get_base_url(self) -> str: + raise NotImplementedError("This method should be mocked in tests") + + @pytest.fixture def mixin(): """Create a test instance of OpenAIMixin""" return OpenAIMixinImpl() +@pytest.fixture +def mixin_with_embeddings(): + """Create a test instance of OpenAIMixin with embedding model metadata""" + return OpenAIMixinWithEmbeddingsImpl() + + @pytest.fixture def mock_models(): """Create multiple mock OpenAI model objects""" @@ -181,3 +203,89 @@ class TestOpenAIMixinCacheBehavior: assert "some-mock-model-id" in mixin._model_cache assert "another-mock-model-id" in mixin._model_cache assert "final-mock-model-id" in mixin._model_cache + + +class TestOpenAIMixinEmbeddingModelMetadata: + """Test cases for embedding_model_metadata attribute functionality""" + + async def test_embedding_model_identified_and_augmented(self, mixin_with_embeddings, mock_client_context): + """Test that models in embedding_model_metadata are correctly identified as embeddings with metadata""" + # Create mock models: 1 embedding model and 1 LLM, while there are 2 known embedding models + mock_embedding_model = MagicMock(id="text-embedding-3-small") + mock_llm_model = MagicMock(id="gpt-4") + mock_models = [mock_embedding_model, mock_llm_model] + + mock_client = MagicMock() + + async def mock_models_list(): + for model in mock_models: + yield model + + mock_client.models.list.return_value = mock_models_list() + + with mock_client_context(mixin_with_embeddings, mock_client): + result = await mixin_with_embeddings.list_models() + + assert result is not None + assert len(result) == 2 + + # Find the models in the result + embedding_model = next(m for m in result if m.identifier == "text-embedding-3-small") + llm_model = next(m for m in result if m.identifier == "gpt-4") + + # Check embedding model + assert embedding_model.model_type == ModelType.embedding + assert embedding_model.metadata == {"embedding_dimension": 1536, "context_length": 8192} + assert embedding_model.provider_id == "test-provider" + assert embedding_model.provider_resource_id == "text-embedding-3-small" + + # Check LLM model + assert llm_model.model_type == ModelType.llm + assert llm_model.metadata == {} # No metadata for LLMs + assert llm_model.provider_id == "test-provider" + assert llm_model.provider_resource_id == "gpt-4" + + +class TestOpenAIMixinAllowedModels: + """Test cases for allowed_models filtering functionality""" + + async def test_list_models_with_allowed_models_filter(self, mixin, mock_client_with_models, mock_client_context): + """Test that list_models filters models based on allowed_models set""" + mixin.allowed_models = {"some-mock-model-id", "another-mock-model-id"} + + with mock_client_context(mixin, mock_client_with_models): + result = await mixin.list_models() + + assert result is not None + assert len(result) == 2 + + model_ids = [model.identifier for model in result] + assert "some-mock-model-id" in model_ids + assert "another-mock-model-id" in model_ids + assert "final-mock-model-id" not in model_ids + + async def test_list_models_with_empty_allowed_models(self, mixin, mock_client_with_models, mock_client_context): + """Test that empty allowed_models set allows all models""" + assert len(mixin.allowed_models) == 0 + + with mock_client_context(mixin, mock_client_with_models): + result = await mixin.list_models() + + assert result is not None + assert len(result) == 3 # All models should be included + + model_ids = [model.identifier for model in result] + assert "some-mock-model-id" in model_ids + assert "another-mock-model-id" in model_ids + assert "final-mock-model-id" in model_ids + + async def test_check_model_availability_with_allowed_models( + self, mixin, mock_client_with_models, mock_client_context + ): + """Test that check_model_availability respects allowed_models""" + mixin.allowed_models = {"final-mock-model-id"} + + with mock_client_context(mixin, mock_client_with_models): + assert await mixin.check_model_availability("final-mock-model-id") + assert not await mixin.check_model_availability("some-mock-model-id") + assert not await mixin.check_model_availability("another-mock-model-id") diff --git a/tests/unit/providers/utils/test_model_registry.py b/tests/unit/providers/utils/test_model_registry.py index db1630000..04e75aa82 100644 --- a/tests/unit/providers/utils/test_model_registry.py +++ b/tests/unit/providers/utils/test_model_registry.py @@ -84,14 +84,14 @@ def unknown_model() -> Model: @pytest.fixture def helper(known_provider_model: ProviderModelEntry, known_provider_model2: ProviderModelEntry) -> ModelRegistryHelper: - return ModelRegistryHelper([known_provider_model, known_provider_model2]) + return ModelRegistryHelper(model_entries=[known_provider_model, known_provider_model2]) class MockModelRegistryHelperWithDynamicModels(ModelRegistryHelper): """Test helper that simulates a provider with dynamically available models.""" def __init__(self, model_entries: list[ProviderModelEntry], available_models: list[str]): - super().__init__(model_entries) + super().__init__(model_entries=model_entries) self._available_models = available_models async def check_model_availability(self, model: str) -> bool: