client sdk test fixes

This commit is contained in:
Dinesh Yeduguru 2025-01-07 16:59:01 -08:00
parent c3865faf37
commit efe3189728
3 changed files with 21 additions and 25 deletions

View file

@ -402,7 +402,6 @@ class ChatAgent(ShieldRunnerMixin):
# if the session has a memory bank id, let the memory tool use it
if session_info.memory_bank_id:
query_args["memory_bank_id"] = session_info.memory_bank_id
serialized_args = tracing.serialize_value(query_args)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
@ -412,8 +411,8 @@ class ChatAgent(ShieldRunnerMixin):
parse_status=ToolCallParseStatus.success,
content=ToolCall(
call_id="",
tool_name="memory",
arguments=serialized_args,
tool_name=MEMORY_QUERY_TOOL,
arguments={},
),
),
)
@ -435,14 +434,14 @@ class ChatAgent(ShieldRunnerMixin):
tool_calls=[
ToolCall(
call_id="",
tool_name="memory",
arguments=serialized_args,
tool_name=MEMORY_QUERY_TOOL,
arguments={},
)
],
tool_responses=[
ToolResponse(
call_id="",
tool_name="memory",
tool_name=MEMORY_QUERY_TOOL,
content=result.content,
)
],
@ -456,7 +455,7 @@ class ChatAgent(ShieldRunnerMixin):
span.set_attribute("output", result.content)
span.set_attribute("error_code", result.error_code)
span.set_attribute("error_message", result.error_message)
span.set_attribute("tool_name", "memory")
span.set_attribute("tool_name", MEMORY_QUERY_TOOL)
if result.error_code == 0:
last_message = input_messages[-1]
last_message.context = result.content

View file

@ -56,7 +56,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
) -> List[ToolDef]:
return [
ToolDef(
name="memory",
name="query_memory",
description="Retrieve context from memory",
parameters=[
ToolParameter(

View file

@ -101,7 +101,7 @@ def agent_config(llama_stack_client):
"temperature": 1.0,
"top_p": 0.9,
},
tools=[],
toolgroups=[],
tool_choice="auto",
tool_prompt_format="json",
input_shields=available_shields,
@ -152,8 +152,8 @@ def test_agent_simple(llama_stack_client, agent_config):
def test_builtin_tool_web_search(llama_stack_client, agent_config):
agent_config = {
**agent_config,
"tools": [
"builtin::web_search",
"toolgroups": [
"builtin::websearch",
],
}
agent = Agent(llama_stack_client, agent_config)
@ -181,7 +181,7 @@ def test_builtin_tool_web_search(llama_stack_client, agent_config):
def test_builtin_tool_code_execution(llama_stack_client, agent_config):
agent_config = {
**agent_config,
"tools": [
"toolgroups": [
"builtin::code_interpreter",
],
}
@ -208,7 +208,7 @@ def test_code_execution(llama_stack_client):
agent_config = AgentConfig(
model="meta-llama/Llama-3.1-70B-Instruct",
instructions="You are a helpful assistant",
tools=[
toolgroups=[
"builtin::code_interpreter",
],
tool_choice="required",
@ -250,7 +250,7 @@ def test_custom_tool(llama_stack_client, agent_config):
agent_config = {
**agent_config,
"model": "meta-llama/Llama-3.2-3B-Instruct",
"tools": ["builtin::web_search"],
"toolgroups": ["builtin::websearch"],
"client_tools": [client_tool.get_tool_definition()],
"tool_prompt_format": "python_list",
}
@ -293,9 +293,14 @@ def test_rag_agent(llama_stack_client, agent_config):
for i, url in enumerate(urls)
]
agent_config["tools"].append("builtin::memory")
agent = Agent(llama_stack_client, agent_config)
memory_bank_id = "test-memory-bank"
agent_config["toolgroups"].append(
dict(
name="builtin::memory",
args={"memory_bank_id": memory_bank_id},
)
)
agent = Agent(llama_stack_client, agent_config)
llama_stack_client.memory_banks.register(
memory_bank_id=memory_bank_id,
params={
@ -326,16 +331,8 @@ def test_rag_agent(llama_stack_client, agent_config):
}
],
session_id=session_id,
tools=[
{
"name": "memory",
"args": {
"memory_bank_id": memory_bank_id,
},
}
],
)
logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs_str = "".join(logs)
assert "Tool:memory" in logs_str
assert "Tool:query_memory" in logs_str