mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
Make LlamaStackLibraryClient work correctly (#581)
This PR does a few things: - it moves "direct client" to llama-stack repo instead of being in the llama-stack-client-python repo - renames it to `LlamaStackLibraryClient` - actually makes synchronous generators work - makes streaming and non-streaming work properly In many ways, this PR makes things finally "work" ## Test Plan See a `library_client_test.py` I added. This isn't really quite a test yet but it demonstrates that this mode now works. Here's the invocation and the response: ``` INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct python llama_stack/distribution/tests/library_client_test.py ollama ``` 
This commit is contained in:
parent
b3cb8eaa38
commit
14f973a64f
4 changed files with 378 additions and 4 deletions
|
@ -46,7 +46,7 @@ class ApiInput(BaseModel):
|
|||
|
||||
|
||||
def get_provider_dependencies(
|
||||
config_providers: Dict[str, List[Provider]]
|
||||
config_providers: Dict[str, List[Provider]],
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""Get normal and special dependencies from provider configuration."""
|
||||
all_providers = get_provider_registry()
|
||||
|
@ -92,11 +92,11 @@ def print_pip_install_help(providers: Dict[str, List[Provider]]):
|
|||
normal_deps, special_deps = get_provider_dependencies(providers)
|
||||
|
||||
cprint(
|
||||
f"Please install needed dependencies using the following commands:\n\n\tpip install {' '.join(normal_deps)}",
|
||||
f"Please install needed dependencies using the following commands:\n\npip install {' '.join(normal_deps)}",
|
||||
"yellow",
|
||||
)
|
||||
for special_dep in special_deps:
|
||||
cprint(f"\tpip install {special_dep}", "yellow")
|
||||
cprint(f"pip install {special_dep}", "yellow")
|
||||
print()
|
||||
|
||||
|
||||
|
|
272
llama_stack/distribution/library_client.py
Normal file
272
llama_stack/distribution/library_client.py
Normal file
|
@ -0,0 +1,272 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import queue
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator, get_args, get_origin, Optional, TypeVar
|
||||
|
||||
import yaml
|
||||
from llama_stack_client import AsyncLlamaStackClient, LlamaStackClient, NOT_GIVEN
|
||||
from pydantic import TypeAdapter
|
||||
from rich.console import Console
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.build import print_pip_install_help
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.resolver import ProviderRegistry
|
||||
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
||||
from llama_stack.distribution.stack import (
|
||||
construct_stack,
|
||||
get_stack_run_config_from_template,
|
||||
replace_env_vars,
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def stream_across_asyncio_run_boundary(
|
||||
async_gen_maker,
|
||||
pool_executor: ThreadPoolExecutor,
|
||||
) -> Generator[T, None, None]:
|
||||
result_queue = queue.Queue()
|
||||
stop_event = threading.Event()
|
||||
|
||||
async def consumer():
|
||||
# make sure we make the generator in the event loop context
|
||||
gen = await async_gen_maker()
|
||||
try:
|
||||
async for item in gen:
|
||||
result_queue.put(item)
|
||||
except Exception as e:
|
||||
print(f"Error in generator {e}")
|
||||
result_queue.put(e)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
finally:
|
||||
result_queue.put(StopIteration)
|
||||
stop_event.set()
|
||||
|
||||
def run_async():
|
||||
# Run our own loop to avoid double async generator cleanup which is done
|
||||
# by asyncio.run()
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
task = loop.create_task(consumer())
|
||||
loop.run_until_complete(task)
|
||||
finally:
|
||||
# Handle pending tasks like a generator's athrow()
|
||||
pending = asyncio.all_tasks(loop)
|
||||
if pending:
|
||||
loop.run_until_complete(
|
||||
asyncio.gather(*pending, return_exceptions=True)
|
||||
)
|
||||
loop.close()
|
||||
|
||||
future = pool_executor.submit(run_async)
|
||||
|
||||
try:
|
||||
# yield results as they come in
|
||||
while not stop_event.is_set() or not result_queue.empty():
|
||||
try:
|
||||
item = result_queue.get(timeout=0.1)
|
||||
if item is StopIteration:
|
||||
break
|
||||
if isinstance(item, Exception):
|
||||
raise item
|
||||
yield item
|
||||
except queue.Empty:
|
||||
continue
|
||||
finally:
|
||||
future.result()
|
||||
|
||||
|
||||
class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||
def __init__(
|
||||
self,
|
||||
config_path_or_template_name: str,
|
||||
custom_provider_registry: Optional[ProviderRegistry] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.async_client = AsyncLlamaStackAsLibraryClient(
|
||||
config_path_or_template_name, custom_provider_registry
|
||||
)
|
||||
self.pool_executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
def initialize(self):
|
||||
asyncio.run(self.async_client.initialize())
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
if kwargs.get("stream"):
|
||||
return stream_across_asyncio_run_boundary(
|
||||
lambda: self.async_client.get(*args, **kwargs),
|
||||
self.pool_executor,
|
||||
)
|
||||
else:
|
||||
return asyncio.run(self.async_client.get(*args, **kwargs))
|
||||
|
||||
def post(self, *args, **kwargs):
|
||||
if kwargs.get("stream"):
|
||||
return stream_across_asyncio_run_boundary(
|
||||
lambda: self.async_client.post(*args, **kwargs),
|
||||
self.pool_executor,
|
||||
)
|
||||
else:
|
||||
return asyncio.run(self.async_client.post(*args, **kwargs))
|
||||
|
||||
|
||||
class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||
def __init__(
|
||||
self,
|
||||
config_path_or_template_name: str,
|
||||
custom_provider_registry: Optional[ProviderRegistry] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if config_path_or_template_name.endswith(".yaml"):
|
||||
config_path = Path(config_path_or_template_name)
|
||||
if not config_path.exists():
|
||||
raise ValueError(f"Config file {config_path} does not exist")
|
||||
config_dict = replace_env_vars(yaml.safe_load(config_path.read_text()))
|
||||
config = parse_and_maybe_upgrade_config(config_dict)
|
||||
else:
|
||||
# template
|
||||
config = get_stack_run_config_from_template(config_path_or_template_name)
|
||||
|
||||
self.config_path_or_template_name = config_path_or_template_name
|
||||
self.config = config
|
||||
self.custom_provider_registry = custom_provider_registry
|
||||
|
||||
async def initialize(self):
|
||||
try:
|
||||
self.impls = await construct_stack(
|
||||
self.config, self.custom_provider_registry
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
cprint(
|
||||
"Using llama-stack as a library requires installing dependencies depending on the template (providers) you choose.\n",
|
||||
"yellow",
|
||||
)
|
||||
print_pip_install_help(self.config.providers)
|
||||
raise e
|
||||
|
||||
console = Console()
|
||||
console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:")
|
||||
console.print(yaml.dump(self.config.model_dump(), indent=2))
|
||||
|
||||
endpoints = get_all_api_endpoints()
|
||||
endpoint_impls = {}
|
||||
for api, api_endpoints in endpoints.items():
|
||||
for endpoint in api_endpoints:
|
||||
impl = self.impls[api]
|
||||
func = getattr(impl, endpoint.name)
|
||||
endpoint_impls[endpoint.route] = func
|
||||
|
||||
self.endpoint_impls = endpoint_impls
|
||||
|
||||
async def get(
|
||||
self,
|
||||
path: str,
|
||||
*,
|
||||
stream=False,
|
||||
**kwargs,
|
||||
):
|
||||
if not self.endpoint_impls:
|
||||
raise ValueError("Client not initialized")
|
||||
|
||||
if stream:
|
||||
return self._call_streaming(path, "GET")
|
||||
else:
|
||||
return await self._call_non_streaming(path, "GET")
|
||||
|
||||
async def post(
|
||||
self,
|
||||
path: str,
|
||||
*,
|
||||
body: dict = None,
|
||||
stream=False,
|
||||
**kwargs,
|
||||
):
|
||||
if not self.endpoint_impls:
|
||||
raise ValueError("Client not initialized")
|
||||
|
||||
if stream:
|
||||
return self._call_streaming(path, "POST", body)
|
||||
else:
|
||||
return await self._call_non_streaming(path, "POST", body)
|
||||
|
||||
async def _call_non_streaming(self, path: str, method: str, body: dict = None):
|
||||
func = self.endpoint_impls.get(path)
|
||||
if not func:
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
|
||||
body = self._convert_body(path, body)
|
||||
return await func(**body)
|
||||
|
||||
async def _call_streaming(self, path: str, method: str, body: dict = None):
|
||||
func = self.endpoint_impls.get(path)
|
||||
if not func:
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
|
||||
body = self._convert_body(path, body)
|
||||
async for chunk in await func(**body):
|
||||
yield chunk
|
||||
|
||||
def _convert_body(self, path: str, body: Optional[dict] = None) -> dict:
|
||||
if not body:
|
||||
return {}
|
||||
|
||||
func = self.endpoint_impls[path]
|
||||
sig = inspect.signature(func)
|
||||
|
||||
# Strip NOT_GIVENs to use the defaults in signature
|
||||
body = {k: v for k, v in body.items() if v is not NOT_GIVEN}
|
||||
|
||||
# Convert parameters to Pydantic models where needed
|
||||
converted_body = {}
|
||||
for param_name, param in sig.parameters.items():
|
||||
if param_name in body:
|
||||
value = body.get(param_name)
|
||||
converted_body[param_name] = self._convert_param(
|
||||
param.annotation, value
|
||||
)
|
||||
return converted_body
|
||||
|
||||
def _convert_param(self, annotation: Any, value: Any) -> Any:
|
||||
if isinstance(annotation, type) and annotation in {str, int, float, bool}:
|
||||
return value
|
||||
|
||||
origin = get_origin(annotation)
|
||||
if origin is list:
|
||||
item_type = get_args(annotation)[0]
|
||||
try:
|
||||
return [self._convert_param(item_type, item) for item in value]
|
||||
except Exception:
|
||||
print(f"Error converting list {value}")
|
||||
return value
|
||||
|
||||
elif origin is dict:
|
||||
key_type, val_type = get_args(annotation)
|
||||
try:
|
||||
return {k: self._convert_param(val_type, v) for k, v in value.items()}
|
||||
except Exception:
|
||||
print(f"Error converting dict {value}")
|
||||
return value
|
||||
|
||||
try:
|
||||
# Handle Pydantic models and discriminated unions
|
||||
return TypeAdapter(annotation).validate_python(value)
|
||||
except Exception as e:
|
||||
cprint(
|
||||
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
|
||||
"yellow",
|
||||
)
|
||||
return value
|
103
llama_stack/distribution/tests/library_client_test.py
Normal file
103
llama_stack/distribution/tests/library_client_test.py
Normal file
|
@ -0,0 +1,103 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client.lib.agents.event_logger import EventLogger as AgentEventLogger
|
||||
from llama_stack_client.lib.inference.event_logger import EventLogger
|
||||
from llama_stack_client.types import UserMessage
|
||||
from llama_stack_client.types.agent_create_params import AgentConfig
|
||||
|
||||
|
||||
def main(config_path: str):
|
||||
client = LlamaStackAsLibraryClient(config_path)
|
||||
client.initialize()
|
||||
|
||||
models = client.models.list()
|
||||
print("\nModels:")
|
||||
for model in models:
|
||||
print(model)
|
||||
|
||||
if not models:
|
||||
print("No models found, skipping chat completion test")
|
||||
return
|
||||
|
||||
model_id = models[0].identifier
|
||||
response = client.inference.chat_completion(
|
||||
messages=[UserMessage(content="What is the capital of France?", role="user")],
|
||||
model_id=model_id,
|
||||
stream=False,
|
||||
)
|
||||
print("\nChat completion response (non-stream):")
|
||||
print(response)
|
||||
|
||||
response = client.inference.chat_completion(
|
||||
messages=[UserMessage(content="What is the capital of France?", role="user")],
|
||||
model_id=model_id,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
print("\nChat completion response (stream):")
|
||||
for log in EventLogger().log(response):
|
||||
log.print()
|
||||
|
||||
print("\nAgent test:")
|
||||
agent_config = AgentConfig(
|
||||
model=model_id,
|
||||
instructions="You are a helpful assistant",
|
||||
sampling_params={
|
||||
"strategy": "greedy",
|
||||
"temperature": 1.0,
|
||||
"top_p": 0.9,
|
||||
},
|
||||
tools=(
|
||||
[
|
||||
{
|
||||
"type": "brave_search",
|
||||
"engine": "brave",
|
||||
"api_key": os.getenv("BRAVE_SEARCH_API_KEY"),
|
||||
}
|
||||
]
|
||||
if os.getenv("BRAVE_SEARCH_API_KEY")
|
||||
else []
|
||||
),
|
||||
tool_choice="auto",
|
||||
tool_prompt_format="json",
|
||||
input_shields=[],
|
||||
output_shields=[],
|
||||
enable_session_persistence=False,
|
||||
)
|
||||
agent = Agent(client, agent_config)
|
||||
user_prompts = [
|
||||
"Hello",
|
||||
"Which players played in the winning team of the NBA western conference semifinals of 2024, please use tools",
|
||||
]
|
||||
|
||||
session_id = agent.create_session("test-session")
|
||||
|
||||
for prompt in user_prompts:
|
||||
response = agent.create_turn(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
for log in AgentEventLogger().log(response):
|
||||
log.print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("config_path", help="Path to the config YAML file")
|
||||
args = parser.parse_args()
|
||||
main(args.config_path)
|
|
@ -269,7 +269,6 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
r = await self.client.chat(**params)
|
||||
else:
|
||||
r = await self.client.generate(**params)
|
||||
assert isinstance(r, dict)
|
||||
|
||||
if "message" in r:
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue