mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-24 01:03:55 +00:00
feat: created dynamic model registration for openai and llama openai compat remote inference providers
fix: removed implementation of register_model() from LiteLLMOpenAIMixin, added log message to llama in query_available_models(), added llama-api-client dependency to pyproject.toml
This commit is contained in:
parent
f85189022c
commit
fa5935bd80
5 changed files with 49 additions and 14 deletions
|
|
@ -3,16 +3,17 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import logging
|
||||
|
||||
from llama_stack.providers.remote.inference.llama_openai_compat.config import (
|
||||
LlamaCompatConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
|
||||
LiteLLMOpenAIMixin,
|
||||
)
|
||||
from llama_api_client import AsyncLlamaAPIClient
|
||||
|
||||
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
_config: LlamaCompatConfig
|
||||
|
|
@ -26,6 +27,17 @@ class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
openai_compat_api_base=config.openai_compat_api_base,
|
||||
)
|
||||
self.config = config
|
||||
self._llama_api_client = AsyncLlamaAPIClient(api_key=config.api_key)
|
||||
|
||||
async def query_available_models(self) -> list[str]:
|
||||
"""Query available models from the Llama API."""
|
||||
try:
|
||||
available_models = await self._llama_api_client.models.list()
|
||||
logger.info(f"Available models from Llama API: {available_models}")
|
||||
return [model.id for model in available_models]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to query available models from Llama API: {e}")
|
||||
return []
|
||||
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
|
|
|
|||
|
|
@ -60,6 +60,17 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
# litellm specific model names, an abstraction leak.
|
||||
self.is_openai_compat = True
|
||||
|
||||
async def query_available_models(self) -> list[str]:
|
||||
"""Query available models from the OpenAI API"""
|
||||
try:
|
||||
openai_client = self._get_openai_client()
|
||||
available_models = await openai_client.models.list()
|
||||
logger.info(f"Available models from OpenAI: {available_models.data}")
|
||||
return [model.id for model in available_models.data]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to query available models from OpenAI: {e}")
|
||||
return []
|
||||
|
||||
async def initialize(self) -> None:
|
||||
await super().initialize()
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ from llama_stack.apis.common.content_types import (
|
|||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.common.errors import UnsupportedModelError
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
|
|
@ -39,7 +38,6 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
|
|
@ -90,12 +88,6 @@ class LiteLLMOpenAIMixin(
|
|||
async def shutdown(self):
|
||||
pass
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
model_id = self.get_provider_model_id(model.provider_resource_id)
|
||||
if model_id is None:
|
||||
raise UnsupportedModelError(model.provider_resource_id, self.alias_to_provider_id_map.keys())
|
||||
return model
|
||||
|
||||
def get_litellm_model_name(self, model_id: str) -> str:
|
||||
# users may be using openai/ prefix in their model names. the openai/models.py did this by default.
|
||||
# model_id.startswith("openai/") is for backwards compatibility.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue