diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index 4ce3ec272..50b867366 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -13,10 +13,20 @@ import threading from concurrent.futures import ThreadPoolExecutor from enum import Enum from pathlib import Path -from typing import Any, Generator, get_args, get_origin, Optional, Type, TypeVar, Union +from typing import Any, Generator, get_args, get_origin, Optional, TypeVar + +import httpx import yaml -from llama_stack_client import AsyncLlamaStackClient, LlamaStackClient, NOT_GIVEN +from llama_stack_client import ( + APIResponse, + AsyncAPIResponse, + AsyncLlamaStackClient, + AsyncStream, + LlamaStackClient, + NOT_GIVEN, + Stream, +) from pydantic import BaseModel, TypeAdapter from rich.console import Console @@ -66,7 +76,7 @@ def stream_across_asyncio_run_boundary( # make sure we make the generator in the event loop context gen = await async_gen_maker() try: - async for item in gen: + async for item in await gen: result_queue.put(item) except Exception as e: print(f"Error in generator {e}") @@ -112,31 +122,17 @@ def stream_across_asyncio_run_boundary( future.result() -def convert_pydantic_to_json_value(value: Any, cast_to: Type) -> dict: +def convert_pydantic_to_json_value(value: Any) -> Any: if isinstance(value, Enum): return value.value elif isinstance(value, list): - return [convert_pydantic_to_json_value(item, cast_to) for item in value] + return [convert_pydantic_to_json_value(item) for item in value] elif isinstance(value, dict): - return {k: convert_pydantic_to_json_value(v, cast_to) for k, v in value.items()} + return {k: convert_pydantic_to_json_value(v) for k, v in value.items()} elif isinstance(value, BaseModel): - # This is quite hacky and we should figure out how to use stuff from - # generated client-sdk code (using ApiResponse.parse() essentially) - value_dict = json.loads(value.model_dump_json()) - - origin = get_origin(cast_to) - if origin is Union: - args = get_args(cast_to) - for arg in args: - arg_name = arg.__name__.split(".")[-1] - value_name = value.__class__.__name__.split(".")[-1] - if arg_name == value_name: - return arg(**value_dict) - - # assume we have the correct association between the server-side type and the client-side type - return cast_to(**value_dict) - - return value + return json.loads(value.model_dump_json()) + else: + return value def convert_to_pydantic(annotation: Any, value: Any) -> Any: @@ -278,16 +274,28 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): if not self.endpoint_impls: raise ValueError("Client not initialized") - params = options.params or {} - params |= options.json_data or {} if stream: - return self._call_streaming(options.url, params, cast_to) + return self._call_streaming( + cast_to=cast_to, + options=options, + stream_cls=stream_cls, + ) else: - return await self._call_non_streaming(options.url, params, cast_to) + return await self._call_non_streaming( + cast_to=cast_to, + options=options, + ) async def _call_non_streaming( - self, path: str, body: dict = None, cast_to: Any = None + self, + *, + cast_to: Any, + options: Any, ): + path = options.url + + body = options.params or {} + body |= options.json_data or {} await start_trace(path, {"__location__": "library_client"}) try: func = self.endpoint_impls.get(path) @@ -295,11 +303,45 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): raise ValueError(f"No endpoint found for {path}") body = self._convert_body(path, body) - return convert_pydantic_to_json_value(await func(**body), cast_to) + result = await func(**body) + + json_content = json.dumps(convert_pydantic_to_json_value(result)) + mock_response = httpx.Response( + status_code=httpx.codes.OK, + content=json_content.encode("utf-8"), + headers={ + "Content-Type": "application/json", + }, + request=httpx.Request( + method=options.method, + url=options.url, + params=options.params, + headers=options.headers, + json=options.json_data, + ), + ) + response = APIResponse( + raw=mock_response, + client=self, + cast_to=cast_to, + options=options, + stream=False, + stream_cls=None, + ) + return response.parse() finally: await end_trace() - async def _call_streaming(self, path: str, body: dict = None, cast_to: Any = None): + async def _call_streaming( + self, + *, + cast_to: Any, + options: Any, + stream_cls: Any, + ): + path = options.url + body = options.params or {} + body |= options.json_data or {} await start_trace(path, {"__location__": "library_client"}) try: func = self.endpoint_impls.get(path) @@ -307,8 +349,40 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): raise ValueError(f"No endpoint found for {path}") body = self._convert_body(path, body) - async for chunk in await func(**body): - yield convert_pydantic_to_json_value(chunk, cast_to) + + async def gen(): + async for chunk in await func(**body): + data = json.dumps(convert_pydantic_to_json_value(chunk)) + sse_event = f"data: {data}\n\n" + yield sse_event.encode("utf-8") + + mock_response = httpx.Response( + status_code=httpx.codes.OK, + content=gen(), + headers={ + "Content-Type": "application/json", + }, + request=httpx.Request( + method=options.method, + url=options.url, + params=options.params, + headers=options.headers, + json=options.json_data, + ), + ) + origin = get_origin(stream_cls) + assert origin is Stream + args = get_args(stream_cls) + stream_cls = AsyncStream[args[0]] + response = AsyncAPIResponse( + raw=mock_response, + client=self, + cast_to=cast_to, + options=options, + stream=True, + stream_cls=stream_cls, + ) + return await response.parse() finally: await end_trace() diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index 75126c221..5671082d5 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -6,6 +6,7 @@ import logging import os +import re from pathlib import Path from typing import Any, Dict @@ -143,7 +144,7 @@ def replace_env_vars(config: Any, path: str = "") -> Any: if default_val is None: raise EnvVarError(env_var, path) else: - value = default_val + value = default_val if default_val != "null" else None # expand "~" from the values return os.path.expanduser(value) diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index 8f93c0c4b..f98c14443 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import asyncio -import json from contextlib import asynccontextmanager from typing import Dict, List, Optional, Protocol, Tuple @@ -54,10 +53,7 @@ def _parse_registry_values(values: List[str]) -> List[RoutableObjectWithProvider """Utility function to parse registry values into RoutableObjectWithProvider objects.""" all_objects = [] for value in values: - obj = pydantic.parse_obj_as( - RoutableObjectWithProvider, - json.loads(value), - ) + obj = pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(value) all_objects.append(obj) return all_objects @@ -89,14 +85,7 @@ class DiskDistributionRegistry(DistributionRegistry): if not json_str: return None - objects_data = json.loads(json_str) - # Return only the first object if any exist - if objects_data: - return pydantic.parse_obj_as( - RoutableObjectWithProvider, - json.loads(objects_data), - ) - return None + return pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(json_str) async def update(self, obj: RoutableObjectWithProvider) -> None: await self.kvstore.set( diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index a0e8c973f..4f3fda8c3 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -8,6 +8,7 @@ import json from typing import Dict, List from uuid import uuid4 +import pytest from llama_stack.providers.tests.env import get_env_or_fail from llama_stack_client.lib.agents.agent import Agent @@ -77,16 +78,20 @@ class TestCustomTool(CustomTool): return -1 -def get_agent_config_with_available_models_shields(llama_stack_client): +@pytest.fixture(scope="session") +def agent_config(llama_stack_client): available_models = [ model.identifier for model in llama_stack_client.models.list() - if model.identifier.startswith("meta-llama") + if model.identifier.startswith("meta-llama") and "405" not in model.identifier ] model_id = available_models[0] + print(f"Using model: {model_id}") available_shields = [ shield.identifier for shield in llama_stack_client.shields.list() ] + available_shields = available_shields[:1] + print(f"Using shield: {available_shields}") agent_config = AgentConfig( model=model_id, instructions="You are a helpful assistant", @@ -105,8 +110,7 @@ def get_agent_config_with_available_models_shields(llama_stack_client): return agent_config -def test_agent_simple(llama_stack_client): - agent_config = get_agent_config_with_available_models_shields(llama_stack_client) +def test_agent_simple(llama_stack_client, agent_config): agent = Agent(llama_stack_client, agent_config) session_id = agent.create_session(f"test-session-{uuid4()}") @@ -142,16 +146,18 @@ def test_agent_simple(llama_stack_client): assert "I can't" in logs_str -def test_builtin_tool_brave_search(llama_stack_client): - agent_config = get_agent_config_with_available_models_shields(llama_stack_client) - agent_config["tools"] = [ - { - "type": "brave_search", - "engine": "brave", - "api_key": get_env_or_fail("BRAVE_SEARCH_API_KEY"), - } - ] - print(agent_config) +def test_builtin_tool_brave_search(llama_stack_client, agent_config): + agent_config = { + **agent_config, + "tools": [ + { + "type": "brave_search", + "engine": "brave", + "api_key": get_env_or_fail("BRAVE_SEARCH_API_KEY"), + } + ], + } + print(f"Agent Config: {agent_config}") agent = Agent(llama_stack_client, agent_config) session_id = agent.create_session(f"test-session-{uuid4()}") @@ -174,13 +180,15 @@ def test_builtin_tool_brave_search(llama_stack_client): assert "No Violation" in logs_str -def test_builtin_tool_code_execution(llama_stack_client): - agent_config = get_agent_config_with_available_models_shields(llama_stack_client) - agent_config["tools"] = [ - { - "type": "code_interpreter", - } - ] +def test_builtin_tool_code_execution(llama_stack_client, agent_config): + agent_config = { + **agent_config, + "tools": [ + { + "type": "code_interpreter", + } + ], + } agent = Agent(llama_stack_client, agent_config) session_id = agent.create_session(f"test-session-{uuid4()}") @@ -200,34 +208,36 @@ def test_builtin_tool_code_execution(llama_stack_client): assert "Tool:code_interpreter Response" in logs_str -def test_custom_tool(llama_stack_client): - agent_config = get_agent_config_with_available_models_shields(llama_stack_client) - agent_config["model"] = "meta-llama/Llama-3.2-3B-Instruct" - agent_config["tools"] = [ - { - "type": "brave_search", - "engine": "brave", - "api_key": get_env_or_fail("BRAVE_SEARCH_API_KEY"), - }, - { - "function_name": "get_boiling_point", - "description": "Get the boiling point of a imaginary liquids (eg. polyjuice)", - "parameters": { - "liquid_name": { - "param_type": "str", - "description": "The name of the liquid", - "required": True, - }, - "celcius": { - "param_type": "boolean", - "description": "Whether to return the boiling point in Celcius", - "required": False, - }, +def test_custom_tool(llama_stack_client, agent_config): + agent_config = { + **agent_config, + "model": "meta-llama/Llama-3.2-3B-Instruct", + "tools": [ + { + "type": "brave_search", + "engine": "brave", + "api_key": get_env_or_fail("BRAVE_SEARCH_API_KEY"), }, - "type": "function_call", - }, - ] - agent_config["tool_prompt_format"] = "python_list" + { + "function_name": "get_boiling_point", + "description": "Get the boiling point of a imaginary liquids (eg. polyjuice)", + "parameters": { + "liquid_name": { + "param_type": "str", + "description": "The name of the liquid", + "required": True, + }, + "celcius": { + "param_type": "boolean", + "description": "Whether to return the boiling point in Celcius", + "required": False, + }, + }, + "type": "function_call", + }, + ], + "tool_prompt_format": "python_list", + } agent = Agent(llama_stack_client, agent_config, custom_tools=(TestCustomTool(),)) session_id = agent.create_session(f"test-session-{uuid4()}") diff --git a/tests/client-sdk/conftest.py b/tests/client-sdk/conftest.py index 4e56254c1..2366008dd 100644 --- a/tests/client-sdk/conftest.py +++ b/tests/client-sdk/conftest.py @@ -3,13 +3,22 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import os + import pytest +from llama_stack import LlamaStackAsLibraryClient from llama_stack.providers.tests.env import get_env_or_fail from llama_stack_client import LlamaStackClient -@pytest.fixture +@pytest.fixture(scope="session") def llama_stack_client(): - """Fixture to create a fresh LlamaStackClient instance for each test""" - return LlamaStackClient(base_url=get_env_or_fail("LLAMA_STACK_BASE_URL")) + if os.environ.get("LLAMA_STACK_CONFIG"): + client = LlamaStackAsLibraryClient(get_env_or_fail("LLAMA_STACK_CONFIG")) + client.initialize() + elif os.environ.get("LLAMA_STACK_BASE_URL"): + client = LlamaStackClient(base_url=get_env_or_fail("LLAMA_STACK_BASE_URL")) + else: + raise ValueError("LLAMA_STACK_CONFIG or LLAMA_STACK_BASE_URL must be set") + return client diff --git a/tests/client-sdk/inference/test_inference.py b/tests/client-sdk/inference/test_inference.py index 245524510..d00ae12a8 100644 --- a/tests/client-sdk/inference/test_inference.py +++ b/tests/client-sdk/inference/test_inference.py @@ -4,10 +4,23 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import sys +import traceback +import warnings + import pytest from llama_stack_client.lib.inference.event_logger import EventLogger +def warn_with_traceback(message, category, filename, lineno, file=None, line=None): + log = file if hasattr(file, "write") else sys.stderr + traceback.print_stack(file=log) + log.write(warnings.formatwarning(message, category, filename, lineno, line)) + + +warnings.showwarning = warn_with_traceback + + def test_text_chat_completion(llama_stack_client): # non-streaming available_models = [ @@ -55,11 +68,15 @@ def test_image_chat_completion(llama_stack_client): "role": "user", "content": [ { - "image": { + "type": "image", + "data": { "uri": "https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg" - } + }, + }, + { + "type": "text", + "text": "Describe what is in this image.", }, - "Describe what is in this image.", ], } response = llama_stack_client.inference.chat_completion( diff --git a/tests/client-sdk/memory/test_memory.py b/tests/client-sdk/memory/test_memory.py index 8465d5aef..bb5c60240 100644 --- a/tests/client-sdk/memory/test_memory.py +++ b/tests/client-sdk/memory/test_memory.py @@ -22,6 +22,7 @@ def test_memory_bank(llama_stack_client): llama_stack_client.memory_banks.register( memory_bank_id=memory_bank_id, params={ + "memory_bank_type": "vector", "embedding_model": "all-MiniLM-L6-v2", "chunk_size_in_tokens": 512, "overlap_size_in_tokens": 64,