mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
Moved llama-stack-as-library client to llama-stack
This commit is contained in:
parent
b3cb8eaa38
commit
fd48cf3fc1
4 changed files with 235 additions and 4 deletions
|
@ -46,7 +46,7 @@ class ApiInput(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
def get_provider_dependencies(
|
def get_provider_dependencies(
|
||||||
config_providers: Dict[str, List[Provider]]
|
config_providers: Dict[str, List[Provider]],
|
||||||
) -> tuple[list[str], list[str]]:
|
) -> tuple[list[str], list[str]]:
|
||||||
"""Get normal and special dependencies from provider configuration."""
|
"""Get normal and special dependencies from provider configuration."""
|
||||||
all_providers = get_provider_registry()
|
all_providers = get_provider_registry()
|
||||||
|
@ -92,11 +92,11 @@ def print_pip_install_help(providers: Dict[str, List[Provider]]):
|
||||||
normal_deps, special_deps = get_provider_dependencies(providers)
|
normal_deps, special_deps = get_provider_dependencies(providers)
|
||||||
|
|
||||||
cprint(
|
cprint(
|
||||||
f"Please install needed dependencies using the following commands:\n\n\tpip install {' '.join(normal_deps)}",
|
f"Please install needed dependencies using the following commands:\n\npip install {' '.join(normal_deps)}",
|
||||||
"yellow",
|
"yellow",
|
||||||
)
|
)
|
||||||
for special_dep in special_deps:
|
for special_dep in special_deps:
|
||||||
cprint(f"\tpip install {special_dep}", "yellow")
|
cprint(f"pip install {special_dep}", "yellow")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
|
173
llama_stack/distribution/library_client.py
Normal file
173
llama_stack/distribution/library_client.py
Normal file
|
@ -0,0 +1,173 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, get_args, get_origin, Optional
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from llama_stack_client import LlamaStackClient, NOT_GIVEN
|
||||||
|
from pydantic import TypeAdapter
|
||||||
|
from rich.console import Console
|
||||||
|
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
|
from llama_stack.distribution.build import print_pip_install_help
|
||||||
|
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||||
|
from llama_stack.distribution.resolver import ProviderRegistry
|
||||||
|
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
||||||
|
from llama_stack.distribution.stack import (
|
||||||
|
construct_stack,
|
||||||
|
get_stack_run_config_from_template,
|
||||||
|
replace_env_vars,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config_path_or_template_name: str,
|
||||||
|
custom_provider_registry: Optional[ProviderRegistry] = None,
|
||||||
|
):
|
||||||
|
if config_path_or_template_name.endswith(".yaml"):
|
||||||
|
config_path = Path(config_path_or_template_name)
|
||||||
|
if not config_path.exists():
|
||||||
|
raise ValueError(f"Config file {config_path} does not exist")
|
||||||
|
config_dict = replace_env_vars(yaml.safe_load(config_path.read_text()))
|
||||||
|
config = parse_and_maybe_upgrade_config(config_dict)
|
||||||
|
else:
|
||||||
|
# template
|
||||||
|
config = get_stack_run_config_from_template(config_path_or_template_name)
|
||||||
|
|
||||||
|
self.config_path_or_template_name = config_path_or_template_name
|
||||||
|
self.config = config
|
||||||
|
self.custom_provider_registry = custom_provider_registry
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
try:
|
||||||
|
self.impls = await construct_stack(
|
||||||
|
self.config, self.custom_provider_registry
|
||||||
|
)
|
||||||
|
except ModuleNotFoundError as e:
|
||||||
|
cprint(
|
||||||
|
"Using llama-stack as a library requires installing dependencies depending on the template (providers) you choose.\n",
|
||||||
|
"yellow",
|
||||||
|
)
|
||||||
|
print_pip_install_help(self.config.providers)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
console = Console()
|
||||||
|
console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:")
|
||||||
|
console.print(yaml.dump(self.config.model_dump(), indent=2))
|
||||||
|
|
||||||
|
endpoints = get_all_api_endpoints()
|
||||||
|
endpoint_impls = {}
|
||||||
|
for api, api_endpoints in endpoints.items():
|
||||||
|
for endpoint in api_endpoints:
|
||||||
|
impl = self.impls[api]
|
||||||
|
func = getattr(impl, endpoint.name)
|
||||||
|
endpoint_impls[endpoint.route] = func
|
||||||
|
|
||||||
|
self.endpoint_impls = endpoint_impls
|
||||||
|
|
||||||
|
async def get(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
*,
|
||||||
|
stream=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if not self.endpoint_impls:
|
||||||
|
raise ValueError("Client not initialized")
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return self._call_streaming(path, "GET")
|
||||||
|
else:
|
||||||
|
return await self._call_non_streaming(path, "GET")
|
||||||
|
|
||||||
|
async def post(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
*,
|
||||||
|
body: dict = None,
|
||||||
|
stream=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if not self.endpoint_impls:
|
||||||
|
raise ValueError("Client not initialized")
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return self._call_streaming(path, "POST", body)
|
||||||
|
else:
|
||||||
|
return await self._call_non_streaming(path, "POST", body)
|
||||||
|
|
||||||
|
async def _call_non_streaming(self, path: str, method: str, body: dict = None):
|
||||||
|
func = self.endpoint_impls.get(path)
|
||||||
|
if not func:
|
||||||
|
raise ValueError(f"No endpoint found for {path}")
|
||||||
|
|
||||||
|
body = self._convert_body(path, body)
|
||||||
|
return await func(**body)
|
||||||
|
|
||||||
|
async def _call_streaming(self, path: str, method: str, body: dict = None):
|
||||||
|
func = self.endpoint_impls.get(path)
|
||||||
|
if not func:
|
||||||
|
raise ValueError(f"No endpoint found for {path}")
|
||||||
|
|
||||||
|
body = self._convert_body(path, body)
|
||||||
|
async for chunk in await func(**body):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
def _convert_body(self, path: str, body: Optional[dict] = None) -> dict:
|
||||||
|
if not body:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
func = self.endpoint_impls[path]
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
|
||||||
|
# Strip NOT_GIVENs to use the defaults in signature
|
||||||
|
body = {k: v for k, v in body.items() if v is not NOT_GIVEN}
|
||||||
|
|
||||||
|
# Convert parameters to Pydantic models where needed
|
||||||
|
converted_body = {}
|
||||||
|
for param_name, param in sig.parameters.items():
|
||||||
|
if param_name in body:
|
||||||
|
value = body.get(param_name)
|
||||||
|
converted_body[param_name] = self._convert_param(
|
||||||
|
param.annotation, value
|
||||||
|
)
|
||||||
|
return converted_body
|
||||||
|
|
||||||
|
def _convert_param(self, annotation: Any, value: Any) -> Any:
|
||||||
|
if isinstance(annotation, type) and annotation in {str, int, float, bool}:
|
||||||
|
return value
|
||||||
|
|
||||||
|
origin = get_origin(annotation)
|
||||||
|
if origin is list:
|
||||||
|
item_type = get_args(annotation)[0]
|
||||||
|
try:
|
||||||
|
return [self._convert_param(item_type, item) for item in value]
|
||||||
|
except Exception:
|
||||||
|
return value
|
||||||
|
|
||||||
|
elif origin is dict:
|
||||||
|
key_type, val_type = get_args(annotation)
|
||||||
|
try:
|
||||||
|
return {k: self._convert_param(val_type, v) for k, v in value.items()}
|
||||||
|
except Exception:
|
||||||
|
return value
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Handle Pydantic models and discriminated unions
|
||||||
|
return TypeAdapter(annotation).validate_python(value)
|
||||||
|
except Exception as e:
|
||||||
|
cprint(
|
||||||
|
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
|
||||||
|
"yellow",
|
||||||
|
)
|
||||||
|
return value
|
59
llama_stack/distribution/tests/library_client_test.py
Normal file
59
llama_stack/distribution/tests/library_client_test.py
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
|
||||||
|
from llama_stack_client.types import UserMessage
|
||||||
|
|
||||||
|
|
||||||
|
async def main(config_path: str):
|
||||||
|
client = LlamaStackAsLibraryClient(config_path)
|
||||||
|
await client.initialize()
|
||||||
|
|
||||||
|
models = await client.models.list()
|
||||||
|
print(models)
|
||||||
|
if not models:
|
||||||
|
print("No models found, skipping chat completion test")
|
||||||
|
return
|
||||||
|
|
||||||
|
model_id = models[0].identifier
|
||||||
|
response = await client.inference.chat_completion(
|
||||||
|
messages=[UserMessage(content="What is the capital of France?", role="user")],
|
||||||
|
model_id=model_id,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
print("\nChat completion response:")
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
response = await client.inference.chat_completion(
|
||||||
|
messages=[UserMessage(content="What is the capital of France?", role="user")],
|
||||||
|
model_id=model_id,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
print("\nChat completion stream response:")
|
||||||
|
async for chunk in response:
|
||||||
|
print(chunk)
|
||||||
|
|
||||||
|
response = await client.memory_banks.register(
|
||||||
|
memory_bank_id="memory_bank_id",
|
||||||
|
params={
|
||||||
|
"chunk_size_in_tokens": 0,
|
||||||
|
"embedding_model": "embedding_model",
|
||||||
|
"memory_bank_type": "vector",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
print("\nRegister memory bank response:")
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("config_path", help="Path to the config YAML file")
|
||||||
|
args = parser.parse_args()
|
||||||
|
asyncio.run(main(args.config_path))
|
|
@ -269,7 +269,6 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
r = await self.client.chat(**params)
|
r = await self.client.chat(**params)
|
||||||
else:
|
else:
|
||||||
r = await self.client.generate(**params)
|
r = await self.client.generate(**params)
|
||||||
assert isinstance(r, dict)
|
|
||||||
|
|
||||||
if "message" in r:
|
if "message" in r:
|
||||||
choice = OpenAICompatCompletionChoice(
|
choice = OpenAICompatCompletionChoice(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue