From db2ec110a1a85c7cadc644d44e71873617f0bea9 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 7 Jan 2025 22:13:33 -0800 Subject: [PATCH] fix failing code interpreter tests --- .../agents/meta_reference/agent_instance.py | 3 +- tests/client-sdk/agents/test_agents.py | 48 +++++++------------ 2 files changed, 19 insertions(+), 32 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index f9ffb2ae0..0c1d50b15 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -78,7 +78,6 @@ def make_random_string(length: int = 8): TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})") MEMORY_TOOL_GROUP_ID = "builtin::memory" MEMORY_QUERY_TOOL = "query_memory" -CODE_INTERPRETER_TOOL = "code_interpreter" WEB_SEARCH_TOOL = "web_search" @@ -787,7 +786,7 @@ class ChatAgent(ShieldRunnerMixin): tool_defs: Dict[str, ToolDefinition], ) -> None: memory_tool = tool_defs.get(MEMORY_QUERY_TOOL, None) - code_interpreter_tool = tool_defs.get(CODE_INTERPRETER_TOOL, None) + code_interpreter_tool = tool_defs.get(BuiltinTool.code_interpreter, None) content_items = [] url_items = [] pattern = re.compile("^(https?://|file://|data:)") diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 01ffe2025..a2ed687a4 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -275,14 +275,7 @@ def test_custom_tool(llama_stack_client, agent_config): def test_rag_agent(llama_stack_client, agent_config): - urls = [ - "memory_optimizations.rst", - "chat.rst", - "llama3.rst", - "datasets.rst", - "qat_finetune.rst", - "lora_finetune.rst", - ] + urls = ["chat.rst", "llama3.rst", "datasets.rst", "lora_finetune.rst"] documents = [ Document( document_id=f"num-{i}", @@ -292,15 +285,7 @@ def test_rag_agent(llama_stack_client, agent_config): ) for i, url in enumerate(urls) ] - memory_bank_id = "test-memory-bank" - agent_config["toolgroups"].append( - dict( - name="builtin::memory", - args={"memory_bank_ids": [memory_bank_id]}, - ) - ) - agent = Agent(llama_stack_client, agent_config) llama_stack_client.memory_banks.register( memory_bank_id=memory_bank_id, params={ @@ -314,25 +299,28 @@ def test_rag_agent(llama_stack_client, agent_config): bank_id=memory_bank_id, documents=documents, ) - session_id = agent.create_session(f"test-session-{uuid4()}") - + agent_config = { + **agent_config, + "toolgroups": [ + dict( + name="builtin::memory", + args={ + "memory_bank_ids": [memory_bank_id], + }, + ) + ], + } + rag_agent = Agent(llama_stack_client, agent_config) + session_id = rag_agent.create_session("test-session") user_prompts = [ - "What are the top 5 topics that were explained in the documentation? Only list succinct bullet points.", - "Was anything related to 'Llama3' discussed, if so what?", - "Tell me how to use LoRA", + "What are the top 5 topics that were explained? Only list succinct bullet points.", ] - for prompt in user_prompts: - response = agent.create_turn( - messages=[ - { - "role": "user", - "content": prompt, - } - ], + print(f"User> {prompt}") + response = rag_agent.create_turn( + messages=[{"role": "user", "content": prompt}], session_id=session_id, ) - logs = [str(log) for log in EventLogger().log(response) if log is not None] logs_str = "".join(logs) assert "Tool:query_memory" in logs_str