add a RAG test to client SDK

This commit is contained in:
Dinesh Yeduguru 2024-12-26 09:13:34 -08:00
parent c76f5f418f
commit 97798c8442
3 changed files with 105 additions and 9 deletions

View file

@ -313,6 +313,7 @@ class ChatAgent(ShieldRunnerMixin):
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value,
step_id=step_id,
step_details=ShieldCallStep(
step_id=step_id,
turn_id=turn_id,
@ -333,6 +334,7 @@ class ChatAgent(ShieldRunnerMixin):
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value,
step_id=step_id,
step_details=ShieldCallStep(
step_id=step_id,
turn_id=turn_id,
@ -355,28 +357,26 @@ class ChatAgent(ShieldRunnerMixin):
if self.agent_config.preprocessing_tools:
with tracing.span("preprocessing_tools") as span:
for tool_name in self.agent_config.preprocessing_tools:
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.tool_execution.value,
step_id=str(uuid.uuid4()),
step_id=step_id,
)
)
)
args = dict(
session_id=session_id,
turn_id=turn_id,
input_messages=input_messages,
attachments=attachments,
)
result = await self.tool_runtime_api.invoke_tool(
tool_name=tool_name,
args=args,
)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.tool_execution.value,
step_id=str(uuid.uuid4()),
step_id=step_id,
tool_call_delta=ToolCallDelta(
parse_status=ToolCallParseStatus.success,
content=ToolCall(
@ -386,6 +386,37 @@ class ChatAgent(ShieldRunnerMixin):
)
)
)
result = await self.tool_runtime_api.invoke_tool(
tool_name=tool_name,
args=args,
)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
step_details=ToolExecutionStep(
step_id=step_id,
turn_id=turn_id,
tool_calls=[
ToolCall(
call_id="",
tool_name=tool_name,
arguments={},
)
],
tool_responses=[
ToolResponse(
call_id="",
tool_name=tool_name,
content=result.content,
)
],
),
)
)
)
span.set_attribute(
"input", [m.model_dump_json() for m in input_messages]
)
@ -393,7 +424,7 @@ class ChatAgent(ShieldRunnerMixin):
span.set_attribute("error_code", result.error_code)
span.set_attribute("error_message", result.error_message)
span.set_attribute("tool_name", tool_name)
if result.error_code != 0 and result.content:
if result.error_code == 0:
last_message = input_messages[-1]
last_message.context = result.content
@ -405,8 +436,6 @@ class ChatAgent(ShieldRunnerMixin):
for tool in self.agent_config.custom_tools:
custom_tools[tool.name] = tool
while True:
msg = input_messages[-1]
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(