From 17abffb5052a8ca09a7b19fbbc843f732c480d41 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Mon, 6 Jan 2025 13:59:06 -0800 Subject: [PATCH] fix handle_docs --- .../agents/meta_reference/agent_instance.py | 39 +++++++------ tests/client-sdk/agents/test_agents.py | 57 +++++-------------- 2 files changed, 34 insertions(+), 62 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 43d5cbdb7..ac49a06ce 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -40,11 +40,7 @@ from llama_stack.apis.agents import ( ToolExecutionStep, Turn, ) -from llama_stack.apis.common.content_types import ( - InterleavedContent, - TextContentItem, - URL, -) +from llama_stack.apis.common.content_types import TextContentItem, URL from llama_stack.apis.inference import ( ChatCompletionResponseEventType, CompletionMessage, @@ -375,7 +371,10 @@ class ChatAgent(ShieldRunnerMixin): tool_args[tool.name] = tool.args tool_defs = await self._get_tool_defs(tools_for_turn) - await self.handle_documents(session_id, documents, input_messages, tool_defs) + if documents: + await self.handle_documents( + session_id, documents, input_messages, tool_defs + ) if "memory" in tool_defs and len(input_messages) > 0: with tracing.span("memory_tool") as span: step_id = str(uuid.uuid4()) @@ -759,26 +758,30 @@ class ChatAgent(ShieldRunnerMixin): input_messages: List[Message], tool_defs: Dict[str, ToolDefinition], ) -> None: + breakpoint() memory_tool = tool_defs.get("memory", None) - code_interpreter_tool = tool_defs.get("code_interpreter", None) + code_interpreter_tool = tool_defs.get(BuiltinTool.code_interpreter, None) if documents: - content_items = [ - d for d in documents if isinstance(d.content, InterleavedContent) - ] - url_items = [d for d in documents if isinstance(d.content, URL)] + content_items = [] + url_items = [] pattern = re.compile("^(https?://|file://|data:)") - url_items = [ - URL(uri=a.content) for a in url_items if pattern.match(a.content) - ] + for d in documents: + if isinstance(d.content, URL): + url_items.append(d.content) + elif pattern.match(d.content): + url_items.append(URL(uri=d.content)) + else: + content_items.append(d) + # Save the contents to a tempdir and use its path as a URL if code interpreter is present if code_interpreter_tool: for c in content_items: temp_file_path = os.path.join( self.tempdir, f"{make_random_string()}.txt" ) - with open(temp_file_path, "w") as temp_file: - temp_file.write(c.content) - url_items.append(URL(uri=f"file://{temp_file_path}")) + with open(temp_file_path, "w") as temp_file: + temp_file.write(c.content) + url_items.append(URL(uri=f"file://{temp_file_path}")) if memory_tool and code_interpreter_tool: # if both memory and code_interpreter are available, we download the URLs @@ -800,7 +803,7 @@ class ChatAgent(ShieldRunnerMixin): # if no memory or code_interpreter tool is available, # we try to load the data from the URLs and content items as a message to inference # and add it to the last message's context - input_messages[-1].context = content_items + load_data_from_urls( + input_messages[-1].context = content_items + await load_data_from_urls( url_items ) diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 522a8a4eb..a8e06b7a2 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -14,6 +14,7 @@ from llama_stack_client.lib.agents.client_tool import ClientTool from llama_stack_client.lib.agents.event_logger import EventLogger from llama_stack_client.types import ToolResponseMessage from llama_stack_client.types.agent_create_params import AgentConfig +from llama_stack_client.types.agents.turn_create_params import Document as AgentDocument from llama_stack_client.types.memory_insert_params import Document from llama_stack_client.types.shared.completion_message import CompletionMessage from llama_stack_client.types.tool_def_param import UserDefinedToolDefParameter @@ -208,7 +209,6 @@ def test_code_execution(llama_stack_client): model="meta-llama/Llama-3.1-70B-Instruct", instructions="You are a helpful assistant", tools=[ - "brave_search", "code_interpreter", ], tool_choice="required", @@ -217,49 +217,19 @@ def test_code_execution(llama_stack_client): enable_session_persistence=False, ) - memory_bank_id = "inflation_data_memory_bank" - llama_stack_client.memory_banks.register( - memory_bank_id=memory_bank_id, - params={ - "memory_bank_type": "vector", - "embedding_model": "all-MiniLM-L6-v2", - "chunk_size_in_tokens": 512, - "overlap_size_in_tokens": 64, - }, - ) - tool_name, _ = maybe_register_memory_tool(llama_stack_client) - agent_config["tools"].append(tool_name) codex_agent = Agent(llama_stack_client, agent_config) session_id = codex_agent.create_session("test-session") - - llama_stack_client.memory.insert( - bank_id=memory_bank_id, - documents=[ - Document( - document_id="inflation", - content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv", - mime_type="text/csv", - metadata={}, - ) - ], + inflation_doc = AgentDocument( + content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv", + mime_type="text/csv", ) - user_prompts = [ - { - "prompt": "Can you describe the data in the context?", - "tools": [{"name": "memory", "args": {"memory_bank_id": memory_bank_id}}], - }, - { - "prompt": "Plot average yearly inflation as a time series", - "tools": [ - {"name": "memory", "args": {"memory_bank_id": memory_bank_id}}, - "code_interpreter", - ], - }, + user_input = [ + {"prompt": "Here is a csv, can you describe it?", "documents": [inflation_doc]}, + {"prompt": "Plot average yearly inflation as a time series"}, ] - for input in user_prompts: - print(f'User> {input["prompt"]}') + for input in user_input: response = codex_agent.create_turn( messages=[ { @@ -268,13 +238,12 @@ def test_code_execution(llama_stack_client): } ], session_id=session_id, - tools=input["tools"], + documents=input.get("documents", None), ) - # for chunk in response: - # print(chunk) - - for log in EventLogger().log(response): - log.print() + logs = [str(log) for log in EventLogger().log(response) if log is not None] + logs_str = "".join(logs) + breakpoint() + print(logs_str) def test_custom_tool(llama_stack_client, agent_config):