mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 01:03:59 +00:00
Fix LibraryClient completely correctly; also make tests pass
This commit is contained in:
parent
d4935ca439
commit
1bcc26ccd1
7 changed files with 201 additions and 100 deletions
|
@ -13,10 +13,20 @@ import threading
|
|||
from concurrent.futures import ThreadPoolExecutor
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator, get_args, get_origin, Optional, Type, TypeVar, Union
|
||||
from typing import Any, Generator, get_args, get_origin, Optional, TypeVar
|
||||
|
||||
import httpx
|
||||
|
||||
import yaml
|
||||
from llama_stack_client import AsyncLlamaStackClient, LlamaStackClient, NOT_GIVEN
|
||||
from llama_stack_client import (
|
||||
APIResponse,
|
||||
AsyncAPIResponse,
|
||||
AsyncLlamaStackClient,
|
||||
AsyncStream,
|
||||
LlamaStackClient,
|
||||
NOT_GIVEN,
|
||||
Stream,
|
||||
)
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
from rich.console import Console
|
||||
|
||||
|
@ -66,7 +76,7 @@ def stream_across_asyncio_run_boundary(
|
|||
# make sure we make the generator in the event loop context
|
||||
gen = await async_gen_maker()
|
||||
try:
|
||||
async for item in gen:
|
||||
async for item in await gen:
|
||||
result_queue.put(item)
|
||||
except Exception as e:
|
||||
print(f"Error in generator {e}")
|
||||
|
@ -112,31 +122,17 @@ def stream_across_asyncio_run_boundary(
|
|||
future.result()
|
||||
|
||||
|
||||
def convert_pydantic_to_json_value(value: Any, cast_to: Type) -> dict:
|
||||
def convert_pydantic_to_json_value(value: Any) -> Any:
|
||||
if isinstance(value, Enum):
|
||||
return value.value
|
||||
elif isinstance(value, list):
|
||||
return [convert_pydantic_to_json_value(item, cast_to) for item in value]
|
||||
return [convert_pydantic_to_json_value(item) for item in value]
|
||||
elif isinstance(value, dict):
|
||||
return {k: convert_pydantic_to_json_value(v, cast_to) for k, v in value.items()}
|
||||
return {k: convert_pydantic_to_json_value(v) for k, v in value.items()}
|
||||
elif isinstance(value, BaseModel):
|
||||
# This is quite hacky and we should figure out how to use stuff from
|
||||
# generated client-sdk code (using ApiResponse.parse() essentially)
|
||||
value_dict = json.loads(value.model_dump_json())
|
||||
|
||||
origin = get_origin(cast_to)
|
||||
if origin is Union:
|
||||
args = get_args(cast_to)
|
||||
for arg in args:
|
||||
arg_name = arg.__name__.split(".")[-1]
|
||||
value_name = value.__class__.__name__.split(".")[-1]
|
||||
if arg_name == value_name:
|
||||
return arg(**value_dict)
|
||||
|
||||
# assume we have the correct association between the server-side type and the client-side type
|
||||
return cast_to(**value_dict)
|
||||
|
||||
return value
|
||||
return json.loads(value.model_dump_json())
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
def convert_to_pydantic(annotation: Any, value: Any) -> Any:
|
||||
|
@ -278,16 +274,28 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
if not self.endpoint_impls:
|
||||
raise ValueError("Client not initialized")
|
||||
|
||||
params = options.params or {}
|
||||
params |= options.json_data or {}
|
||||
if stream:
|
||||
return self._call_streaming(options.url, params, cast_to)
|
||||
return self._call_streaming(
|
||||
cast_to=cast_to,
|
||||
options=options,
|
||||
stream_cls=stream_cls,
|
||||
)
|
||||
else:
|
||||
return await self._call_non_streaming(options.url, params, cast_to)
|
||||
return await self._call_non_streaming(
|
||||
cast_to=cast_to,
|
||||
options=options,
|
||||
)
|
||||
|
||||
async def _call_non_streaming(
|
||||
self, path: str, body: dict = None, cast_to: Any = None
|
||||
self,
|
||||
*,
|
||||
cast_to: Any,
|
||||
options: Any,
|
||||
):
|
||||
path = options.url
|
||||
|
||||
body = options.params or {}
|
||||
body |= options.json_data or {}
|
||||
await start_trace(path, {"__location__": "library_client"})
|
||||
try:
|
||||
func = self.endpoint_impls.get(path)
|
||||
|
@ -295,11 +303,45 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
raise ValueError(f"No endpoint found for {path}")
|
||||
|
||||
body = self._convert_body(path, body)
|
||||
return convert_pydantic_to_json_value(await func(**body), cast_to)
|
||||
result = await func(**body)
|
||||
|
||||
json_content = json.dumps(convert_pydantic_to_json_value(result))
|
||||
mock_response = httpx.Response(
|
||||
status_code=httpx.codes.OK,
|
||||
content=json_content.encode("utf-8"),
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
request=httpx.Request(
|
||||
method=options.method,
|
||||
url=options.url,
|
||||
params=options.params,
|
||||
headers=options.headers,
|
||||
json=options.json_data,
|
||||
),
|
||||
)
|
||||
response = APIResponse(
|
||||
raw=mock_response,
|
||||
client=self,
|
||||
cast_to=cast_to,
|
||||
options=options,
|
||||
stream=False,
|
||||
stream_cls=None,
|
||||
)
|
||||
return response.parse()
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
async def _call_streaming(self, path: str, body: dict = None, cast_to: Any = None):
|
||||
async def _call_streaming(
|
||||
self,
|
||||
*,
|
||||
cast_to: Any,
|
||||
options: Any,
|
||||
stream_cls: Any,
|
||||
):
|
||||
path = options.url
|
||||
body = options.params or {}
|
||||
body |= options.json_data or {}
|
||||
await start_trace(path, {"__location__": "library_client"})
|
||||
try:
|
||||
func = self.endpoint_impls.get(path)
|
||||
|
@ -307,8 +349,40 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
raise ValueError(f"No endpoint found for {path}")
|
||||
|
||||
body = self._convert_body(path, body)
|
||||
async for chunk in await func(**body):
|
||||
yield convert_pydantic_to_json_value(chunk, cast_to)
|
||||
|
||||
async def gen():
|
||||
async for chunk in await func(**body):
|
||||
data = json.dumps(convert_pydantic_to_json_value(chunk))
|
||||
sse_event = f"data: {data}\n\n"
|
||||
yield sse_event.encode("utf-8")
|
||||
|
||||
mock_response = httpx.Response(
|
||||
status_code=httpx.codes.OK,
|
||||
content=gen(),
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
request=httpx.Request(
|
||||
method=options.method,
|
||||
url=options.url,
|
||||
params=options.params,
|
||||
headers=options.headers,
|
||||
json=options.json_data,
|
||||
),
|
||||
)
|
||||
origin = get_origin(stream_cls)
|
||||
assert origin is Stream
|
||||
args = get_args(stream_cls)
|
||||
stream_cls = AsyncStream[args[0]]
|
||||
response = AsyncAPIResponse(
|
||||
raw=mock_response,
|
||||
client=self,
|
||||
cast_to=cast_to,
|
||||
options=options,
|
||||
stream=True,
|
||||
stream_cls=stream_cls,
|
||||
)
|
||||
return await response.parse()
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
|
@ -143,7 +144,7 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
|
|||
if default_val is None:
|
||||
raise EnvVarError(env_var, path)
|
||||
else:
|
||||
value = default_val
|
||||
value = default_val if default_val != "null" else None
|
||||
|
||||
# expand "~" from the values
|
||||
return os.path.expanduser(value)
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Dict, List, Optional, Protocol, Tuple
|
||||
|
||||
|
@ -54,10 +53,7 @@ def _parse_registry_values(values: List[str]) -> List[RoutableObjectWithProvider
|
|||
"""Utility function to parse registry values into RoutableObjectWithProvider objects."""
|
||||
all_objects = []
|
||||
for value in values:
|
||||
obj = pydantic.parse_obj_as(
|
||||
RoutableObjectWithProvider,
|
||||
json.loads(value),
|
||||
)
|
||||
obj = pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(value)
|
||||
all_objects.append(obj)
|
||||
return all_objects
|
||||
|
||||
|
@ -89,14 +85,7 @@ class DiskDistributionRegistry(DistributionRegistry):
|
|||
if not json_str:
|
||||
return None
|
||||
|
||||
objects_data = json.loads(json_str)
|
||||
# Return only the first object if any exist
|
||||
if objects_data:
|
||||
return pydantic.parse_obj_as(
|
||||
RoutableObjectWithProvider,
|
||||
json.loads(objects_data),
|
||||
)
|
||||
return None
|
||||
return pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(json_str)
|
||||
|
||||
async def update(self, obj: RoutableObjectWithProvider) -> None:
|
||||
await self.kvstore.set(
|
||||
|
|
|
@ -8,6 +8,7 @@ import json
|
|||
from typing import Dict, List
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from llama_stack.providers.tests.env import get_env_or_fail
|
||||
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
|
@ -77,16 +78,20 @@ class TestCustomTool(CustomTool):
|
|||
return -1
|
||||
|
||||
|
||||
def get_agent_config_with_available_models_shields(llama_stack_client):
|
||||
@pytest.fixture(scope="session")
|
||||
def agent_config(llama_stack_client):
|
||||
available_models = [
|
||||
model.identifier
|
||||
for model in llama_stack_client.models.list()
|
||||
if model.identifier.startswith("meta-llama")
|
||||
if model.identifier.startswith("meta-llama") and "405" not in model.identifier
|
||||
]
|
||||
model_id = available_models[0]
|
||||
print(f"Using model: {model_id}")
|
||||
available_shields = [
|
||||
shield.identifier for shield in llama_stack_client.shields.list()
|
||||
]
|
||||
available_shields = available_shields[:1]
|
||||
print(f"Using shield: {available_shields}")
|
||||
agent_config = AgentConfig(
|
||||
model=model_id,
|
||||
instructions="You are a helpful assistant",
|
||||
|
@ -105,8 +110,7 @@ def get_agent_config_with_available_models_shields(llama_stack_client):
|
|||
return agent_config
|
||||
|
||||
|
||||
def test_agent_simple(llama_stack_client):
|
||||
agent_config = get_agent_config_with_available_models_shields(llama_stack_client)
|
||||
def test_agent_simple(llama_stack_client, agent_config):
|
||||
agent = Agent(llama_stack_client, agent_config)
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
||||
|
@ -142,16 +146,18 @@ def test_agent_simple(llama_stack_client):
|
|||
assert "I can't" in logs_str
|
||||
|
||||
|
||||
def test_builtin_tool_brave_search(llama_stack_client):
|
||||
agent_config = get_agent_config_with_available_models_shields(llama_stack_client)
|
||||
agent_config["tools"] = [
|
||||
{
|
||||
"type": "brave_search",
|
||||
"engine": "brave",
|
||||
"api_key": get_env_or_fail("BRAVE_SEARCH_API_KEY"),
|
||||
}
|
||||
]
|
||||
print(agent_config)
|
||||
def test_builtin_tool_brave_search(llama_stack_client, agent_config):
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
"tools": [
|
||||
{
|
||||
"type": "brave_search",
|
||||
"engine": "brave",
|
||||
"api_key": get_env_or_fail("BRAVE_SEARCH_API_KEY"),
|
||||
}
|
||||
],
|
||||
}
|
||||
print(f"Agent Config: {agent_config}")
|
||||
agent = Agent(llama_stack_client, agent_config)
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
||||
|
@ -174,13 +180,15 @@ def test_builtin_tool_brave_search(llama_stack_client):
|
|||
assert "No Violation" in logs_str
|
||||
|
||||
|
||||
def test_builtin_tool_code_execution(llama_stack_client):
|
||||
agent_config = get_agent_config_with_available_models_shields(llama_stack_client)
|
||||
agent_config["tools"] = [
|
||||
{
|
||||
"type": "code_interpreter",
|
||||
}
|
||||
]
|
||||
def test_builtin_tool_code_execution(llama_stack_client, agent_config):
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
"tools": [
|
||||
{
|
||||
"type": "code_interpreter",
|
||||
}
|
||||
],
|
||||
}
|
||||
agent = Agent(llama_stack_client, agent_config)
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
||||
|
@ -200,34 +208,36 @@ def test_builtin_tool_code_execution(llama_stack_client):
|
|||
assert "Tool:code_interpreter Response" in logs_str
|
||||
|
||||
|
||||
def test_custom_tool(llama_stack_client):
|
||||
agent_config = get_agent_config_with_available_models_shields(llama_stack_client)
|
||||
agent_config["model"] = "meta-llama/Llama-3.2-3B-Instruct"
|
||||
agent_config["tools"] = [
|
||||
{
|
||||
"type": "brave_search",
|
||||
"engine": "brave",
|
||||
"api_key": get_env_or_fail("BRAVE_SEARCH_API_KEY"),
|
||||
},
|
||||
{
|
||||
"function_name": "get_boiling_point",
|
||||
"description": "Get the boiling point of a imaginary liquids (eg. polyjuice)",
|
||||
"parameters": {
|
||||
"liquid_name": {
|
||||
"param_type": "str",
|
||||
"description": "The name of the liquid",
|
||||
"required": True,
|
||||
},
|
||||
"celcius": {
|
||||
"param_type": "boolean",
|
||||
"description": "Whether to return the boiling point in Celcius",
|
||||
"required": False,
|
||||
},
|
||||
def test_custom_tool(llama_stack_client, agent_config):
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
"model": "meta-llama/Llama-3.2-3B-Instruct",
|
||||
"tools": [
|
||||
{
|
||||
"type": "brave_search",
|
||||
"engine": "brave",
|
||||
"api_key": get_env_or_fail("BRAVE_SEARCH_API_KEY"),
|
||||
},
|
||||
"type": "function_call",
|
||||
},
|
||||
]
|
||||
agent_config["tool_prompt_format"] = "python_list"
|
||||
{
|
||||
"function_name": "get_boiling_point",
|
||||
"description": "Get the boiling point of a imaginary liquids (eg. polyjuice)",
|
||||
"parameters": {
|
||||
"liquid_name": {
|
||||
"param_type": "str",
|
||||
"description": "The name of the liquid",
|
||||
"required": True,
|
||||
},
|
||||
"celcius": {
|
||||
"param_type": "boolean",
|
||||
"description": "Whether to return the boiling point in Celcius",
|
||||
"required": False,
|
||||
},
|
||||
},
|
||||
"type": "function_call",
|
||||
},
|
||||
],
|
||||
"tool_prompt_format": "python_list",
|
||||
}
|
||||
|
||||
agent = Agent(llama_stack_client, agent_config, custom_tools=(TestCustomTool(),))
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
|
|
@ -3,13 +3,22 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from llama_stack import LlamaStackAsLibraryClient
|
||||
|
||||
from llama_stack.providers.tests.env import get_env_or_fail
|
||||
from llama_stack_client import LlamaStackClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@pytest.fixture(scope="session")
|
||||
def llama_stack_client():
|
||||
"""Fixture to create a fresh LlamaStackClient instance for each test"""
|
||||
return LlamaStackClient(base_url=get_env_or_fail("LLAMA_STACK_BASE_URL"))
|
||||
if os.environ.get("LLAMA_STACK_CONFIG"):
|
||||
client = LlamaStackAsLibraryClient(get_env_or_fail("LLAMA_STACK_CONFIG"))
|
||||
client.initialize()
|
||||
elif os.environ.get("LLAMA_STACK_BASE_URL"):
|
||||
client = LlamaStackClient(base_url=get_env_or_fail("LLAMA_STACK_BASE_URL"))
|
||||
else:
|
||||
raise ValueError("LLAMA_STACK_CONFIG or LLAMA_STACK_BASE_URL must be set")
|
||||
return client
|
||||
|
|
|
@ -4,10 +4,23 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
from llama_stack_client.lib.inference.event_logger import EventLogger
|
||||
|
||||
|
||||
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
|
||||
log = file if hasattr(file, "write") else sys.stderr
|
||||
traceback.print_stack(file=log)
|
||||
log.write(warnings.formatwarning(message, category, filename, lineno, line))
|
||||
|
||||
|
||||
warnings.showwarning = warn_with_traceback
|
||||
|
||||
|
||||
def test_text_chat_completion(llama_stack_client):
|
||||
# non-streaming
|
||||
available_models = [
|
||||
|
@ -55,11 +68,15 @@ def test_image_chat_completion(llama_stack_client):
|
|||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"image": {
|
||||
"type": "image",
|
||||
"data": {
|
||||
"uri": "https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Describe what is in this image.",
|
||||
},
|
||||
"Describe what is in this image.",
|
||||
],
|
||||
}
|
||||
response = llama_stack_client.inference.chat_completion(
|
||||
|
|
|
@ -22,6 +22,7 @@ def test_memory_bank(llama_stack_client):
|
|||
llama_stack_client.memory_banks.register(
|
||||
memory_bank_id=memory_bank_id,
|
||||
params={
|
||||
"memory_bank_type": "vector",
|
||||
"embedding_model": "all-MiniLM-L6-v2",
|
||||
"chunk_size_in_tokens": 512,
|
||||
"overlap_size_in_tokens": 64,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue