From 14f973a64f4f6bee011d94910eea67d75375998f Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Sat, 7 Dec 2024 14:59:36 -0800 Subject: [PATCH] Make LlamaStackLibraryClient work correctly (#581) This PR does a few things: - it moves "direct client" to llama-stack repo instead of being in the llama-stack-client-python repo - renames it to `LlamaStackLibraryClient` - actually makes synchronous generators work - makes streaming and non-streaming work properly In many ways, this PR makes things finally "work" ## Test Plan See a `library_client_test.py` I added. This isn't really quite a test yet but it demonstrates that this mode now works. Here's the invocation and the response: ``` INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct python llama_stack/distribution/tests/library_client_test.py ollama ``` ![image](https://github.com/user-attachments/assets/17d4e116-4457-4755-a14e-d9a668801fe0) --- llama_stack/distribution/build.py | 6 +- llama_stack/distribution/library_client.py | 272 ++++++++++++++++++ .../distribution/tests/library_client_test.py | 103 +++++++ .../remote/inference/ollama/ollama.py | 1 - 4 files changed, 378 insertions(+), 4 deletions(-) create mode 100644 llama_stack/distribution/library_client.py create mode 100644 llama_stack/distribution/tests/library_client_test.py diff --git a/llama_stack/distribution/build.py b/llama_stack/distribution/build.py index 9d0ad9af4..3349a7d50 100644 --- a/llama_stack/distribution/build.py +++ b/llama_stack/distribution/build.py @@ -46,7 +46,7 @@ class ApiInput(BaseModel): def get_provider_dependencies( - config_providers: Dict[str, List[Provider]] + config_providers: Dict[str, List[Provider]], ) -> tuple[list[str], list[str]]: """Get normal and special dependencies from provider configuration.""" all_providers = get_provider_registry() @@ -92,11 +92,11 @@ def print_pip_install_help(providers: Dict[str, List[Provider]]): normal_deps, special_deps = get_provider_dependencies(providers) cprint( - f"Please install needed dependencies using the following commands:\n\n\tpip install {' '.join(normal_deps)}", + f"Please install needed dependencies using the following commands:\n\npip install {' '.join(normal_deps)}", "yellow", ) for special_dep in special_deps: - cprint(f"\tpip install {special_dep}", "yellow") + cprint(f"pip install {special_dep}", "yellow") print() diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py new file mode 100644 index 000000000..4de06ae08 --- /dev/null +++ b/llama_stack/distribution/library_client.py @@ -0,0 +1,272 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +import inspect +import queue +import threading +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import Any, Generator, get_args, get_origin, Optional, TypeVar + +import yaml +from llama_stack_client import AsyncLlamaStackClient, LlamaStackClient, NOT_GIVEN +from pydantic import TypeAdapter +from rich.console import Console + +from termcolor import cprint + +from llama_stack.distribution.build import print_pip_install_help +from llama_stack.distribution.configure import parse_and_maybe_upgrade_config +from llama_stack.distribution.resolver import ProviderRegistry +from llama_stack.distribution.server.endpoints import get_all_api_endpoints +from llama_stack.distribution.stack import ( + construct_stack, + get_stack_run_config_from_template, + replace_env_vars, +) + +T = TypeVar("T") + + +def stream_across_asyncio_run_boundary( + async_gen_maker, + pool_executor: ThreadPoolExecutor, +) -> Generator[T, None, None]: + result_queue = queue.Queue() + stop_event = threading.Event() + + async def consumer(): + # make sure we make the generator in the event loop context + gen = await async_gen_maker() + try: + async for item in gen: + result_queue.put(item) + except Exception as e: + print(f"Error in generator {e}") + result_queue.put(e) + except asyncio.CancelledError: + return + finally: + result_queue.put(StopIteration) + stop_event.set() + + def run_async(): + # Run our own loop to avoid double async generator cleanup which is done + # by asyncio.run() + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + task = loop.create_task(consumer()) + loop.run_until_complete(task) + finally: + # Handle pending tasks like a generator's athrow() + pending = asyncio.all_tasks(loop) + if pending: + loop.run_until_complete( + asyncio.gather(*pending, return_exceptions=True) + ) + loop.close() + + future = pool_executor.submit(run_async) + + try: + # yield results as they come in + while not stop_event.is_set() or not result_queue.empty(): + try: + item = result_queue.get(timeout=0.1) + if item is StopIteration: + break + if isinstance(item, Exception): + raise item + yield item + except queue.Empty: + continue + finally: + future.result() + + +class LlamaStackAsLibraryClient(LlamaStackClient): + def __init__( + self, + config_path_or_template_name: str, + custom_provider_registry: Optional[ProviderRegistry] = None, + ): + super().__init__() + self.async_client = AsyncLlamaStackAsLibraryClient( + config_path_or_template_name, custom_provider_registry + ) + self.pool_executor = ThreadPoolExecutor(max_workers=4) + + def initialize(self): + asyncio.run(self.async_client.initialize()) + + def get(self, *args, **kwargs): + if kwargs.get("stream"): + return stream_across_asyncio_run_boundary( + lambda: self.async_client.get(*args, **kwargs), + self.pool_executor, + ) + else: + return asyncio.run(self.async_client.get(*args, **kwargs)) + + def post(self, *args, **kwargs): + if kwargs.get("stream"): + return stream_across_asyncio_run_boundary( + lambda: self.async_client.post(*args, **kwargs), + self.pool_executor, + ) + else: + return asyncio.run(self.async_client.post(*args, **kwargs)) + + +class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): + def __init__( + self, + config_path_or_template_name: str, + custom_provider_registry: Optional[ProviderRegistry] = None, + ): + super().__init__() + + if config_path_or_template_name.endswith(".yaml"): + config_path = Path(config_path_or_template_name) + if not config_path.exists(): + raise ValueError(f"Config file {config_path} does not exist") + config_dict = replace_env_vars(yaml.safe_load(config_path.read_text())) + config = parse_and_maybe_upgrade_config(config_dict) + else: + # template + config = get_stack_run_config_from_template(config_path_or_template_name) + + self.config_path_or_template_name = config_path_or_template_name + self.config = config + self.custom_provider_registry = custom_provider_registry + + async def initialize(self): + try: + self.impls = await construct_stack( + self.config, self.custom_provider_registry + ) + except ModuleNotFoundError as e: + cprint( + "Using llama-stack as a library requires installing dependencies depending on the template (providers) you choose.\n", + "yellow", + ) + print_pip_install_help(self.config.providers) + raise e + + console = Console() + console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:") + console.print(yaml.dump(self.config.model_dump(), indent=2)) + + endpoints = get_all_api_endpoints() + endpoint_impls = {} + for api, api_endpoints in endpoints.items(): + for endpoint in api_endpoints: + impl = self.impls[api] + func = getattr(impl, endpoint.name) + endpoint_impls[endpoint.route] = func + + self.endpoint_impls = endpoint_impls + + async def get( + self, + path: str, + *, + stream=False, + **kwargs, + ): + if not self.endpoint_impls: + raise ValueError("Client not initialized") + + if stream: + return self._call_streaming(path, "GET") + else: + return await self._call_non_streaming(path, "GET") + + async def post( + self, + path: str, + *, + body: dict = None, + stream=False, + **kwargs, + ): + if not self.endpoint_impls: + raise ValueError("Client not initialized") + + if stream: + return self._call_streaming(path, "POST", body) + else: + return await self._call_non_streaming(path, "POST", body) + + async def _call_non_streaming(self, path: str, method: str, body: dict = None): + func = self.endpoint_impls.get(path) + if not func: + raise ValueError(f"No endpoint found for {path}") + + body = self._convert_body(path, body) + return await func(**body) + + async def _call_streaming(self, path: str, method: str, body: dict = None): + func = self.endpoint_impls.get(path) + if not func: + raise ValueError(f"No endpoint found for {path}") + + body = self._convert_body(path, body) + async for chunk in await func(**body): + yield chunk + + def _convert_body(self, path: str, body: Optional[dict] = None) -> dict: + if not body: + return {} + + func = self.endpoint_impls[path] + sig = inspect.signature(func) + + # Strip NOT_GIVENs to use the defaults in signature + body = {k: v for k, v in body.items() if v is not NOT_GIVEN} + + # Convert parameters to Pydantic models where needed + converted_body = {} + for param_name, param in sig.parameters.items(): + if param_name in body: + value = body.get(param_name) + converted_body[param_name] = self._convert_param( + param.annotation, value + ) + return converted_body + + def _convert_param(self, annotation: Any, value: Any) -> Any: + if isinstance(annotation, type) and annotation in {str, int, float, bool}: + return value + + origin = get_origin(annotation) + if origin is list: + item_type = get_args(annotation)[0] + try: + return [self._convert_param(item_type, item) for item in value] + except Exception: + print(f"Error converting list {value}") + return value + + elif origin is dict: + key_type, val_type = get_args(annotation) + try: + return {k: self._convert_param(val_type, v) for k, v in value.items()} + except Exception: + print(f"Error converting dict {value}") + return value + + try: + # Handle Pydantic models and discriminated unions + return TypeAdapter(annotation).validate_python(value) + except Exception as e: + cprint( + f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}", + "yellow", + ) + return value diff --git a/llama_stack/distribution/tests/library_client_test.py b/llama_stack/distribution/tests/library_client_test.py new file mode 100644 index 000000000..8381f5470 --- /dev/null +++ b/llama_stack/distribution/tests/library_client_test.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import argparse +import os + +from llama_stack.distribution.library_client import LlamaStackAsLibraryClient +from llama_stack_client.lib.agents.agent import Agent +from llama_stack_client.lib.agents.event_logger import EventLogger as AgentEventLogger +from llama_stack_client.lib.inference.event_logger import EventLogger +from llama_stack_client.types import UserMessage +from llama_stack_client.types.agent_create_params import AgentConfig + + +def main(config_path: str): + client = LlamaStackAsLibraryClient(config_path) + client.initialize() + + models = client.models.list() + print("\nModels:") + for model in models: + print(model) + + if not models: + print("No models found, skipping chat completion test") + return + + model_id = models[0].identifier + response = client.inference.chat_completion( + messages=[UserMessage(content="What is the capital of France?", role="user")], + model_id=model_id, + stream=False, + ) + print("\nChat completion response (non-stream):") + print(response) + + response = client.inference.chat_completion( + messages=[UserMessage(content="What is the capital of France?", role="user")], + model_id=model_id, + stream=True, + ) + + print("\nChat completion response (stream):") + for log in EventLogger().log(response): + log.print() + + print("\nAgent test:") + agent_config = AgentConfig( + model=model_id, + instructions="You are a helpful assistant", + sampling_params={ + "strategy": "greedy", + "temperature": 1.0, + "top_p": 0.9, + }, + tools=( + [ + { + "type": "brave_search", + "engine": "brave", + "api_key": os.getenv("BRAVE_SEARCH_API_KEY"), + } + ] + if os.getenv("BRAVE_SEARCH_API_KEY") + else [] + ), + tool_choice="auto", + tool_prompt_format="json", + input_shields=[], + output_shields=[], + enable_session_persistence=False, + ) + agent = Agent(client, agent_config) + user_prompts = [ + "Hello", + "Which players played in the winning team of the NBA western conference semifinals of 2024, please use tools", + ] + + session_id = agent.create_session("test-session") + + for prompt in user_prompts: + response = agent.create_turn( + messages=[ + { + "role": "user", + "content": prompt, + } + ], + session_id=session_id, + ) + + for log in AgentEventLogger().log(response): + log.print() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("config_path", help="Path to the config YAML file") + args = parser.parse_args() + main(args.config_path) diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index f89629afc..d6fa20835 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -269,7 +269,6 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): r = await self.client.chat(**params) else: r = await self.client.generate(**params) - assert isinstance(r, dict) if "message" in r: choice = OpenAICompatCompletionChoice(