mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 17:29:01 +00:00
add a RAG test to client SDK
This commit is contained in:
parent
c76f5f418f
commit
97798c8442
3 changed files with 105 additions and 9 deletions
|
@ -184,6 +184,7 @@ class AgentTurnResponseStepCompletePayload(BaseModel):
|
||||||
AgentTurnResponseEventType.step_complete.value
|
AgentTurnResponseEventType.step_complete.value
|
||||||
)
|
)
|
||||||
step_type: StepType
|
step_type: StepType
|
||||||
|
step_id: str
|
||||||
step_details: Step
|
step_details: Step
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -313,6 +313,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepCompletePayload(
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
step_type=StepType.shield_call.value,
|
step_type=StepType.shield_call.value,
|
||||||
|
step_id=step_id,
|
||||||
step_details=ShieldCallStep(
|
step_details=ShieldCallStep(
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
|
@ -333,6 +334,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepCompletePayload(
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
step_type=StepType.shield_call.value,
|
step_type=StepType.shield_call.value,
|
||||||
|
step_id=step_id,
|
||||||
step_details=ShieldCallStep(
|
step_details=ShieldCallStep(
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
|
@ -355,28 +357,26 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
if self.agent_config.preprocessing_tools:
|
if self.agent_config.preprocessing_tools:
|
||||||
with tracing.span("preprocessing_tools") as span:
|
with tracing.span("preprocessing_tools") as span:
|
||||||
for tool_name in self.agent_config.preprocessing_tools:
|
for tool_name in self.agent_config.preprocessing_tools:
|
||||||
|
step_id = str(uuid.uuid4())
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepStartPayload(
|
payload=AgentTurnResponseStepStartPayload(
|
||||||
step_type=StepType.tool_execution.value,
|
step_type=StepType.tool_execution.value,
|
||||||
step_id=str(uuid.uuid4()),
|
step_id=step_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
args = dict(
|
args = dict(
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
|
turn_id=turn_id,
|
||||||
input_messages=input_messages,
|
input_messages=input_messages,
|
||||||
attachments=attachments,
|
attachments=attachments,
|
||||||
)
|
)
|
||||||
result = await self.tool_runtime_api.invoke_tool(
|
|
||||||
tool_name=tool_name,
|
|
||||||
args=args,
|
|
||||||
)
|
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepProgressPayload(
|
payload=AgentTurnResponseStepProgressPayload(
|
||||||
step_type=StepType.tool_execution.value,
|
step_type=StepType.tool_execution.value,
|
||||||
step_id=str(uuid.uuid4()),
|
step_id=step_id,
|
||||||
tool_call_delta=ToolCallDelta(
|
tool_call_delta=ToolCallDelta(
|
||||||
parse_status=ToolCallParseStatus.success,
|
parse_status=ToolCallParseStatus.success,
|
||||||
content=ToolCall(
|
content=ToolCall(
|
||||||
|
@ -386,6 +386,37 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
result = await self.tool_runtime_api.invoke_tool(
|
||||||
|
tool_name=tool_name,
|
||||||
|
args=args,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield AgentTurnResponseStreamChunk(
|
||||||
|
event=AgentTurnResponseEvent(
|
||||||
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
|
step_type=StepType.tool_execution.value,
|
||||||
|
step_id=step_id,
|
||||||
|
step_details=ToolExecutionStep(
|
||||||
|
step_id=step_id,
|
||||||
|
turn_id=turn_id,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(
|
||||||
|
call_id="",
|
||||||
|
tool_name=tool_name,
|
||||||
|
arguments={},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
tool_responses=[
|
||||||
|
ToolResponse(
|
||||||
|
call_id="",
|
||||||
|
tool_name=tool_name,
|
||||||
|
content=result.content,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
span.set_attribute(
|
span.set_attribute(
|
||||||
"input", [m.model_dump_json() for m in input_messages]
|
"input", [m.model_dump_json() for m in input_messages]
|
||||||
)
|
)
|
||||||
|
@ -393,7 +424,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
span.set_attribute("error_code", result.error_code)
|
span.set_attribute("error_code", result.error_code)
|
||||||
span.set_attribute("error_message", result.error_message)
|
span.set_attribute("error_message", result.error_message)
|
||||||
span.set_attribute("tool_name", tool_name)
|
span.set_attribute("tool_name", tool_name)
|
||||||
if result.error_code != 0 and result.content:
|
if result.error_code == 0:
|
||||||
last_message = input_messages[-1]
|
last_message = input_messages[-1]
|
||||||
last_message.context = result.content
|
last_message.context = result.content
|
||||||
|
|
||||||
|
@ -405,8 +436,6 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
for tool in self.agent_config.custom_tools:
|
for tool in self.agent_config.custom_tools:
|
||||||
custom_tools[tool.name] = tool
|
custom_tools[tool.name] = tool
|
||||||
while True:
|
while True:
|
||||||
msg = input_messages[-1]
|
|
||||||
|
|
||||||
step_id = str(uuid.uuid4())
|
step_id = str(uuid.uuid4())
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
|
|
|
@ -15,6 +15,7 @@ from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||||
from llama_stack_client.types import ToolResponseMessage
|
from llama_stack_client.types import ToolResponseMessage
|
||||||
from llama_stack_client.types.agent_create_params import AgentConfig
|
from llama_stack_client.types.agent_create_params import AgentConfig
|
||||||
from llama_stack_client.types.custom_tool_def import Parameter
|
from llama_stack_client.types.custom_tool_def import Parameter
|
||||||
|
from llama_stack_client.types.memory_insert_params import Document
|
||||||
from llama_stack_client.types.shared.completion_message import CompletionMessage
|
from llama_stack_client.types.shared.completion_message import CompletionMessage
|
||||||
|
|
||||||
|
|
||||||
|
@ -230,3 +231,68 @@ def test_custom_tool(llama_stack_client, agent_config):
|
||||||
logs_str = "".join(logs)
|
logs_str = "".join(logs)
|
||||||
assert "-100" in logs_str
|
assert "-100" in logs_str
|
||||||
assert "CustomTool" in logs_str
|
assert "CustomTool" in logs_str
|
||||||
|
|
||||||
|
|
||||||
|
def test_rag_agent(llama_stack_client, agent_config):
|
||||||
|
urls = [
|
||||||
|
"memory_optimizations.rst",
|
||||||
|
"chat.rst",
|
||||||
|
"llama3.rst",
|
||||||
|
"datasets.rst",
|
||||||
|
"qat_finetune.rst",
|
||||||
|
"lora_finetune.rst",
|
||||||
|
]
|
||||||
|
documents = [
|
||||||
|
Document(
|
||||||
|
document_id=f"num-{i}",
|
||||||
|
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
||||||
|
mime_type="text/plain",
|
||||||
|
metadata={},
|
||||||
|
)
|
||||||
|
for i, url in enumerate(urls)
|
||||||
|
]
|
||||||
|
llama_stack_client.memory_banks.register(
|
||||||
|
memory_bank_id="test_bank",
|
||||||
|
params={
|
||||||
|
"memory_bank_type": "vector",
|
||||||
|
"embedding_model": "all-MiniLM-L6-v2",
|
||||||
|
"chunk_size_in_tokens": 512,
|
||||||
|
"overlap_size_in_tokens": 64,
|
||||||
|
},
|
||||||
|
provider_id="faiss",
|
||||||
|
)
|
||||||
|
|
||||||
|
# insert some documents
|
||||||
|
llama_stack_client.memory.insert(
|
||||||
|
bank_id="test_bank",
|
||||||
|
documents=documents,
|
||||||
|
)
|
||||||
|
|
||||||
|
agent_config = {
|
||||||
|
**agent_config,
|
||||||
|
"preprocessing_tools": ["memory-tool"],
|
||||||
|
}
|
||||||
|
agent = Agent(llama_stack_client, agent_config)
|
||||||
|
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||||
|
|
||||||
|
user_prompts = [
|
||||||
|
"What are the top 5 topics that were explained in the documentation? Only list succinct bullet points.",
|
||||||
|
"Was anything related to 'Llama3' discussed, if so what?",
|
||||||
|
"Tell me how to use LoRA",
|
||||||
|
"What about Quantization?",
|
||||||
|
]
|
||||||
|
|
||||||
|
for prompt in user_prompts:
|
||||||
|
response = agent.create_turn(
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
||||||
|
logs_str = "".join(logs)
|
||||||
|
assert "Tool:memory-tool" in logs_str
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue