mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
chore: OpenAIMixin implements ModelsProtocolPrivate (#3662)
# What does this PR do? add ModelsProtocolPrivate methods to OpenAIMixin this will allow providers using OpenAIMixin to use a common interface ## Test Plan ci w/ new tests
This commit is contained in:
parent
14a94e9894
commit
0a41c4ead0
8 changed files with 243 additions and 11 deletions
|
@ -25,9 +25,6 @@ from llama_stack.apis.inference import (
|
|||
ToolPromptFormat,
|
||||
TopKSamplingStrategy,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
|
@ -44,7 +41,6 @@ from .config import CerebrasImplConfig
|
|||
|
||||
class CerebrasInferenceAdapter(
|
||||
OpenAIMixin,
|
||||
ModelRegistryHelper,
|
||||
Inference,
|
||||
):
|
||||
def __init__(self, config: CerebrasImplConfig) -> None:
|
||||
|
|
|
@ -44,7 +44,7 @@ from .config import FireworksImplConfig
|
|||
logger = get_logger(name=__name__, category="inference::fireworks")
|
||||
|
||||
|
||||
class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
class FireworksInferenceAdapter(OpenAIMixin, Inference, NeedsRequestProviderData):
|
||||
embedding_model_metadata = {
|
||||
"nomic-ai/nomic-embed-text-v1.5": {"embedding_dimension": 768, "context_length": 8192},
|
||||
"accounts/fireworks/models/qwen3-embedding-8b": {"embedding_dimension": 4096, "context_length": 40960},
|
||||
|
|
|
@ -29,7 +29,6 @@ from llama_stack.apis.models import Model
|
|||
from llama_stack.apis.models.models import ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.sku_list import all_registered_models
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
build_hf_repo_model_entry,
|
||||
|
@ -65,7 +64,6 @@ def build_hf_repo_model_entries():
|
|||
class _HfAdapter(
|
||||
OpenAIMixin,
|
||||
Inference,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
url: str
|
||||
api_key: SecretStr
|
||||
|
|
|
@ -47,7 +47,7 @@ from .config import TogetherImplConfig
|
|||
logger = get_logger(name=__name__, category="inference::together")
|
||||
|
||||
|
||||
class TogetherInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
class TogetherInferenceAdapter(OpenAIMixin, Inference, NeedsRequestProviderData):
|
||||
embedding_model_metadata = {
|
||||
"togethercomputer/m2-bert-80M-32k-retrieval": {"embedding_dimension": 768, "context_length": 32768},
|
||||
"BAAI/bge-large-en-v1.5": {"embedding_dimension": 1024, "context_length": 512},
|
||||
|
|
|
@ -26,14 +26,14 @@ from llama_stack.apis.inference import (
|
|||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import localize_image_content
|
||||
|
||||
logger = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
||||
class OpenAIMixin(ModelRegistryHelper, NeedsRequestProviderData, ABC):
|
||||
class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC):
|
||||
"""
|
||||
Mixin class that provides OpenAI-specific functionality for inference providers.
|
||||
This class handles direct OpenAI API calls using the AsyncOpenAI client.
|
||||
|
@ -73,6 +73,9 @@ class OpenAIMixin(ModelRegistryHelper, NeedsRequestProviderData, ABC):
|
|||
# Optional field name in provider data to look for API key, which takes precedence
|
||||
provider_data_api_key_field: str | None = None
|
||||
|
||||
# automatically set by the resolver when instantiating the provider
|
||||
__provider_id__: str
|
||||
|
||||
@abstractmethod
|
||||
def get_api_key(self) -> str:
|
||||
"""
|
||||
|
@ -356,6 +359,24 @@ class OpenAIMixin(ModelRegistryHelper, NeedsRequestProviderData, ABC):
|
|||
usage=usage,
|
||||
)
|
||||
|
||||
###
|
||||
# ModelsProtocolPrivate implementation - provide model management functionality
|
||||
#
|
||||
# async def register_model(self, model: Model) -> Model: ...
|
||||
# async def unregister_model(self, model_id: str) -> None: ...
|
||||
#
|
||||
# async def list_models(self) -> list[Model] | None: ...
|
||||
# async def should_refresh_models(self) -> bool: ...
|
||||
##
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
if not await self.check_model_availability(model.provider_model_id):
|
||||
raise ValueError(f"Model {model.provider_model_id} is not available from provider {self.__provider_id__}")
|
||||
return model
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
return None
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
"""
|
||||
List available models from the provider's /v1/models endpoint augmented with static embedding model metadata.
|
||||
|
@ -400,5 +421,7 @@ class OpenAIMixin(ModelRegistryHelper, NeedsRequestProviderData, ABC):
|
|||
"""
|
||||
if not self._model_cache:
|
||||
await self.list_models()
|
||||
|
||||
return model in self._model_cache
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return False
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue