feat: add static embedding metadata to dynamic model listings for providers using OpenAIMixin (#3547)

# What does this PR do?

- remove auto-download of ollama embedding models
- add embedding model metadata to dynamic listing w/ unit test
- add support and tests for allowed_models
- removed inference provider models.py files where dynamic listing is
enabled
- store embedding metadata in embedding_model_metadata field on
inference providers
- make model_entries optional on ModelRegistryHelper and
LiteLLMOpenAIMixin
- make OpenAIMixin a ModelRegistryHelper
- skip base64 embedding test for remote::ollama, always returns floats
- only use OpenAI client for ollama model listing
- remove unused build_model_entry function
- remove unused get_huggingface_repo function


## Test Plan

ci w/ new tests
This commit is contained in:
Matthew Farrellee 2025-09-25 17:17:00 -04:00 committed by GitHub
parent a50b63906c
commit b67aef2fc4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
43 changed files with 368 additions and 1015 deletions

View file

@ -24,12 +24,13 @@ from llama_stack.apis.inference import (
)
from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
logger = get_logger(name=__name__, category="providers::utils")
class OpenAIMixin(ABC):
class OpenAIMixin(ModelRegistryHelper, ABC):
"""
Mixin class that provides OpenAI-specific functionality for inference providers.
This class handles direct OpenAI API calls using the AsyncOpenAI client.
@ -50,10 +51,18 @@ class OpenAIMixin(ABC):
# This is useful for providers that do not return a unique id in the response.
overwrite_completion_id: bool = False
# Embedding model metadata for this provider
# Can be set by subclasses or instances to provide embedding models
# Format: {"model_id": {"embedding_dimension": 1536, "context_length": 8192}}
embedding_model_metadata: dict[str, dict[str, int]] = {}
# Cache of available models keyed by model ID
# This is set in list_models() and used in check_model_availability()
_model_cache: dict[str, Model] = {}
# List of allowed models for this provider, if empty all models allowed
allowed_models: list[str] = []
@abstractmethod
def get_api_key(self) -> str:
"""
@ -302,22 +311,36 @@ class OpenAIMixin(ABC):
async def list_models(self) -> list[Model] | None:
"""
List available models from the provider's /v1/models endpoint.
List available models from the provider's /v1/models endpoint augmented with static embedding model metadata.
Also, caches the models in self._model_cache for use in check_model_availability().
:return: A list of Model instances representing available models.
"""
self._model_cache = {
m.id: Model(
# __provider_id__ is dynamically added by instantiate_provider in resolver.py
provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=m.id,
identifier=m.id,
model_type=ModelType.llm,
)
async for m in self.client.models.list()
}
self._model_cache = {}
async for m in self.client.models.list():
if self.allowed_models and m.id not in self.allowed_models:
logger.info(f"Skipping model {m.id} as it is not in the allowed models list")
continue
if metadata := self.embedding_model_metadata.get(m.id):
# This is an embedding model - augment with metadata
model = Model(
provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=m.id,
identifier=m.id,
model_type=ModelType.embedding,
metadata=metadata,
)
else:
# This is an LLM
model = Model(
provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=m.id,
identifier=m.id,
model_type=ModelType.llm,
)
self._model_cache[m.id] = model
return list(self._model_cache.values())