mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-09 13:14:39 +00:00
Merge branch 'main' into add-localize-url-feature-to-openaimixin
This commit is contained in:
commit
17125fd2cf
421 changed files with 70880 additions and 5915 deletions
|
@ -25,13 +25,14 @@ from llama_stack.apis.inference import (
|
|||
)
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import localize_image_content
|
||||
|
||||
logger = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
||||
class OpenAIMixin(ABC):
|
||||
class OpenAIMixin(ModelRegistryHelper, ABC):
|
||||
"""
|
||||
Mixin class that provides OpenAI-specific functionality for inference providers.
|
||||
This class handles direct OpenAI API calls using the AsyncOpenAI client.
|
||||
|
@ -56,10 +57,18 @@ class OpenAIMixin(ABC):
|
|||
# for providers that require base64 encoded images instead of URLs.
|
||||
download_images: bool = False
|
||||
|
||||
# Embedding model metadata for this provider
|
||||
# Can be set by subclasses or instances to provide embedding models
|
||||
# Format: {"model_id": {"embedding_dimension": 1536, "context_length": 8192}}
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {}
|
||||
|
||||
# Cache of available models keyed by model ID
|
||||
# This is set in list_models() and used in check_model_availability()
|
||||
_model_cache: dict[str, Model] = {}
|
||||
|
||||
# List of allowed models for this provider, if empty all models allowed
|
||||
allowed_models: list[str] = []
|
||||
|
||||
@abstractmethod
|
||||
def get_api_key(self) -> str:
|
||||
"""
|
||||
|
@ -320,28 +329,42 @@ class OpenAIMixin(ABC):
|
|||
|
||||
return OpenAIEmbeddingsResponse(
|
||||
data=data,
|
||||
model=response.model,
|
||||
model=model,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
"""
|
||||
List available models from the provider's /v1/models endpoint.
|
||||
List available models from the provider's /v1/models endpoint augmented with static embedding model metadata.
|
||||
|
||||
Also, caches the models in self._model_cache for use in check_model_availability().
|
||||
|
||||
:return: A list of Model instances representing available models.
|
||||
"""
|
||||
self._model_cache = {
|
||||
m.id: Model(
|
||||
# __provider_id__ is dynamically added by instantiate_provider in resolver.py
|
||||
provider_id=self.__provider_id__, # type: ignore[attr-defined]
|
||||
provider_resource_id=m.id,
|
||||
identifier=m.id,
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
async for m in self.client.models.list()
|
||||
}
|
||||
self._model_cache = {}
|
||||
|
||||
async for m in self.client.models.list():
|
||||
if self.allowed_models and m.id not in self.allowed_models:
|
||||
logger.info(f"Skipping model {m.id} as it is not in the allowed models list")
|
||||
continue
|
||||
if metadata := self.embedding_model_metadata.get(m.id):
|
||||
# This is an embedding model - augment with metadata
|
||||
model = Model(
|
||||
provider_id=self.__provider_id__, # type: ignore[attr-defined]
|
||||
provider_resource_id=m.id,
|
||||
identifier=m.id,
|
||||
model_type=ModelType.embedding,
|
||||
metadata=metadata,
|
||||
)
|
||||
else:
|
||||
# This is an LLM
|
||||
model = Model(
|
||||
provider_id=self.__provider_id__, # type: ignore[attr-defined]
|
||||
provider_resource_id=m.id,
|
||||
identifier=m.id,
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
self._model_cache[m.id] = model
|
||||
|
||||
return list(self._model_cache.values())
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue