Fix LibraryClient completely correctly; also make tests pass

This commit is contained in:
Ashwin Bharambe 2024-12-16 22:16:21 -08:00
parent d4935ca439
commit 1bcc26ccd1
7 changed files with 201 additions and 100 deletions

View file

@ -13,10 +13,20 @@ import threading
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Any, Generator, get_args, get_origin, Optional, Type, TypeVar, Union from typing import Any, Generator, get_args, get_origin, Optional, TypeVar
import httpx
import yaml import yaml
from llama_stack_client import AsyncLlamaStackClient, LlamaStackClient, NOT_GIVEN from llama_stack_client import (
APIResponse,
AsyncAPIResponse,
AsyncLlamaStackClient,
AsyncStream,
LlamaStackClient,
NOT_GIVEN,
Stream,
)
from pydantic import BaseModel, TypeAdapter from pydantic import BaseModel, TypeAdapter
from rich.console import Console from rich.console import Console
@ -66,7 +76,7 @@ def stream_across_asyncio_run_boundary(
# make sure we make the generator in the event loop context # make sure we make the generator in the event loop context
gen = await async_gen_maker() gen = await async_gen_maker()
try: try:
async for item in gen: async for item in await gen:
result_queue.put(item) result_queue.put(item)
except Exception as e: except Exception as e:
print(f"Error in generator {e}") print(f"Error in generator {e}")
@ -112,31 +122,17 @@ def stream_across_asyncio_run_boundary(
future.result() future.result()
def convert_pydantic_to_json_value(value: Any, cast_to: Type) -> dict: def convert_pydantic_to_json_value(value: Any) -> Any:
if isinstance(value, Enum): if isinstance(value, Enum):
return value.value return value.value
elif isinstance(value, list): elif isinstance(value, list):
return [convert_pydantic_to_json_value(item, cast_to) for item in value] return [convert_pydantic_to_json_value(item) for item in value]
elif isinstance(value, dict): elif isinstance(value, dict):
return {k: convert_pydantic_to_json_value(v, cast_to) for k, v in value.items()} return {k: convert_pydantic_to_json_value(v) for k, v in value.items()}
elif isinstance(value, BaseModel): elif isinstance(value, BaseModel):
# This is quite hacky and we should figure out how to use stuff from return json.loads(value.model_dump_json())
# generated client-sdk code (using ApiResponse.parse() essentially) else:
value_dict = json.loads(value.model_dump_json()) return value
origin = get_origin(cast_to)
if origin is Union:
args = get_args(cast_to)
for arg in args:
arg_name = arg.__name__.split(".")[-1]
value_name = value.__class__.__name__.split(".")[-1]
if arg_name == value_name:
return arg(**value_dict)
# assume we have the correct association between the server-side type and the client-side type
return cast_to(**value_dict)
return value
def convert_to_pydantic(annotation: Any, value: Any) -> Any: def convert_to_pydantic(annotation: Any, value: Any) -> Any:
@ -278,16 +274,28 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
if not self.endpoint_impls: if not self.endpoint_impls:
raise ValueError("Client not initialized") raise ValueError("Client not initialized")
params = options.params or {}
params |= options.json_data or {}
if stream: if stream:
return self._call_streaming(options.url, params, cast_to) return self._call_streaming(
cast_to=cast_to,
options=options,
stream_cls=stream_cls,
)
else: else:
return await self._call_non_streaming(options.url, params, cast_to) return await self._call_non_streaming(
cast_to=cast_to,
options=options,
)
async def _call_non_streaming( async def _call_non_streaming(
self, path: str, body: dict = None, cast_to: Any = None self,
*,
cast_to: Any,
options: Any,
): ):
path = options.url
body = options.params or {}
body |= options.json_data or {}
await start_trace(path, {"__location__": "library_client"}) await start_trace(path, {"__location__": "library_client"})
try: try:
func = self.endpoint_impls.get(path) func = self.endpoint_impls.get(path)
@ -295,11 +303,45 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
raise ValueError(f"No endpoint found for {path}") raise ValueError(f"No endpoint found for {path}")
body = self._convert_body(path, body) body = self._convert_body(path, body)
return convert_pydantic_to_json_value(await func(**body), cast_to) result = await func(**body)
json_content = json.dumps(convert_pydantic_to_json_value(result))
mock_response = httpx.Response(
status_code=httpx.codes.OK,
content=json_content.encode("utf-8"),
headers={
"Content-Type": "application/json",
},
request=httpx.Request(
method=options.method,
url=options.url,
params=options.params,
headers=options.headers,
json=options.json_data,
),
)
response = APIResponse(
raw=mock_response,
client=self,
cast_to=cast_to,
options=options,
stream=False,
stream_cls=None,
)
return response.parse()
finally: finally:
await end_trace() await end_trace()
async def _call_streaming(self, path: str, body: dict = None, cast_to: Any = None): async def _call_streaming(
self,
*,
cast_to: Any,
options: Any,
stream_cls: Any,
):
path = options.url
body = options.params or {}
body |= options.json_data or {}
await start_trace(path, {"__location__": "library_client"}) await start_trace(path, {"__location__": "library_client"})
try: try:
func = self.endpoint_impls.get(path) func = self.endpoint_impls.get(path)
@ -307,8 +349,40 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
raise ValueError(f"No endpoint found for {path}") raise ValueError(f"No endpoint found for {path}")
body = self._convert_body(path, body) body = self._convert_body(path, body)
async for chunk in await func(**body):
yield convert_pydantic_to_json_value(chunk, cast_to) async def gen():
async for chunk in await func(**body):
data = json.dumps(convert_pydantic_to_json_value(chunk))
sse_event = f"data: {data}\n\n"
yield sse_event.encode("utf-8")
mock_response = httpx.Response(
status_code=httpx.codes.OK,
content=gen(),
headers={
"Content-Type": "application/json",
},
request=httpx.Request(
method=options.method,
url=options.url,
params=options.params,
headers=options.headers,
json=options.json_data,
),
)
origin = get_origin(stream_cls)
assert origin is Stream
args = get_args(stream_cls)
stream_cls = AsyncStream[args[0]]
response = AsyncAPIResponse(
raw=mock_response,
client=self,
cast_to=cast_to,
options=options,
stream=True,
stream_cls=stream_cls,
)
return await response.parse()
finally: finally:
await end_trace() await end_trace()

View file

@ -6,6 +6,7 @@
import logging import logging
import os import os
import re
from pathlib import Path from pathlib import Path
from typing import Any, Dict from typing import Any, Dict
@ -143,7 +144,7 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
if default_val is None: if default_val is None:
raise EnvVarError(env_var, path) raise EnvVarError(env_var, path)
else: else:
value = default_val value = default_val if default_val != "null" else None
# expand "~" from the values # expand "~" from the values
return os.path.expanduser(value) return os.path.expanduser(value)

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import asyncio
import json
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Dict, List, Optional, Protocol, Tuple from typing import Dict, List, Optional, Protocol, Tuple
@ -54,10 +53,7 @@ def _parse_registry_values(values: List[str]) -> List[RoutableObjectWithProvider
"""Utility function to parse registry values into RoutableObjectWithProvider objects.""" """Utility function to parse registry values into RoutableObjectWithProvider objects."""
all_objects = [] all_objects = []
for value in values: for value in values:
obj = pydantic.parse_obj_as( obj = pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(value)
RoutableObjectWithProvider,
json.loads(value),
)
all_objects.append(obj) all_objects.append(obj)
return all_objects return all_objects
@ -89,14 +85,7 @@ class DiskDistributionRegistry(DistributionRegistry):
if not json_str: if not json_str:
return None return None
objects_data = json.loads(json_str) return pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(json_str)
# Return only the first object if any exist
if objects_data:
return pydantic.parse_obj_as(
RoutableObjectWithProvider,
json.loads(objects_data),
)
return None
async def update(self, obj: RoutableObjectWithProvider) -> None: async def update(self, obj: RoutableObjectWithProvider) -> None:
await self.kvstore.set( await self.kvstore.set(

View file

@ -8,6 +8,7 @@ import json
from typing import Dict, List from typing import Dict, List
from uuid import uuid4 from uuid import uuid4
import pytest
from llama_stack.providers.tests.env import get_env_or_fail from llama_stack.providers.tests.env import get_env_or_fail
from llama_stack_client.lib.agents.agent import Agent from llama_stack_client.lib.agents.agent import Agent
@ -77,16 +78,20 @@ class TestCustomTool(CustomTool):
return -1 return -1
def get_agent_config_with_available_models_shields(llama_stack_client): @pytest.fixture(scope="session")
def agent_config(llama_stack_client):
available_models = [ available_models = [
model.identifier model.identifier
for model in llama_stack_client.models.list() for model in llama_stack_client.models.list()
if model.identifier.startswith("meta-llama") if model.identifier.startswith("meta-llama") and "405" not in model.identifier
] ]
model_id = available_models[0] model_id = available_models[0]
print(f"Using model: {model_id}")
available_shields = [ available_shields = [
shield.identifier for shield in llama_stack_client.shields.list() shield.identifier for shield in llama_stack_client.shields.list()
] ]
available_shields = available_shields[:1]
print(f"Using shield: {available_shields}")
agent_config = AgentConfig( agent_config = AgentConfig(
model=model_id, model=model_id,
instructions="You are a helpful assistant", instructions="You are a helpful assistant",
@ -105,8 +110,7 @@ def get_agent_config_with_available_models_shields(llama_stack_client):
return agent_config return agent_config
def test_agent_simple(llama_stack_client): def test_agent_simple(llama_stack_client, agent_config):
agent_config = get_agent_config_with_available_models_shields(llama_stack_client)
agent = Agent(llama_stack_client, agent_config) agent = Agent(llama_stack_client, agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}") session_id = agent.create_session(f"test-session-{uuid4()}")
@ -142,16 +146,18 @@ def test_agent_simple(llama_stack_client):
assert "I can't" in logs_str assert "I can't" in logs_str
def test_builtin_tool_brave_search(llama_stack_client): def test_builtin_tool_brave_search(llama_stack_client, agent_config):
agent_config = get_agent_config_with_available_models_shields(llama_stack_client) agent_config = {
agent_config["tools"] = [ **agent_config,
{ "tools": [
"type": "brave_search", {
"engine": "brave", "type": "brave_search",
"api_key": get_env_or_fail("BRAVE_SEARCH_API_KEY"), "engine": "brave",
} "api_key": get_env_or_fail("BRAVE_SEARCH_API_KEY"),
] }
print(agent_config) ],
}
print(f"Agent Config: {agent_config}")
agent = Agent(llama_stack_client, agent_config) agent = Agent(llama_stack_client, agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}") session_id = agent.create_session(f"test-session-{uuid4()}")
@ -174,13 +180,15 @@ def test_builtin_tool_brave_search(llama_stack_client):
assert "No Violation" in logs_str assert "No Violation" in logs_str
def test_builtin_tool_code_execution(llama_stack_client): def test_builtin_tool_code_execution(llama_stack_client, agent_config):
agent_config = get_agent_config_with_available_models_shields(llama_stack_client) agent_config = {
agent_config["tools"] = [ **agent_config,
{ "tools": [
"type": "code_interpreter", {
} "type": "code_interpreter",
] }
],
}
agent = Agent(llama_stack_client, agent_config) agent = Agent(llama_stack_client, agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}") session_id = agent.create_session(f"test-session-{uuid4()}")
@ -200,34 +208,36 @@ def test_builtin_tool_code_execution(llama_stack_client):
assert "Tool:code_interpreter Response" in logs_str assert "Tool:code_interpreter Response" in logs_str
def test_custom_tool(llama_stack_client): def test_custom_tool(llama_stack_client, agent_config):
agent_config = get_agent_config_with_available_models_shields(llama_stack_client) agent_config = {
agent_config["model"] = "meta-llama/Llama-3.2-3B-Instruct" **agent_config,
agent_config["tools"] = [ "model": "meta-llama/Llama-3.2-3B-Instruct",
{ "tools": [
"type": "brave_search", {
"engine": "brave", "type": "brave_search",
"api_key": get_env_or_fail("BRAVE_SEARCH_API_KEY"), "engine": "brave",
}, "api_key": get_env_or_fail("BRAVE_SEARCH_API_KEY"),
{
"function_name": "get_boiling_point",
"description": "Get the boiling point of a imaginary liquids (eg. polyjuice)",
"parameters": {
"liquid_name": {
"param_type": "str",
"description": "The name of the liquid",
"required": True,
},
"celcius": {
"param_type": "boolean",
"description": "Whether to return the boiling point in Celcius",
"required": False,
},
}, },
"type": "function_call", {
}, "function_name": "get_boiling_point",
] "description": "Get the boiling point of a imaginary liquids (eg. polyjuice)",
agent_config["tool_prompt_format"] = "python_list" "parameters": {
"liquid_name": {
"param_type": "str",
"description": "The name of the liquid",
"required": True,
},
"celcius": {
"param_type": "boolean",
"description": "Whether to return the boiling point in Celcius",
"required": False,
},
},
"type": "function_call",
},
],
"tool_prompt_format": "python_list",
}
agent = Agent(llama_stack_client, agent_config, custom_tools=(TestCustomTool(),)) agent = Agent(llama_stack_client, agent_config, custom_tools=(TestCustomTool(),))
session_id = agent.create_session(f"test-session-{uuid4()}") session_id = agent.create_session(f"test-session-{uuid4()}")

View file

@ -3,13 +3,22 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import os
import pytest import pytest
from llama_stack import LlamaStackAsLibraryClient
from llama_stack.providers.tests.env import get_env_or_fail from llama_stack.providers.tests.env import get_env_or_fail
from llama_stack_client import LlamaStackClient from llama_stack_client import LlamaStackClient
@pytest.fixture @pytest.fixture(scope="session")
def llama_stack_client(): def llama_stack_client():
"""Fixture to create a fresh LlamaStackClient instance for each test""" if os.environ.get("LLAMA_STACK_CONFIG"):
return LlamaStackClient(base_url=get_env_or_fail("LLAMA_STACK_BASE_URL")) client = LlamaStackAsLibraryClient(get_env_or_fail("LLAMA_STACK_CONFIG"))
client.initialize()
elif os.environ.get("LLAMA_STACK_BASE_URL"):
client = LlamaStackClient(base_url=get_env_or_fail("LLAMA_STACK_BASE_URL"))
else:
raise ValueError("LLAMA_STACK_CONFIG or LLAMA_STACK_BASE_URL must be set")
return client

View file

@ -4,10 +4,23 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import sys
import traceback
import warnings
import pytest import pytest
from llama_stack_client.lib.inference.event_logger import EventLogger from llama_stack_client.lib.inference.event_logger import EventLogger
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
log = file if hasattr(file, "write") else sys.stderr
traceback.print_stack(file=log)
log.write(warnings.formatwarning(message, category, filename, lineno, line))
warnings.showwarning = warn_with_traceback
def test_text_chat_completion(llama_stack_client): def test_text_chat_completion(llama_stack_client):
# non-streaming # non-streaming
available_models = [ available_models = [
@ -55,11 +68,15 @@ def test_image_chat_completion(llama_stack_client):
"role": "user", "role": "user",
"content": [ "content": [
{ {
"image": { "type": "image",
"data": {
"uri": "https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg" "uri": "https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
} },
},
{
"type": "text",
"text": "Describe what is in this image.",
}, },
"Describe what is in this image.",
], ],
} }
response = llama_stack_client.inference.chat_completion( response = llama_stack_client.inference.chat_completion(

View file

@ -22,6 +22,7 @@ def test_memory_bank(llama_stack_client):
llama_stack_client.memory_banks.register( llama_stack_client.memory_banks.register(
memory_bank_id=memory_bank_id, memory_bank_id=memory_bank_id,
params={ params={
"memory_bank_type": "vector",
"embedding_model": "all-MiniLM-L6-v2", "embedding_model": "all-MiniLM-L6-v2",
"chunk_size_in_tokens": 512, "chunk_size_in_tokens": 512,
"overlap_size_in_tokens": 64, "overlap_size_in_tokens": 64,