mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-27 06:28:50 +00:00
add configuration to control which models are exposed
This commit is contained in:
parent
2e5ffab4e3
commit
e3396513e9
7 changed files with 21 additions and 7 deletions
|
@ -8,6 +8,7 @@ Fireworks AI inference provider for Llama models and other AI models on the Fire
|
||||||
|
|
||||||
| Field | Type | Required | Default | Description |
|
| Field | Type | Required | Default | Description |
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `url` | `<class 'str'>` | No | https://api.fireworks.ai/inference/v1 | The URL for the Fireworks server |
|
| `url` | `<class 'str'>` | No | https://api.fireworks.ai/inference/v1 | The URL for the Fireworks server |
|
||||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The Fireworks.ai API Key |
|
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The Fireworks.ai API Key |
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,7 @@ Together AI inference provider for open-source models and collaborative AI devel
|
||||||
|
|
||||||
| Field | Type | Required | Default | Description |
|
| Field | Type | Required | Default | Description |
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `url` | `<class 'str'>` | No | https://api.together.xyz/v1 | The URL for the Together AI server |
|
| `url` | `<class 'str'>` | No | https://api.together.xyz/v1 | The URL for the Together AI server |
|
||||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The Together AI API Key |
|
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The Together AI API Key |
|
||||||
|
|
||||||
|
|
|
@ -6,13 +6,14 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, SecretStr
|
from pydantic import Field, SecretStr
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class FireworksImplConfig(BaseModel):
|
class FireworksImplConfig(RemoteInferenceProviderConfig):
|
||||||
url: str = Field(
|
url: str = Field(
|
||||||
default="https://api.fireworks.ai/inference/v1",
|
default="https://api.fireworks.ai/inference/v1",
|
||||||
description="The URL for the Fireworks server",
|
description="The URL for the Fireworks server",
|
||||||
|
|
|
@ -70,7 +70,7 @@ logger = get_logger(name=__name__, category="inference")
|
||||||
|
|
||||||
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||||
def __init__(self, config: FireworksImplConfig) -> None:
|
def __init__(self, config: FireworksImplConfig) -> None:
|
||||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
|
|
|
@ -6,13 +6,14 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, SecretStr
|
from pydantic import Field, SecretStr
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class TogetherImplConfig(BaseModel):
|
class TogetherImplConfig(RemoteInferenceProviderConfig):
|
||||||
url: str = Field(
|
url: str = Field(
|
||||||
default="https://api.together.xyz/v1",
|
default="https://api.together.xyz/v1",
|
||||||
description="The URL for the Together AI server",
|
description="The URL for the Together AI server",
|
||||||
|
|
|
@ -66,7 +66,7 @@ logger = get_logger(name=__name__, category="inference")
|
||||||
|
|
||||||
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||||
def __init__(self, config: TogetherImplConfig) -> None:
|
def __init__(self, config: TogetherImplConfig) -> None:
|
||||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
|
|
|
@ -20,6 +20,13 @@ from llama_stack.providers.utils.inference import (
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteInferenceProviderConfig(BaseModel):
|
||||||
|
allowed_models: list[str] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="List of models that should be registered with the model registry. If None, all models are allowed.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO: this class is more confusing than useful right now. We need to make it
|
# TODO: this class is more confusing than useful right now. We need to make it
|
||||||
# more closer to the Model class.
|
# more closer to the Model class.
|
||||||
class ProviderModelEntry(BaseModel):
|
class ProviderModelEntry(BaseModel):
|
||||||
|
@ -67,7 +74,8 @@ def build_model_entry(provider_model_id: str, model_descriptor: str) -> Provider
|
||||||
class ModelRegistryHelper(ModelsProtocolPrivate):
|
class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||||
__provider_id__: str
|
__provider_id__: str
|
||||||
|
|
||||||
def __init__(self, model_entries: list[ProviderModelEntry]):
|
def __init__(self, model_entries: list[ProviderModelEntry], allowed_models: list[str] | None = None):
|
||||||
|
self.allowed_models = allowed_models
|
||||||
self.alias_to_provider_id_map = {}
|
self.alias_to_provider_id_map = {}
|
||||||
self.provider_id_to_llama_model_map = {}
|
self.provider_id_to_llama_model_map = {}
|
||||||
for entry in model_entries:
|
for entry in model_entries:
|
||||||
|
@ -86,6 +94,8 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||||
for entry in self.model_entries:
|
for entry in self.model_entries:
|
||||||
ids = [entry.provider_model_id] + entry.aliases
|
ids = [entry.provider_model_id] + entry.aliases
|
||||||
for id in ids:
|
for id in ids:
|
||||||
|
if self.allowed_models and id not in self.allowed_models:
|
||||||
|
continue
|
||||||
models.append(
|
models.append(
|
||||||
Model(
|
Model(
|
||||||
model_id=id,
|
model_id=id,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue