feat: add static embedding metadata to dynamic model listings for providers using OpenAIMixin

- remove auto-download of ollama embedding models
- add embedding model metadata to dynamic listing w/ unit test
- add support and tests for allowed_models
- removed inference provider models.py files where dynamic listing is enabled
- store embedding metadata in embedding_model_metadata field on inference providers
- make model_entries optional on ModelRegistryHelper and LiteLLMOpenAIMixin
- make OpenAIMixin a ModelRegistryHelper
- skip base64 embedding test for remote::ollama, always returns floats
- only use OpenAI client for ollama model listing
- remove unused build_model_entry function
- remove unused get_huggingface_repo function
This commit is contained in:
Matthew Farrellee 2025-09-25 04:56:54 -04:00
parent a50b63906c
commit 466ef6f490
43 changed files with 370 additions and 1016 deletions

View file

@ -40,7 +40,7 @@ from llama_stack.apis.inference import (
)
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry
from llama_stack.providers.utils.inference.openai_compat import (
b64_encode_openai_embeddings_response,
convert_message_to_openai_dict_new,
@ -67,10 +67,10 @@ class LiteLLMOpenAIMixin(
# when calling litellm.
def __init__(
self,
model_entries,
litellm_provider_name: str,
api_key_from_config: str | None,
provider_data_api_key_field: str,
model_entries: list[ProviderModelEntry] | None = None,
openai_compat_api_base: str | None = None,
download_images: bool = False,
json_schema_strict: bool = True,
@ -86,7 +86,7 @@ class LiteLLMOpenAIMixin(
:param download_images: Whether to download images and convert to base64 for message conversion.
:param json_schema_strict: Whether to use strict mode for JSON schema validation.
"""
ModelRegistryHelper.__init__(self, model_entries)
ModelRegistryHelper.__init__(self, model_entries=model_entries)
self.litellm_provider_name = litellm_provider_name
self.api_key_from_config = api_key_from_config

View file

@ -11,7 +11,6 @@ from pydantic import BaseModel, Field
from llama_stack.apis.common.errors import UnsupportedModelError
from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference import (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
@ -37,13 +36,6 @@ class ProviderModelEntry(BaseModel):
metadata: dict[str, Any] = Field(default_factory=dict)
def get_huggingface_repo(model_descriptor: str) -> str | None:
for model in all_registered_models():
if model.descriptor() == model_descriptor:
return model.huggingface_repo
return None
def build_hf_repo_model_entry(
provider_model_id: str,
model_descriptor: str,
@ -63,25 +55,20 @@ def build_hf_repo_model_entry(
)
def build_model_entry(provider_model_id: str, model_descriptor: str) -> ProviderModelEntry:
return ProviderModelEntry(
provider_model_id=provider_model_id,
aliases=[],
llama_model=model_descriptor,
model_type=ModelType.llm,
)
class ModelRegistryHelper(ModelsProtocolPrivate):
__provider_id__: str
def __init__(self, model_entries: list[ProviderModelEntry], allowed_models: list[str] | None = None):
self.model_entries = model_entries
def __init__(
self,
model_entries: list[ProviderModelEntry] | None = None,
allowed_models: list[str] | None = None,
):
self.allowed_models = allowed_models
self.alias_to_provider_id_map = {}
self.provider_id_to_llama_model_map = {}
for entry in model_entries:
self.model_entries = model_entries or []
for entry in self.model_entries:
for alias in entry.aliases:
self.alias_to_provider_id_map[alias] = entry.provider_model_id

View file

@ -24,12 +24,13 @@ from llama_stack.apis.inference import (
)
from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
logger = get_logger(name=__name__, category="providers::utils")
class OpenAIMixin(ABC):
class OpenAIMixin(ModelRegistryHelper, ABC):
"""
Mixin class that provides OpenAI-specific functionality for inference providers.
This class handles direct OpenAI API calls using the AsyncOpenAI client.
@ -50,10 +51,18 @@ class OpenAIMixin(ABC):
# This is useful for providers that do not return a unique id in the response.
overwrite_completion_id: bool = False
# Embedding model metadata for this provider
# Can be set by subclasses or instances to provide embedding models
# Format: {"model_id": {"embedding_dimension": 1536, "context_length": 8192}}
embedding_model_metadata: dict[str, dict[str, int]] = {}
# Cache of available models keyed by model ID
# This is set in list_models() and used in check_model_availability()
_model_cache: dict[str, Model] = {}
# List of allowed models for this provider, if empty all models allowed
allowed_models: list[str] = []
@abstractmethod
def get_api_key(self) -> str:
"""
@ -302,22 +311,36 @@ class OpenAIMixin(ABC):
async def list_models(self) -> list[Model] | None:
"""
List available models from the provider's /v1/models endpoint.
List available models from the provider's /v1/models endpoint augmented with static embedding model metadata.
Also, caches the models in self._model_cache for use in check_model_availability().
:return: A list of Model instances representing available models.
"""
self._model_cache = {
m.id: Model(
# __provider_id__ is dynamically added by instantiate_provider in resolver.py
provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=m.id,
identifier=m.id,
model_type=ModelType.llm,
)
async for m in self.client.models.list()
}
self._model_cache = {}
async for m in self.client.models.list():
if self.allowed_models and m.id not in self.allowed_models:
logger.info(f"Skipping model {m.id} as it is not in the allowed models list")
continue
if metadata := self.embedding_model_metadata.get(m.id):
# This is an embedding model - augment with metadata
model = Model(
provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=m.id,
identifier=m.id,
model_type=ModelType.embedding,
metadata=metadata,
)
else:
# This is an LLM
model = Model(
provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=m.id,
identifier=m.id,
model_type=ModelType.llm,
)
self._model_cache[m.id] = model
return list(self._model_cache.values())