Fix LibraryClient completely correctly; also make tests pass

This commit is contained in:
Ashwin Bharambe 2024-12-16 22:16:21 -08:00
parent d4935ca439
commit 1bcc26ccd1
7 changed files with 201 additions and 100 deletions

View file

@ -13,10 +13,20 @@ import threading
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from pathlib import Path
from typing import Any, Generator, get_args, get_origin, Optional, Type, TypeVar, Union
from typing import Any, Generator, get_args, get_origin, Optional, TypeVar
import httpx
import yaml
from llama_stack_client import AsyncLlamaStackClient, LlamaStackClient, NOT_GIVEN
from llama_stack_client import (
APIResponse,
AsyncAPIResponse,
AsyncLlamaStackClient,
AsyncStream,
LlamaStackClient,
NOT_GIVEN,
Stream,
)
from pydantic import BaseModel, TypeAdapter
from rich.console import Console
@ -66,7 +76,7 @@ def stream_across_asyncio_run_boundary(
# make sure we make the generator in the event loop context
gen = await async_gen_maker()
try:
async for item in gen:
async for item in await gen:
result_queue.put(item)
except Exception as e:
print(f"Error in generator {e}")
@ -112,31 +122,17 @@ def stream_across_asyncio_run_boundary(
future.result()
def convert_pydantic_to_json_value(value: Any, cast_to: Type) -> dict:
def convert_pydantic_to_json_value(value: Any) -> Any:
if isinstance(value, Enum):
return value.value
elif isinstance(value, list):
return [convert_pydantic_to_json_value(item, cast_to) for item in value]
return [convert_pydantic_to_json_value(item) for item in value]
elif isinstance(value, dict):
return {k: convert_pydantic_to_json_value(v, cast_to) for k, v in value.items()}
return {k: convert_pydantic_to_json_value(v) for k, v in value.items()}
elif isinstance(value, BaseModel):
# This is quite hacky and we should figure out how to use stuff from
# generated client-sdk code (using ApiResponse.parse() essentially)
value_dict = json.loads(value.model_dump_json())
origin = get_origin(cast_to)
if origin is Union:
args = get_args(cast_to)
for arg in args:
arg_name = arg.__name__.split(".")[-1]
value_name = value.__class__.__name__.split(".")[-1]
if arg_name == value_name:
return arg(**value_dict)
# assume we have the correct association between the server-side type and the client-side type
return cast_to(**value_dict)
return value
return json.loads(value.model_dump_json())
else:
return value
def convert_to_pydantic(annotation: Any, value: Any) -> Any:
@ -278,16 +274,28 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
if not self.endpoint_impls:
raise ValueError("Client not initialized")
params = options.params or {}
params |= options.json_data or {}
if stream:
return self._call_streaming(options.url, params, cast_to)
return self._call_streaming(
cast_to=cast_to,
options=options,
stream_cls=stream_cls,
)
else:
return await self._call_non_streaming(options.url, params, cast_to)
return await self._call_non_streaming(
cast_to=cast_to,
options=options,
)
async def _call_non_streaming(
self, path: str, body: dict = None, cast_to: Any = None
self,
*,
cast_to: Any,
options: Any,
):
path = options.url
body = options.params or {}
body |= options.json_data or {}
await start_trace(path, {"__location__": "library_client"})
try:
func = self.endpoint_impls.get(path)
@ -295,11 +303,45 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
raise ValueError(f"No endpoint found for {path}")
body = self._convert_body(path, body)
return convert_pydantic_to_json_value(await func(**body), cast_to)
result = await func(**body)
json_content = json.dumps(convert_pydantic_to_json_value(result))
mock_response = httpx.Response(
status_code=httpx.codes.OK,
content=json_content.encode("utf-8"),
headers={
"Content-Type": "application/json",
},
request=httpx.Request(
method=options.method,
url=options.url,
params=options.params,
headers=options.headers,
json=options.json_data,
),
)
response = APIResponse(
raw=mock_response,
client=self,
cast_to=cast_to,
options=options,
stream=False,
stream_cls=None,
)
return response.parse()
finally:
await end_trace()
async def _call_streaming(self, path: str, body: dict = None, cast_to: Any = None):
async def _call_streaming(
self,
*,
cast_to: Any,
options: Any,
stream_cls: Any,
):
path = options.url
body = options.params or {}
body |= options.json_data or {}
await start_trace(path, {"__location__": "library_client"})
try:
func = self.endpoint_impls.get(path)
@ -307,8 +349,40 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
raise ValueError(f"No endpoint found for {path}")
body = self._convert_body(path, body)
async for chunk in await func(**body):
yield convert_pydantic_to_json_value(chunk, cast_to)
async def gen():
async for chunk in await func(**body):
data = json.dumps(convert_pydantic_to_json_value(chunk))
sse_event = f"data: {data}\n\n"
yield sse_event.encode("utf-8")
mock_response = httpx.Response(
status_code=httpx.codes.OK,
content=gen(),
headers={
"Content-Type": "application/json",
},
request=httpx.Request(
method=options.method,
url=options.url,
params=options.params,
headers=options.headers,
json=options.json_data,
),
)
origin = get_origin(stream_cls)
assert origin is Stream
args = get_args(stream_cls)
stream_cls = AsyncStream[args[0]]
response = AsyncAPIResponse(
raw=mock_response,
client=self,
cast_to=cast_to,
options=options,
stream=True,
stream_cls=stream_cls,
)
return await response.parse()
finally:
await end_trace()

View file

@ -6,6 +6,7 @@
import logging
import os
import re
from pathlib import Path
from typing import Any, Dict
@ -143,7 +144,7 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
if default_val is None:
raise EnvVarError(env_var, path)
else:
value = default_val
value = default_val if default_val != "null" else None
# expand "~" from the values
return os.path.expanduser(value)

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import asyncio
import json
from contextlib import asynccontextmanager
from typing import Dict, List, Optional, Protocol, Tuple
@ -54,10 +53,7 @@ def _parse_registry_values(values: List[str]) -> List[RoutableObjectWithProvider
"""Utility function to parse registry values into RoutableObjectWithProvider objects."""
all_objects = []
for value in values:
obj = pydantic.parse_obj_as(
RoutableObjectWithProvider,
json.loads(value),
)
obj = pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(value)
all_objects.append(obj)
return all_objects
@ -89,14 +85,7 @@ class DiskDistributionRegistry(DistributionRegistry):
if not json_str:
return None
objects_data = json.loads(json_str)
# Return only the first object if any exist
if objects_data:
return pydantic.parse_obj_as(
RoutableObjectWithProvider,
json.loads(objects_data),
)
return None
return pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(json_str)
async def update(self, obj: RoutableObjectWithProvider) -> None:
await self.kvstore.set(

View file

@ -8,6 +8,7 @@ import json
from typing import Dict, List
from uuid import uuid4
import pytest
from llama_stack.providers.tests.env import get_env_or_fail
from llama_stack_client.lib.agents.agent import Agent
@ -77,16 +78,20 @@ class TestCustomTool(CustomTool):
return -1
def get_agent_config_with_available_models_shields(llama_stack_client):
@pytest.fixture(scope="session")
def agent_config(llama_stack_client):
available_models = [
model.identifier
for model in llama_stack_client.models.list()
if model.identifier.startswith("meta-llama")
if model.identifier.startswith("meta-llama") and "405" not in model.identifier
]
model_id = available_models[0]
print(f"Using model: {model_id}")
available_shields = [
shield.identifier for shield in llama_stack_client.shields.list()
]
available_shields = available_shields[:1]
print(f"Using shield: {available_shields}")
agent_config = AgentConfig(
model=model_id,
instructions="You are a helpful assistant",
@ -105,8 +110,7 @@ def get_agent_config_with_available_models_shields(llama_stack_client):
return agent_config
def test_agent_simple(llama_stack_client):
agent_config = get_agent_config_with_available_models_shields(llama_stack_client)
def test_agent_simple(llama_stack_client, agent_config):
agent = Agent(llama_stack_client, agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}")
@ -142,16 +146,18 @@ def test_agent_simple(llama_stack_client):
assert "I can't" in logs_str
def test_builtin_tool_brave_search(llama_stack_client):
agent_config = get_agent_config_with_available_models_shields(llama_stack_client)
agent_config["tools"] = [
{
"type": "brave_search",
"engine": "brave",
"api_key": get_env_or_fail("BRAVE_SEARCH_API_KEY"),
}
]
print(agent_config)
def test_builtin_tool_brave_search(llama_stack_client, agent_config):
agent_config = {
**agent_config,
"tools": [
{
"type": "brave_search",
"engine": "brave",
"api_key": get_env_or_fail("BRAVE_SEARCH_API_KEY"),
}
],
}
print(f"Agent Config: {agent_config}")
agent = Agent(llama_stack_client, agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}")
@ -174,13 +180,15 @@ def test_builtin_tool_brave_search(llama_stack_client):
assert "No Violation" in logs_str
def test_builtin_tool_code_execution(llama_stack_client):
agent_config = get_agent_config_with_available_models_shields(llama_stack_client)
agent_config["tools"] = [
{
"type": "code_interpreter",
}
]
def test_builtin_tool_code_execution(llama_stack_client, agent_config):
agent_config = {
**agent_config,
"tools": [
{
"type": "code_interpreter",
}
],
}
agent = Agent(llama_stack_client, agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}")
@ -200,34 +208,36 @@ def test_builtin_tool_code_execution(llama_stack_client):
assert "Tool:code_interpreter Response" in logs_str
def test_custom_tool(llama_stack_client):
agent_config = get_agent_config_with_available_models_shields(llama_stack_client)
agent_config["model"] = "meta-llama/Llama-3.2-3B-Instruct"
agent_config["tools"] = [
{
"type": "brave_search",
"engine": "brave",
"api_key": get_env_or_fail("BRAVE_SEARCH_API_KEY"),
},
{
"function_name": "get_boiling_point",
"description": "Get the boiling point of a imaginary liquids (eg. polyjuice)",
"parameters": {
"liquid_name": {
"param_type": "str",
"description": "The name of the liquid",
"required": True,
},
"celcius": {
"param_type": "boolean",
"description": "Whether to return the boiling point in Celcius",
"required": False,
},
def test_custom_tool(llama_stack_client, agent_config):
agent_config = {
**agent_config,
"model": "meta-llama/Llama-3.2-3B-Instruct",
"tools": [
{
"type": "brave_search",
"engine": "brave",
"api_key": get_env_or_fail("BRAVE_SEARCH_API_KEY"),
},
"type": "function_call",
},
]
agent_config["tool_prompt_format"] = "python_list"
{
"function_name": "get_boiling_point",
"description": "Get the boiling point of a imaginary liquids (eg. polyjuice)",
"parameters": {
"liquid_name": {
"param_type": "str",
"description": "The name of the liquid",
"required": True,
},
"celcius": {
"param_type": "boolean",
"description": "Whether to return the boiling point in Celcius",
"required": False,
},
},
"type": "function_call",
},
],
"tool_prompt_format": "python_list",
}
agent = Agent(llama_stack_client, agent_config, custom_tools=(TestCustomTool(),))
session_id = agent.create_session(f"test-session-{uuid4()}")

View file

@ -3,13 +3,22 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import pytest
from llama_stack import LlamaStackAsLibraryClient
from llama_stack.providers.tests.env import get_env_or_fail
from llama_stack_client import LlamaStackClient
@pytest.fixture
@pytest.fixture(scope="session")
def llama_stack_client():
"""Fixture to create a fresh LlamaStackClient instance for each test"""
return LlamaStackClient(base_url=get_env_or_fail("LLAMA_STACK_BASE_URL"))
if os.environ.get("LLAMA_STACK_CONFIG"):
client = LlamaStackAsLibraryClient(get_env_or_fail("LLAMA_STACK_CONFIG"))
client.initialize()
elif os.environ.get("LLAMA_STACK_BASE_URL"):
client = LlamaStackClient(base_url=get_env_or_fail("LLAMA_STACK_BASE_URL"))
else:
raise ValueError("LLAMA_STACK_CONFIG or LLAMA_STACK_BASE_URL must be set")
return client

View file

@ -4,10 +4,23 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import sys
import traceback
import warnings
import pytest
from llama_stack_client.lib.inference.event_logger import EventLogger
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
log = file if hasattr(file, "write") else sys.stderr
traceback.print_stack(file=log)
log.write(warnings.formatwarning(message, category, filename, lineno, line))
warnings.showwarning = warn_with_traceback
def test_text_chat_completion(llama_stack_client):
# non-streaming
available_models = [
@ -55,11 +68,15 @@ def test_image_chat_completion(llama_stack_client):
"role": "user",
"content": [
{
"image": {
"type": "image",
"data": {
"uri": "https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
}
},
},
{
"type": "text",
"text": "Describe what is in this image.",
},
"Describe what is in this image.",
],
}
response = llama_stack_client.inference.chat_completion(

View file

@ -22,6 +22,7 @@ def test_memory_bank(llama_stack_client):
llama_stack_client.memory_banks.register(
memory_bank_id=memory_bank_id,
params={
"memory_bank_type": "vector",
"embedding_model": "all-MiniLM-L6-v2",
"chunk_size_in_tokens": 512,
"overlap_size_in_tokens": 64,