mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 09:21:45 +00:00
add a RAG test to client SDK
This commit is contained in:
parent
c76f5f418f
commit
97798c8442
3 changed files with 105 additions and 9 deletions
|
@ -184,6 +184,7 @@ class AgentTurnResponseStepCompletePayload(BaseModel):
|
|||
AgentTurnResponseEventType.step_complete.value
|
||||
)
|
||||
step_type: StepType
|
||||
step_id: str
|
||||
step_details: Step
|
||||
|
||||
|
||||
|
|
|
@ -313,6 +313,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.shield_call.value,
|
||||
step_id=step_id,
|
||||
step_details=ShieldCallStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
|
@ -333,6 +334,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.shield_call.value,
|
||||
step_id=step_id,
|
||||
step_details=ShieldCallStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
|
@ -355,28 +357,26 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
if self.agent_config.preprocessing_tools:
|
||||
with tracing.span("preprocessing_tools") as span:
|
||||
for tool_name in self.agent_config.preprocessing_tools:
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=str(uuid.uuid4()),
|
||||
step_id=step_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
args = dict(
|
||||
session_id=session_id,
|
||||
turn_id=turn_id,
|
||||
input_messages=input_messages,
|
||||
attachments=attachments,
|
||||
)
|
||||
result = await self.tool_runtime_api.invoke_tool(
|
||||
tool_name=tool_name,
|
||||
args=args,
|
||||
)
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=str(uuid.uuid4()),
|
||||
step_id=step_id,
|
||||
tool_call_delta=ToolCallDelta(
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
content=ToolCall(
|
||||
|
@ -386,6 +386,37 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
)
|
||||
result = await self.tool_runtime_api.invoke_tool(
|
||||
tool_name=tool_name,
|
||||
args=args,
|
||||
)
|
||||
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
step_details=ToolExecutionStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
call_id="",
|
||||
tool_name=tool_name,
|
||||
arguments={},
|
||||
)
|
||||
],
|
||||
tool_responses=[
|
||||
ToolResponse(
|
||||
call_id="",
|
||||
tool_name=tool_name,
|
||||
content=result.content,
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
span.set_attribute(
|
||||
"input", [m.model_dump_json() for m in input_messages]
|
||||
)
|
||||
|
@ -393,7 +424,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
span.set_attribute("error_code", result.error_code)
|
||||
span.set_attribute("error_message", result.error_message)
|
||||
span.set_attribute("tool_name", tool_name)
|
||||
if result.error_code != 0 and result.content:
|
||||
if result.error_code == 0:
|
||||
last_message = input_messages[-1]
|
||||
last_message.context = result.content
|
||||
|
||||
|
@ -405,8 +436,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
for tool in self.agent_config.custom_tools:
|
||||
custom_tools[tool.name] = tool
|
||||
while True:
|
||||
msg = input_messages[-1]
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
|
|
|
@ -15,6 +15,7 @@ from llama_stack_client.lib.agents.event_logger import EventLogger
|
|||
from llama_stack_client.types import ToolResponseMessage
|
||||
from llama_stack_client.types.agent_create_params import AgentConfig
|
||||
from llama_stack_client.types.custom_tool_def import Parameter
|
||||
from llama_stack_client.types.memory_insert_params import Document
|
||||
from llama_stack_client.types.shared.completion_message import CompletionMessage
|
||||
|
||||
|
||||
|
@ -230,3 +231,68 @@ def test_custom_tool(llama_stack_client, agent_config):
|
|||
logs_str = "".join(logs)
|
||||
assert "-100" in logs_str
|
||||
assert "CustomTool" in logs_str
|
||||
|
||||
|
||||
def test_rag_agent(llama_stack_client, agent_config):
|
||||
urls = [
|
||||
"memory_optimizations.rst",
|
||||
"chat.rst",
|
||||
"llama3.rst",
|
||||
"datasets.rst",
|
||||
"qat_finetune.rst",
|
||||
"lora_finetune.rst",
|
||||
]
|
||||
documents = [
|
||||
Document(
|
||||
document_id=f"num-{i}",
|
||||
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
||||
mime_type="text/plain",
|
||||
metadata={},
|
||||
)
|
||||
for i, url in enumerate(urls)
|
||||
]
|
||||
llama_stack_client.memory_banks.register(
|
||||
memory_bank_id="test_bank",
|
||||
params={
|
||||
"memory_bank_type": "vector",
|
||||
"embedding_model": "all-MiniLM-L6-v2",
|
||||
"chunk_size_in_tokens": 512,
|
||||
"overlap_size_in_tokens": 64,
|
||||
},
|
||||
provider_id="faiss",
|
||||
)
|
||||
|
||||
# insert some documents
|
||||
llama_stack_client.memory.insert(
|
||||
bank_id="test_bank",
|
||||
documents=documents,
|
||||
)
|
||||
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
"preprocessing_tools": ["memory-tool"],
|
||||
}
|
||||
agent = Agent(llama_stack_client, agent_config)
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
||||
user_prompts = [
|
||||
"What are the top 5 topics that were explained in the documentation? Only list succinct bullet points.",
|
||||
"Was anything related to 'Llama3' discussed, if so what?",
|
||||
"Tell me how to use LoRA",
|
||||
"What about Quantization?",
|
||||
]
|
||||
|
||||
for prompt in user_prompts:
|
||||
response = agent.create_turn(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
||||
logs_str = "".join(logs)
|
||||
assert "Tool:memory-tool" in logs_str
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue