mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-22 22:19:49 +00:00
client sdk test fixes
This commit is contained in:
parent
c3865faf37
commit
efe3189728
3 changed files with 21 additions and 25 deletions
|
|
@ -101,7 +101,7 @@ def agent_config(llama_stack_client):
|
|||
"temperature": 1.0,
|
||||
"top_p": 0.9,
|
||||
},
|
||||
tools=[],
|
||||
toolgroups=[],
|
||||
tool_choice="auto",
|
||||
tool_prompt_format="json",
|
||||
input_shields=available_shields,
|
||||
|
|
@ -152,8 +152,8 @@ def test_agent_simple(llama_stack_client, agent_config):
|
|||
def test_builtin_tool_web_search(llama_stack_client, agent_config):
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
"tools": [
|
||||
"builtin::web_search",
|
||||
"toolgroups": [
|
||||
"builtin::websearch",
|
||||
],
|
||||
}
|
||||
agent = Agent(llama_stack_client, agent_config)
|
||||
|
|
@ -181,7 +181,7 @@ def test_builtin_tool_web_search(llama_stack_client, agent_config):
|
|||
def test_builtin_tool_code_execution(llama_stack_client, agent_config):
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
"tools": [
|
||||
"toolgroups": [
|
||||
"builtin::code_interpreter",
|
||||
],
|
||||
}
|
||||
|
|
@ -208,7 +208,7 @@ def test_code_execution(llama_stack_client):
|
|||
agent_config = AgentConfig(
|
||||
model="meta-llama/Llama-3.1-70B-Instruct",
|
||||
instructions="You are a helpful assistant",
|
||||
tools=[
|
||||
toolgroups=[
|
||||
"builtin::code_interpreter",
|
||||
],
|
||||
tool_choice="required",
|
||||
|
|
@ -250,7 +250,7 @@ def test_custom_tool(llama_stack_client, agent_config):
|
|||
agent_config = {
|
||||
**agent_config,
|
||||
"model": "meta-llama/Llama-3.2-3B-Instruct",
|
||||
"tools": ["builtin::web_search"],
|
||||
"toolgroups": ["builtin::websearch"],
|
||||
"client_tools": [client_tool.get_tool_definition()],
|
||||
"tool_prompt_format": "python_list",
|
||||
}
|
||||
|
|
@ -293,9 +293,14 @@ def test_rag_agent(llama_stack_client, agent_config):
|
|||
for i, url in enumerate(urls)
|
||||
]
|
||||
|
||||
agent_config["tools"].append("builtin::memory")
|
||||
agent = Agent(llama_stack_client, agent_config)
|
||||
memory_bank_id = "test-memory-bank"
|
||||
agent_config["toolgroups"].append(
|
||||
dict(
|
||||
name="builtin::memory",
|
||||
args={"memory_bank_id": memory_bank_id},
|
||||
)
|
||||
)
|
||||
agent = Agent(llama_stack_client, agent_config)
|
||||
llama_stack_client.memory_banks.register(
|
||||
memory_bank_id=memory_bank_id,
|
||||
params={
|
||||
|
|
@ -326,16 +331,8 @@ def test_rag_agent(llama_stack_client, agent_config):
|
|||
}
|
||||
],
|
||||
session_id=session_id,
|
||||
tools=[
|
||||
{
|
||||
"name": "memory",
|
||||
"args": {
|
||||
"memory_bank_id": memory_bank_id,
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
||||
logs_str = "".join(logs)
|
||||
assert "Tool:memory" in logs_str
|
||||
assert "Tool:query_memory" in logs_str
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue