mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-23 09:39:46 +00:00
feat: created dynamic model registration for openai and llama openai compat remote inference providers
fix: removed implementation of register_model() from LiteLLMOpenAIMixin, added log message to llama in query_available_models(), added llama-api-client dependency to pyproject.toml
This commit is contained in:
parent
f85189022c
commit
fa5935bd80
5 changed files with 49 additions and 14 deletions
|
|
@ -3,16 +3,17 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
import logging
|
||||||
|
|
||||||
from llama_stack.providers.remote.inference.llama_openai_compat.config import (
|
from llama_api_client import AsyncLlamaAPIClient
|
||||||
LlamaCompatConfig,
|
|
||||||
)
|
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
|
||||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
|
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||||
LiteLLMOpenAIMixin,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .models import MODEL_ENTRIES
|
from .models import MODEL_ENTRIES
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin):
|
class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
_config: LlamaCompatConfig
|
_config: LlamaCompatConfig
|
||||||
|
|
@ -26,6 +27,17 @@ class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
openai_compat_api_base=config.openai_compat_api_base,
|
openai_compat_api_base=config.openai_compat_api_base,
|
||||||
)
|
)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self._llama_api_client = AsyncLlamaAPIClient(api_key=config.api_key)
|
||||||
|
|
||||||
|
async def query_available_models(self) -> list[str]:
|
||||||
|
"""Query available models from the Llama API."""
|
||||||
|
try:
|
||||||
|
available_models = await self._llama_api_client.models.list()
|
||||||
|
logger.info(f"Available models from Llama API: {available_models}")
|
||||||
|
return [model.id for model in available_models]
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to query available models from Llama API: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
await super().initialize()
|
await super().initialize()
|
||||||
|
|
|
||||||
|
|
@ -60,6 +60,17 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
# litellm specific model names, an abstraction leak.
|
# litellm specific model names, an abstraction leak.
|
||||||
self.is_openai_compat = True
|
self.is_openai_compat = True
|
||||||
|
|
||||||
|
async def query_available_models(self) -> list[str]:
|
||||||
|
"""Query available models from the OpenAI API"""
|
||||||
|
try:
|
||||||
|
openai_client = self._get_openai_client()
|
||||||
|
available_models = await openai_client.models.list()
|
||||||
|
logger.info(f"Available models from OpenAI: {available_models.data}")
|
||||||
|
return [model.id for model in available_models.data]
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to query available models from OpenAI: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
await super().initialize()
|
await super().initialize()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,6 @@ from llama_stack.apis.common.content_types import (
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
InterleavedContentItem,
|
InterleavedContentItem,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.common.errors import UnsupportedModelError
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
|
|
@ -39,7 +38,6 @@ from llama_stack.apis.inference import (
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model
|
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
|
|
@ -90,12 +88,6 @@ class LiteLLMOpenAIMixin(
|
||||||
async def shutdown(self):
|
async def shutdown(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> Model:
|
|
||||||
model_id = self.get_provider_model_id(model.provider_resource_id)
|
|
||||||
if model_id is None:
|
|
||||||
raise UnsupportedModelError(model.provider_resource_id, self.alias_to_provider_id_map.keys())
|
|
||||||
return model
|
|
||||||
|
|
||||||
def get_litellm_model_name(self, model_id: str) -> str:
|
def get_litellm_model_name(self, model_id: str) -> str:
|
||||||
# users may be using openai/ prefix in their model names. the openai/models.py did this by default.
|
# users may be using openai/ prefix in their model names. the openai/models.py did this by default.
|
||||||
# model_id.startswith("openai/") is for backwards compatibility.
|
# model_id.startswith("openai/") is for backwards compatibility.
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,7 @@ dependencies = [
|
||||||
"jinja2>=3.1.6",
|
"jinja2>=3.1.6",
|
||||||
"jsonschema",
|
"jsonschema",
|
||||||
"llama-stack-client>=0.2.15",
|
"llama-stack-client>=0.2.15",
|
||||||
|
"llama-api-client>=0.1.2",
|
||||||
"openai>=1.66",
|
"openai>=1.66",
|
||||||
"prompt-toolkit",
|
"prompt-toolkit",
|
||||||
"python-dotenv",
|
"python-dotenv",
|
||||||
|
|
|
||||||
19
uv.lock
generated
19
uv.lock
generated
|
|
@ -1268,6 +1268,23 @@ wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/2a/f7/67689245f48b9e79bcd2f3a10a3690cb1918fb99fffd5a623ed2496bca66/litellm-1.74.2-py3-none-any.whl", hash = "sha256:29bb555b45128e4cc696e72921a6ec24e97b14e9b69e86eed6f155124ad629b1", size = 8587065 },
|
{ url = "https://files.pythonhosted.org/packages/2a/f7/67689245f48b9e79bcd2f3a10a3690cb1918fb99fffd5a623ed2496bca66/litellm-1.74.2-py3-none-any.whl", hash = "sha256:29bb555b45128e4cc696e72921a6ec24e97b14e9b69e86eed6f155124ad629b1", size = 8587065 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "llama-api-client"
|
||||||
|
version = "0.1.2"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "anyio" },
|
||||||
|
{ name = "distro" },
|
||||||
|
{ name = "httpx" },
|
||||||
|
{ name = "pydantic" },
|
||||||
|
{ name = "sniffio" },
|
||||||
|
{ name = "typing-extensions" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/d0/78/875de3a16efd0442718ac47cc27319cd80cc5f38e12298e454e08611acc4/llama_api_client-0.1.2.tar.gz", hash = "sha256:709011f2d506009b1b3b3bceea1c84f2a3a7600df1420fb256e680fcd7251387", size = 113695, upload-time = "2025-06-27T19:56:14.057Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/99/08/5d7e6e7e6af5353391376288c200acacebb8e6b156d3636eae598a451673/llama_api_client-0.1.2-py3-none-any.whl", hash = "sha256:8ad6e10726f74b2302bfd766c61c41355a9ecf60f57cde2961882d22af998941", size = 84091, upload-time = "2025-06-27T19:56:12.8Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "llama-stack"
|
name = "llama-stack"
|
||||||
version = "0.2.15"
|
version = "0.2.15"
|
||||||
|
|
@ -1283,6 +1300,7 @@ dependencies = [
|
||||||
{ name = "huggingface-hub" },
|
{ name = "huggingface-hub" },
|
||||||
{ name = "jinja2" },
|
{ name = "jinja2" },
|
||||||
{ name = "jsonschema" },
|
{ name = "jsonschema" },
|
||||||
|
{ name = "llama-api-client" },
|
||||||
{ name = "llama-stack-client" },
|
{ name = "llama-stack-client" },
|
||||||
{ name = "openai" },
|
{ name = "openai" },
|
||||||
{ name = "opentelemetry-exporter-otlp-proto-http" },
|
{ name = "opentelemetry-exporter-otlp-proto-http" },
|
||||||
|
|
@ -1398,6 +1416,7 @@ requires-dist = [
|
||||||
{ name = "jsonschema" },
|
{ name = "jsonschema" },
|
||||||
{ name = "llama-stack-client", specifier = ">=0.2.15" },
|
{ name = "llama-stack-client", specifier = ">=0.2.15" },
|
||||||
{ name = "llama-stack-client", marker = "extra == 'ui'", specifier = ">=0.2.15" },
|
{ name = "llama-stack-client", marker = "extra == 'ui'", specifier = ">=0.2.15" },
|
||||||
|
{ name = "llama-api-client", specifier = ">=0.1.2" },
|
||||||
{ name = "openai", specifier = ">=1.66" },
|
{ name = "openai", specifier = ">=1.66" },
|
||||||
{ name = "opentelemetry-exporter-otlp-proto-http", specifier = ">=1.30.0" },
|
{ name = "opentelemetry-exporter-otlp-proto-http", specifier = ">=1.30.0" },
|
||||||
{ name = "opentelemetry-sdk", specifier = ">=1.30.0" },
|
{ name = "opentelemetry-sdk", specifier = ">=1.30.0" },
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue