From 97798c84420ebee05533168738b63108d42366bf Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Thu, 26 Dec 2024 09:13:34 -0800 Subject: [PATCH] add a RAG test to client SDK --- llama_stack/apis/agents/agents.py | 1 + .../agents/meta_reference/agent_instance.py | 47 ++++++++++--- tests/client-sdk/agents/test_agents.py | 66 +++++++++++++++++++ 3 files changed, 105 insertions(+), 9 deletions(-) diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 3348211c9..14278b803 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -184,6 +184,7 @@ class AgentTurnResponseStepCompletePayload(BaseModel): AgentTurnResponseEventType.step_complete.value ) step_type: StepType + step_id: str step_details: Step diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index ba190f567..1ecb95e68 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -313,6 +313,7 @@ class ChatAgent(ShieldRunnerMixin): event=AgentTurnResponseEvent( payload=AgentTurnResponseStepCompletePayload( step_type=StepType.shield_call.value, + step_id=step_id, step_details=ShieldCallStep( step_id=step_id, turn_id=turn_id, @@ -333,6 +334,7 @@ class ChatAgent(ShieldRunnerMixin): event=AgentTurnResponseEvent( payload=AgentTurnResponseStepCompletePayload( step_type=StepType.shield_call.value, + step_id=step_id, step_details=ShieldCallStep( step_id=step_id, turn_id=turn_id, @@ -355,28 +357,26 @@ class ChatAgent(ShieldRunnerMixin): if self.agent_config.preprocessing_tools: with tracing.span("preprocessing_tools") as span: for tool_name in self.agent_config.preprocessing_tools: + step_id = str(uuid.uuid4()) yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepStartPayload( step_type=StepType.tool_execution.value, - step_id=str(uuid.uuid4()), + step_id=step_id, ) ) ) args = dict( session_id=session_id, + turn_id=turn_id, input_messages=input_messages, attachments=attachments, ) - result = await self.tool_runtime_api.invoke_tool( - tool_name=tool_name, - args=args, - ) yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepProgressPayload( step_type=StepType.tool_execution.value, - step_id=str(uuid.uuid4()), + step_id=step_id, tool_call_delta=ToolCallDelta( parse_status=ToolCallParseStatus.success, content=ToolCall( @@ -386,6 +386,37 @@ class ChatAgent(ShieldRunnerMixin): ) ) ) + result = await self.tool_runtime_api.invoke_tool( + tool_name=tool_name, + args=args, + ) + + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepCompletePayload( + step_type=StepType.tool_execution.value, + step_id=step_id, + step_details=ToolExecutionStep( + step_id=step_id, + turn_id=turn_id, + tool_calls=[ + ToolCall( + call_id="", + tool_name=tool_name, + arguments={}, + ) + ], + tool_responses=[ + ToolResponse( + call_id="", + tool_name=tool_name, + content=result.content, + ) + ], + ), + ) + ) + ) span.set_attribute( "input", [m.model_dump_json() for m in input_messages] ) @@ -393,7 +424,7 @@ class ChatAgent(ShieldRunnerMixin): span.set_attribute("error_code", result.error_code) span.set_attribute("error_message", result.error_message) span.set_attribute("tool_name", tool_name) - if result.error_code != 0 and result.content: + if result.error_code == 0: last_message = input_messages[-1] last_message.context = result.content @@ -405,8 +436,6 @@ class ChatAgent(ShieldRunnerMixin): for tool in self.agent_config.custom_tools: custom_tools[tool.name] = tool while True: - msg = input_messages[-1] - step_id = str(uuid.uuid4()) yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 1b2192949..10aaa09b5 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -15,6 +15,7 @@ from llama_stack_client.lib.agents.event_logger import EventLogger from llama_stack_client.types import ToolResponseMessage from llama_stack_client.types.agent_create_params import AgentConfig from llama_stack_client.types.custom_tool_def import Parameter +from llama_stack_client.types.memory_insert_params import Document from llama_stack_client.types.shared.completion_message import CompletionMessage @@ -230,3 +231,68 @@ def test_custom_tool(llama_stack_client, agent_config): logs_str = "".join(logs) assert "-100" in logs_str assert "CustomTool" in logs_str + + +def test_rag_agent(llama_stack_client, agent_config): + urls = [ + "memory_optimizations.rst", + "chat.rst", + "llama3.rst", + "datasets.rst", + "qat_finetune.rst", + "lora_finetune.rst", + ] + documents = [ + Document( + document_id=f"num-{i}", + content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}", + mime_type="text/plain", + metadata={}, + ) + for i, url in enumerate(urls) + ] + llama_stack_client.memory_banks.register( + memory_bank_id="test_bank", + params={ + "memory_bank_type": "vector", + "embedding_model": "all-MiniLM-L6-v2", + "chunk_size_in_tokens": 512, + "overlap_size_in_tokens": 64, + }, + provider_id="faiss", + ) + + # insert some documents + llama_stack_client.memory.insert( + bank_id="test_bank", + documents=documents, + ) + + agent_config = { + **agent_config, + "preprocessing_tools": ["memory-tool"], + } + agent = Agent(llama_stack_client, agent_config) + session_id = agent.create_session(f"test-session-{uuid4()}") + + user_prompts = [ + "What are the top 5 topics that were explained in the documentation? Only list succinct bullet points.", + "Was anything related to 'Llama3' discussed, if so what?", + "Tell me how to use LoRA", + "What about Quantization?", + ] + + for prompt in user_prompts: + response = agent.create_turn( + messages=[ + { + "role": "user", + "content": prompt, + } + ], + session_id=session_id, + ) + + logs = [str(log) for log in EventLogger().log(response) if log is not None] + logs_str = "".join(logs) + assert "Tool:memory-tool" in logs_str