Merge branch 'main' into add-localize-url-feature-to-openaimixin

This commit is contained in:
Matthew Farrellee 2025-09-26 16:31:59 -04:00
commit 17125fd2cf
421 changed files with 70880 additions and 5915 deletions

View file

@ -25,13 +25,14 @@ from llama_stack.apis.inference import (
)
from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
from llama_stack.providers.utils.inference.prompt_adapter import localize_image_content
logger = get_logger(name=__name__, category="providers::utils")
class OpenAIMixin(ABC):
class OpenAIMixin(ModelRegistryHelper, ABC):
"""
Mixin class that provides OpenAI-specific functionality for inference providers.
This class handles direct OpenAI API calls using the AsyncOpenAI client.
@ -56,10 +57,18 @@ class OpenAIMixin(ABC):
# for providers that require base64 encoded images instead of URLs.
download_images: bool = False
# Embedding model metadata for this provider
# Can be set by subclasses or instances to provide embedding models
# Format: {"model_id": {"embedding_dimension": 1536, "context_length": 8192}}
embedding_model_metadata: dict[str, dict[str, int]] = {}
# Cache of available models keyed by model ID
# This is set in list_models() and used in check_model_availability()
_model_cache: dict[str, Model] = {}
# List of allowed models for this provider, if empty all models allowed
allowed_models: list[str] = []
@abstractmethod
def get_api_key(self) -> str:
"""
@ -320,28 +329,42 @@ class OpenAIMixin(ABC):
return OpenAIEmbeddingsResponse(
data=data,
model=response.model,
model=model,
usage=usage,
)
async def list_models(self) -> list[Model] | None:
"""
List available models from the provider's /v1/models endpoint.
List available models from the provider's /v1/models endpoint augmented with static embedding model metadata.
Also, caches the models in self._model_cache for use in check_model_availability().
:return: A list of Model instances representing available models.
"""
self._model_cache = {
m.id: Model(
# __provider_id__ is dynamically added by instantiate_provider in resolver.py
provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=m.id,
identifier=m.id,
model_type=ModelType.llm,
)
async for m in self.client.models.list()
}
self._model_cache = {}
async for m in self.client.models.list():
if self.allowed_models and m.id not in self.allowed_models:
logger.info(f"Skipping model {m.id} as it is not in the allowed models list")
continue
if metadata := self.embedding_model_metadata.get(m.id):
# This is an embedding model - augment with metadata
model = Model(
provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=m.id,
identifier=m.id,
model_type=ModelType.embedding,
metadata=metadata,
)
else:
# This is an LLM
model = Model(
provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=m.id,
identifier=m.id,
model_type=ModelType.llm,
)
self._model_cache[m.id] = model
return list(self._model_cache.values())