mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 10:13:05 +00:00
client sdk test fixes
This commit is contained in:
parent
c3865faf37
commit
efe3189728
3 changed files with 21 additions and 25 deletions
|
@ -402,7 +402,6 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
# if the session has a memory bank id, let the memory tool use it
|
# if the session has a memory bank id, let the memory tool use it
|
||||||
if session_info.memory_bank_id:
|
if session_info.memory_bank_id:
|
||||||
query_args["memory_bank_id"] = session_info.memory_bank_id
|
query_args["memory_bank_id"] = session_info.memory_bank_id
|
||||||
serialized_args = tracing.serialize_value(query_args)
|
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepProgressPayload(
|
payload=AgentTurnResponseStepProgressPayload(
|
||||||
|
@ -412,8 +411,8 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
parse_status=ToolCallParseStatus.success,
|
parse_status=ToolCallParseStatus.success,
|
||||||
content=ToolCall(
|
content=ToolCall(
|
||||||
call_id="",
|
call_id="",
|
||||||
tool_name="memory",
|
tool_name=MEMORY_QUERY_TOOL,
|
||||||
arguments=serialized_args,
|
arguments={},
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -435,14 +434,14 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
ToolCall(
|
ToolCall(
|
||||||
call_id="",
|
call_id="",
|
||||||
tool_name="memory",
|
tool_name=MEMORY_QUERY_TOOL,
|
||||||
arguments=serialized_args,
|
arguments={},
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
tool_responses=[
|
tool_responses=[
|
||||||
ToolResponse(
|
ToolResponse(
|
||||||
call_id="",
|
call_id="",
|
||||||
tool_name="memory",
|
tool_name=MEMORY_QUERY_TOOL,
|
||||||
content=result.content,
|
content=result.content,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
|
@ -456,7 +455,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
span.set_attribute("output", result.content)
|
span.set_attribute("output", result.content)
|
||||||
span.set_attribute("error_code", result.error_code)
|
span.set_attribute("error_code", result.error_code)
|
||||||
span.set_attribute("error_message", result.error_message)
|
span.set_attribute("error_message", result.error_message)
|
||||||
span.set_attribute("tool_name", "memory")
|
span.set_attribute("tool_name", MEMORY_QUERY_TOOL)
|
||||||
if result.error_code == 0:
|
if result.error_code == 0:
|
||||||
last_message = input_messages[-1]
|
last_message = input_messages[-1]
|
||||||
last_message.context = result.content
|
last_message.context = result.content
|
||||||
|
|
|
@ -56,7 +56,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
) -> List[ToolDef]:
|
) -> List[ToolDef]:
|
||||||
return [
|
return [
|
||||||
ToolDef(
|
ToolDef(
|
||||||
name="memory",
|
name="query_memory",
|
||||||
description="Retrieve context from memory",
|
description="Retrieve context from memory",
|
||||||
parameters=[
|
parameters=[
|
||||||
ToolParameter(
|
ToolParameter(
|
||||||
|
|
|
@ -101,7 +101,7 @@ def agent_config(llama_stack_client):
|
||||||
"temperature": 1.0,
|
"temperature": 1.0,
|
||||||
"top_p": 0.9,
|
"top_p": 0.9,
|
||||||
},
|
},
|
||||||
tools=[],
|
toolgroups=[],
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
tool_prompt_format="json",
|
tool_prompt_format="json",
|
||||||
input_shields=available_shields,
|
input_shields=available_shields,
|
||||||
|
@ -152,8 +152,8 @@ def test_agent_simple(llama_stack_client, agent_config):
|
||||||
def test_builtin_tool_web_search(llama_stack_client, agent_config):
|
def test_builtin_tool_web_search(llama_stack_client, agent_config):
|
||||||
agent_config = {
|
agent_config = {
|
||||||
**agent_config,
|
**agent_config,
|
||||||
"tools": [
|
"toolgroups": [
|
||||||
"builtin::web_search",
|
"builtin::websearch",
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
agent = Agent(llama_stack_client, agent_config)
|
agent = Agent(llama_stack_client, agent_config)
|
||||||
|
@ -181,7 +181,7 @@ def test_builtin_tool_web_search(llama_stack_client, agent_config):
|
||||||
def test_builtin_tool_code_execution(llama_stack_client, agent_config):
|
def test_builtin_tool_code_execution(llama_stack_client, agent_config):
|
||||||
agent_config = {
|
agent_config = {
|
||||||
**agent_config,
|
**agent_config,
|
||||||
"tools": [
|
"toolgroups": [
|
||||||
"builtin::code_interpreter",
|
"builtin::code_interpreter",
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
@ -208,7 +208,7 @@ def test_code_execution(llama_stack_client):
|
||||||
agent_config = AgentConfig(
|
agent_config = AgentConfig(
|
||||||
model="meta-llama/Llama-3.1-70B-Instruct",
|
model="meta-llama/Llama-3.1-70B-Instruct",
|
||||||
instructions="You are a helpful assistant",
|
instructions="You are a helpful assistant",
|
||||||
tools=[
|
toolgroups=[
|
||||||
"builtin::code_interpreter",
|
"builtin::code_interpreter",
|
||||||
],
|
],
|
||||||
tool_choice="required",
|
tool_choice="required",
|
||||||
|
@ -250,7 +250,7 @@ def test_custom_tool(llama_stack_client, agent_config):
|
||||||
agent_config = {
|
agent_config = {
|
||||||
**agent_config,
|
**agent_config,
|
||||||
"model": "meta-llama/Llama-3.2-3B-Instruct",
|
"model": "meta-llama/Llama-3.2-3B-Instruct",
|
||||||
"tools": ["builtin::web_search"],
|
"toolgroups": ["builtin::websearch"],
|
||||||
"client_tools": [client_tool.get_tool_definition()],
|
"client_tools": [client_tool.get_tool_definition()],
|
||||||
"tool_prompt_format": "python_list",
|
"tool_prompt_format": "python_list",
|
||||||
}
|
}
|
||||||
|
@ -293,9 +293,14 @@ def test_rag_agent(llama_stack_client, agent_config):
|
||||||
for i, url in enumerate(urls)
|
for i, url in enumerate(urls)
|
||||||
]
|
]
|
||||||
|
|
||||||
agent_config["tools"].append("builtin::memory")
|
|
||||||
agent = Agent(llama_stack_client, agent_config)
|
|
||||||
memory_bank_id = "test-memory-bank"
|
memory_bank_id = "test-memory-bank"
|
||||||
|
agent_config["toolgroups"].append(
|
||||||
|
dict(
|
||||||
|
name="builtin::memory",
|
||||||
|
args={"memory_bank_id": memory_bank_id},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
agent = Agent(llama_stack_client, agent_config)
|
||||||
llama_stack_client.memory_banks.register(
|
llama_stack_client.memory_banks.register(
|
||||||
memory_bank_id=memory_bank_id,
|
memory_bank_id=memory_bank_id,
|
||||||
params={
|
params={
|
||||||
|
@ -326,16 +331,8 @@ def test_rag_agent(llama_stack_client, agent_config):
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
tools=[
|
|
||||||
{
|
|
||||||
"name": "memory",
|
|
||||||
"args": {
|
|
||||||
"memory_bank_id": memory_bank_id,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
||||||
logs_str = "".join(logs)
|
logs_str = "".join(logs)
|
||||||
assert "Tool:memory" in logs_str
|
assert "Tool:query_memory" in logs_str
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue