diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index eca7364d7..706dd74f1 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -66,6 +66,7 @@ from llama_stack.apis.vector_io import VectorIO from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content from llama_stack.providers.utils.telemetry import tracing + from .persistence import AgentPersistence from .safety import SafetyException, ShieldRunnerMixin @@ -476,9 +477,12 @@ class ChatAgent(ShieldRunnerMixin): ) span.set_attribute("output", retrieved_context) span.set_attribute("tool_name", MEMORY_QUERY_TOOL) - if retrieved_context: - last_message = input_messages[-1] - last_message.context = retrieved_context + + # append retrieved_context to the last user message + for message in input_messages[::-1]: + if isinstance(message, UserMessage): + message.context = retrieved_context + break output_attachments = [] diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 4a8fdd36a..e0f86e3d7 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -211,7 +211,7 @@ def test_code_interpreter_for_attachments(llama_stack_client, agent_config): } codex_agent = Agent(llama_stack_client, agent_config) - session_id = codex_agent.create_session("test-session") + session_id = codex_agent.create_session(f"test-session-{uuid4()}") inflation_doc = AgentDocument( content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv", mime_type="text/csv", @@ -285,7 +285,8 @@ def test_rag_agent(llama_stack_client, agent_config): llama_stack_client.tool_runtime.rag_tool.insert( documents=documents, vector_db_id=vector_db_id, - chunk_size_in_tokens=512, + # small chunks help to get specific info out of the docs + chunk_size_in_tokens=128, ) agent_config = { **agent_config, @@ -299,11 +300,15 @@ def test_rag_agent(llama_stack_client, agent_config): ], } rag_agent = Agent(llama_stack_client, agent_config) - session_id = rag_agent.create_session("test-session") + session_id = rag_agent.create_session(f"test-session-{uuid4()}") user_prompts = [ - "What are the top 5 topics that were explained? Only list succinct bullet points.", + ( + "Instead of the standard multi-head attention, what attention type does Llama3-8B use?", + "grouped-query", + ), + ("What command to use to get access to Llama3-8B-Instruct ?", "tune download"), ] - for prompt in user_prompts: + for prompt, expected_kw in user_prompts: print(f"User> {prompt}") response = rag_agent.create_turn( messages=[{"role": "user", "content": prompt}], @@ -312,3 +317,69 @@ def test_rag_agent(llama_stack_client, agent_config): logs = [str(log) for log in EventLogger().log(response) if log is not None] logs_str = "".join(logs) assert "Tool:query_from_memory" in logs_str + assert expected_kw in logs_str.lower() + + +def test_rag_and_code_agent(llama_stack_client, agent_config): + urls = ["chat.rst"] + documents = [ + Document( + document_id=f"num-{i}", + content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}", + mime_type="text/plain", + metadata={}, + ) + for i, url in enumerate(urls) + ] + vector_db_id = "test-vector-db" + llama_stack_client.vector_dbs.register( + vector_db_id=vector_db_id, + embedding_model="all-MiniLM-L6-v2", + embedding_dimension=384, + ) + llama_stack_client.tool_runtime.rag_tool.insert( + documents=documents, + vector_db_id=vector_db_id, + chunk_size_in_tokens=128, + ) + agent_config = { + **agent_config, + "toolgroups": [ + dict( + name="builtin::rag", + args={"vector_db_ids": [vector_db_id]}, + ), + "builtin::code_interpreter", + ], + } + agent = Agent(llama_stack_client, agent_config) + inflation_doc = Document( + document_id="test_csv", + content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv", + mime_type="text/csv", + metadata={}, + ) + user_prompts = [ + ( + "Here is a csv file, can you describe it?", + [inflation_doc], + "code_interpreter", + ), + ( + "What are the top 5 topics that were explained? Only list succinct bullet points.", + [], + "query_from_memory", + ), + ] + + for prompt, docs, tool_name in user_prompts: + print(f"User> {prompt}") + session_id = agent.create_session(f"test-session-{uuid4()}") + response = agent.create_turn( + messages=[{"role": "user", "content": prompt}], + session_id=session_id, + documents=docs, + ) + logs = [str(log) for log in EventLogger().log(response) if log is not None] + logs_str = "".join(logs) + assert f"Tool:{tool_name}" in logs_str diff --git a/tests/client-sdk/inference/test_inference.py b/tests/client-sdk/inference/test_inference.py index 6dff1be24..b10ede357 100644 --- a/tests/client-sdk/inference/test_inference.py +++ b/tests/client-sdk/inference/test_inference.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import base64 -import os +import pathlib import pytest from pydantic import BaseModel @@ -57,13 +57,20 @@ def get_weather_tool_definition(): @pytest.fixture -def base64_image_url(): - image_path = os.path.join(os.path.dirname(__file__), "dog.png") - with open(image_path, "rb") as image_file: - # Convert the image to base64 - base64_string = base64.b64encode(image_file.read()).decode("utf-8") - base64_url = f"data:image/png;base64,{base64_string}" - return base64_url +def image_path(): + return pathlib.Path(__file__).parent / "dog.png" + + +@pytest.fixture +def base64_image_data(image_path): + # Convert the image to base64 + return base64.b64encode(image_path.read_bytes()).decode("utf-8") + + +@pytest.fixture +def base64_image_url(base64_image_data, image_path): + # suffix includes the ., so we remove it + return f"data:image/{image_path.suffix[1:]};base64,{base64_image_data}" def test_text_completion_non_streaming(llama_stack_client, text_model_id): @@ -371,20 +378,31 @@ def test_image_chat_completion_streaming(llama_stack_client, vision_model_id): assert any(expected in streamed_content for expected in {"dog", "puppy", "pup"}) -def test_image_chat_completion_base64_url( - llama_stack_client, vision_model_id, base64_image_url +@pytest.mark.parametrize("type_", ["url", "data"]) +def test_image_chat_completion_base64( + llama_stack_client, vision_model_id, base64_image_data, base64_image_url, type_ ): + image_spec = { + "url": { + "type": "image", + "image": { + "url": { + "uri": base64_image_url, + }, + }, + }, + "data": { + "type": "image", + "image": { + "data": base64_image_data, + }, + }, + }[type_] + message = { "role": "user", "content": [ - { - "type": "image", - "image": { - "url": { - "uri": base64_image_url, - }, - }, - }, + image_spec, { "type": "text", "text": "Describe what is in this image.",