From e3396513e983b4c73ea44bb99ff39997f2cb702a Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 22 Jul 2025 16:16:08 -0700 Subject: [PATCH] add configuration to control which models are exposed --- docs/source/providers/inference/remote_fireworks.md | 1 + docs/source/providers/inference/remote_together.md | 1 + .../providers/remote/inference/fireworks/config.py | 5 +++-- .../remote/inference/fireworks/fireworks.py | 2 +- .../providers/remote/inference/together/config.py | 5 +++-- .../providers/remote/inference/together/together.py | 2 +- .../providers/utils/inference/model_registry.py | 12 +++++++++++- 7 files changed, 21 insertions(+), 7 deletions(-) diff --git a/docs/source/providers/inference/remote_fireworks.md b/docs/source/providers/inference/remote_fireworks.md index 351586c34..862860c29 100644 --- a/docs/source/providers/inference/remote_fireworks.md +++ b/docs/source/providers/inference/remote_fireworks.md @@ -8,6 +8,7 @@ Fireworks AI inference provider for Llama models and other AI models on the Fire | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| +| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `url` | `` | No | https://api.fireworks.ai/inference/v1 | The URL for the Fireworks server | | `api_key` | `pydantic.types.SecretStr \| None` | No | | The Fireworks.ai API Key | diff --git a/docs/source/providers/inference/remote_together.md b/docs/source/providers/inference/remote_together.md index f33ff42f2..d1fe3e82b 100644 --- a/docs/source/providers/inference/remote_together.md +++ b/docs/source/providers/inference/remote_together.md @@ -8,6 +8,7 @@ Together AI inference provider for open-source models and collaborative AI devel | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| +| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `url` | `` | No | https://api.together.xyz/v1 | The URL for the Together AI server | | `api_key` | `pydantic.types.SecretStr \| None` | No | | The Together AI API Key | diff --git a/llama_stack/providers/remote/inference/fireworks/config.py b/llama_stack/providers/remote/inference/fireworks/config.py index 072d558f4..b23f2d31b 100644 --- a/llama_stack/providers/remote/inference/fireworks/config.py +++ b/llama_stack/providers/remote/inference/fireworks/config.py @@ -6,13 +6,14 @@ from typing import Any -from pydantic import BaseModel, Field, SecretStr +from pydantic import Field, SecretStr +from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type @json_schema_type -class FireworksImplConfig(BaseModel): +class FireworksImplConfig(RemoteInferenceProviderConfig): url: str = Field( default="https://api.fireworks.ai/inference/v1", description="The URL for the Fireworks server", diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index 1c82ff3a8..c76aa39f3 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -70,7 +70,7 @@ logger = get_logger(name=__name__, category="inference") class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): def __init__(self, config: FireworksImplConfig) -> None: - ModelRegistryHelper.__init__(self, MODEL_ENTRIES) + ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models) self.config = config async def initialize(self) -> None: diff --git a/llama_stack/providers/remote/inference/together/config.py b/llama_stack/providers/remote/inference/together/config.py index f166e4277..211be7efe 100644 --- a/llama_stack/providers/remote/inference/together/config.py +++ b/llama_stack/providers/remote/inference/together/config.py @@ -6,13 +6,14 @@ from typing import Any -from pydantic import BaseModel, Field, SecretStr +from pydantic import Field, SecretStr +from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type @json_schema_type -class TogetherImplConfig(BaseModel): +class TogetherImplConfig(RemoteInferenceProviderConfig): url: str = Field( default="https://api.together.xyz/v1", description="The URL for the Together AI server", diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index e1eb934c5..46094c146 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -66,7 +66,7 @@ logger = get_logger(name=__name__, category="inference") class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): def __init__(self, config: TogetherImplConfig) -> None: - ModelRegistryHelper.__init__(self, MODEL_ENTRIES) + ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models) self.config = config async def initialize(self) -> None: diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 84265a85a..bceeaf198 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -20,6 +20,13 @@ from llama_stack.providers.utils.inference import ( logger = get_logger(name=__name__, category="core") +class RemoteInferenceProviderConfig(BaseModel): + allowed_models: list[str] | None = Field( + default=None, + description="List of models that should be registered with the model registry. If None, all models are allowed.", + ) + + # TODO: this class is more confusing than useful right now. We need to make it # more closer to the Model class. class ProviderModelEntry(BaseModel): @@ -67,7 +74,8 @@ def build_model_entry(provider_model_id: str, model_descriptor: str) -> Provider class ModelRegistryHelper(ModelsProtocolPrivate): __provider_id__: str - def __init__(self, model_entries: list[ProviderModelEntry]): + def __init__(self, model_entries: list[ProviderModelEntry], allowed_models: list[str] | None = None): + self.allowed_models = allowed_models self.alias_to_provider_id_map = {} self.provider_id_to_llama_model_map = {} for entry in model_entries: @@ -86,6 +94,8 @@ class ModelRegistryHelper(ModelsProtocolPrivate): for entry in self.model_entries: ids = [entry.provider_model_id] + entry.aliases for id in ids: + if self.allowed_models and id not in self.allowed_models: + continue models.append( Model( model_id=id,