add a RAG test to client SDK

This commit is contained in:
Dinesh Yeduguru 2024-12-26 09:13:34 -08:00
parent c76f5f418f
commit 97798c8442
3 changed files with 105 additions and 9 deletions

View file

@ -184,6 +184,7 @@ class AgentTurnResponseStepCompletePayload(BaseModel):
AgentTurnResponseEventType.step_complete.value
)
step_type: StepType
step_id: str
step_details: Step

View file

@ -313,6 +313,7 @@ class ChatAgent(ShieldRunnerMixin):
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value,
step_id=step_id,
step_details=ShieldCallStep(
step_id=step_id,
turn_id=turn_id,
@ -333,6 +334,7 @@ class ChatAgent(ShieldRunnerMixin):
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value,
step_id=step_id,
step_details=ShieldCallStep(
step_id=step_id,
turn_id=turn_id,
@ -355,28 +357,26 @@ class ChatAgent(ShieldRunnerMixin):
if self.agent_config.preprocessing_tools:
with tracing.span("preprocessing_tools") as span:
for tool_name in self.agent_config.preprocessing_tools:
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.tool_execution.value,
step_id=str(uuid.uuid4()),
step_id=step_id,
)
)
)
args = dict(
session_id=session_id,
turn_id=turn_id,
input_messages=input_messages,
attachments=attachments,
)
result = await self.tool_runtime_api.invoke_tool(
tool_name=tool_name,
args=args,
)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.tool_execution.value,
step_id=str(uuid.uuid4()),
step_id=step_id,
tool_call_delta=ToolCallDelta(
parse_status=ToolCallParseStatus.success,
content=ToolCall(
@ -386,6 +386,37 @@ class ChatAgent(ShieldRunnerMixin):
)
)
)
result = await self.tool_runtime_api.invoke_tool(
tool_name=tool_name,
args=args,
)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
step_details=ToolExecutionStep(
step_id=step_id,
turn_id=turn_id,
tool_calls=[
ToolCall(
call_id="",
tool_name=tool_name,
arguments={},
)
],
tool_responses=[
ToolResponse(
call_id="",
tool_name=tool_name,
content=result.content,
)
],
),
)
)
)
span.set_attribute(
"input", [m.model_dump_json() for m in input_messages]
)
@ -393,7 +424,7 @@ class ChatAgent(ShieldRunnerMixin):
span.set_attribute("error_code", result.error_code)
span.set_attribute("error_message", result.error_message)
span.set_attribute("tool_name", tool_name)
if result.error_code != 0 and result.content:
if result.error_code == 0:
last_message = input_messages[-1]
last_message.context = result.content
@ -405,8 +436,6 @@ class ChatAgent(ShieldRunnerMixin):
for tool in self.agent_config.custom_tools:
custom_tools[tool.name] = tool
while True:
msg = input_messages[-1]
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(

View file

@ -15,6 +15,7 @@ from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types import ToolResponseMessage
from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.types.custom_tool_def import Parameter
from llama_stack_client.types.memory_insert_params import Document
from llama_stack_client.types.shared.completion_message import CompletionMessage
@ -230,3 +231,68 @@ def test_custom_tool(llama_stack_client, agent_config):
logs_str = "".join(logs)
assert "-100" in logs_str
assert "CustomTool" in logs_str
def test_rag_agent(llama_stack_client, agent_config):
urls = [
"memory_optimizations.rst",
"chat.rst",
"llama3.rst",
"datasets.rst",
"qat_finetune.rst",
"lora_finetune.rst",
]
documents = [
Document(
document_id=f"num-{i}",
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain",
metadata={},
)
for i, url in enumerate(urls)
]
llama_stack_client.memory_banks.register(
memory_bank_id="test_bank",
params={
"memory_bank_type": "vector",
"embedding_model": "all-MiniLM-L6-v2",
"chunk_size_in_tokens": 512,
"overlap_size_in_tokens": 64,
},
provider_id="faiss",
)
# insert some documents
llama_stack_client.memory.insert(
bank_id="test_bank",
documents=documents,
)
agent_config = {
**agent_config,
"preprocessing_tools": ["memory-tool"],
}
agent = Agent(llama_stack_client, agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}")
user_prompts = [
"What are the top 5 topics that were explained in the documentation? Only list succinct bullet points.",
"Was anything related to 'Llama3' discussed, if so what?",
"Tell me how to use LoRA",
"What about Quantization?",
]
for prompt in user_prompts:
response = agent.create_turn(
messages=[
{
"role": "user",
"content": prompt,
}
],
session_id=session_id,
)
logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs_str = "".join(logs)
assert "Tool:memory-tool" in logs_str