forked from phoenix-oss/llama-stack-mirror
Telemetry API redesign (#525)
# What does this PR do? Change the Telemetry API to be able to support different use cases like returning traces for the UI and ability to export for Evals. Other changes: * Add a new trace_protocol decorator to decorate all our API methods so that any call to them will automatically get traced across all impls. * There is some issue with the decorator pattern of span creation when using async generators, where there are multiple yields with in the same context. I think its much more explicit by using the explicit context manager pattern using with. I moved the span creations in agent instance to be using with * Inject session id at the turn level, which should quickly give us all traces across turns for a given session Addresses #509 ## Test Plan ``` llama stack run /Users/dineshyv/.llama/distributions/llamastack-together/together-run.yaml PYTHONPATH=. python -m examples.agents.rag_with_memory_bank localhost 5000 curl -X POST 'http://localhost:5000/alpha/telemetry/query-traces' \ -H 'Content-Type: application/json' \ -d '{ "attribute_filters": [ { "key": "session_id", "op": "eq", "value": "dd667b87-ca4b-4d30-9265-5a0de318fc65" }], "limit": 100, "offset": 0, "order_by": ["start_time"] }' | jq . [ { "trace_id": "6902f54b83b4b48be18a6f422b13e16f", "root_span_id": "5f37b85543afc15a", "start_time": "2024-12-04T08:08:30.501587", "end_time": "2024-12-04T08:08:36.026463" }, { "trace_id": "92227dac84c0615ed741be393813fb5f", "root_span_id": "af7c5bb46665c2c8", "start_time": "2024-12-04T08:08:36.031170", "end_time": "2024-12-04T08:08:41.693301" }, { "trace_id": "7d578a6edac62f204ab479fba82f77b6", "root_span_id": "1d935e3362676896", "start_time": "2024-12-04T08:08:41.695204", "end_time": "2024-12-04T08:08:47.228016" }, { "trace_id": "dbd767d76991bc816f9f078907dc9ff2", "root_span_id": "f5a7ee76683b9602", "start_time": "2024-12-04T08:08:47.234578", "end_time": "2024-12-04T08:08:53.189412" } ] curl -X POST 'http://localhost:5000/alpha/telemetry/get-span-tree' \ -H 'Content-Type: application/json' \ -d '{ "span_id" : "6cceb4b48a156913", "max_depth": 2, "attributes_to_return": ["input"] }' | jq . % Total % Received % Xferd Average Speed Time Time Time Current Dload Upload Total Spent Left Speed 100 875 100 790 100 85 18462 1986 --:--:-- --:--:-- --:--:-- 20833 { "span_id": "6cceb4b48a156913", "trace_id": "dafa796f6aaf925f511c04cd7c67fdda", "parent_span_id": "892a66d726c7f990", "name": "retrieve_rag_context", "start_time": "2024-12-04T09:28:21.781995", "end_time": "2024-12-04T09:28:21.913352", "attributes": { "input": [ "{\"role\":\"system\",\"content\":\"You are a helpful assistant\"}", "{\"role\":\"user\",\"content\":\"What are the top 5 topics that were explained in the documentation? Only list succinct bullet points.\",\"context\":null}" ] }, "children": [ { "span_id": "1a2df181854064a8", "trace_id": "dafa796f6aaf925f511c04cd7c67fdda", "parent_span_id": "6cceb4b48a156913", "name": "MemoryRouter.query_documents", "start_time": "2024-12-04T09:28:21.787620", "end_time": "2024-12-04T09:28:21.906512", "attributes": { "input": null }, "children": [], "status": "ok" } ], "status": "ok" } ``` <img width="1677" alt="Screenshot 2024-12-04 at 9 42 56 AM" src="https://github.com/user-attachments/assets/4d3cea93-05ce-415a-93d9-4b1628631bf8">
This commit is contained in:
parent
16769256b7
commit
fcd6449519
34 changed files with 1551 additions and 245 deletions
|
@ -144,87 +144,91 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
async def create_session(self, name: str) -> str:
|
||||
return await self.storage.create_session(name)
|
||||
|
||||
@tracing.span("create_and_execute_turn")
|
||||
async def create_and_execute_turn(
|
||||
self, request: AgentTurnCreateRequest
|
||||
) -> AsyncGenerator:
|
||||
assert request.stream is True, "Non-streaming not supported"
|
||||
with tracing.span("create_and_execute_turn") as span:
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("request", request.model_dump_json())
|
||||
assert request.stream is True, "Non-streaming not supported"
|
||||
|
||||
session_info = await self.storage.get_session_info(request.session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {request.session_id} not found")
|
||||
session_info = await self.storage.get_session_info(request.session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {request.session_id} not found")
|
||||
|
||||
turns = await self.storage.get_session_turns(request.session_id)
|
||||
turns = await self.storage.get_session_turns(request.session_id)
|
||||
|
||||
messages = []
|
||||
if self.agent_config.instructions != "":
|
||||
messages.append(SystemMessage(content=self.agent_config.instructions))
|
||||
messages = []
|
||||
if self.agent_config.instructions != "":
|
||||
messages.append(SystemMessage(content=self.agent_config.instructions))
|
||||
|
||||
for i, turn in enumerate(turns):
|
||||
messages.extend(self.turn_to_messages(turn))
|
||||
for i, turn in enumerate(turns):
|
||||
messages.extend(self.turn_to_messages(turn))
|
||||
|
||||
messages.extend(request.messages)
|
||||
messages.extend(request.messages)
|
||||
|
||||
turn_id = str(uuid.uuid4())
|
||||
start_time = datetime.now()
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnStartPayload(
|
||||
turn_id=turn_id,
|
||||
turn_id = str(uuid.uuid4())
|
||||
span.set_attribute("turn_id", turn_id)
|
||||
start_time = datetime.now()
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnStartPayload(
|
||||
turn_id=turn_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
steps = []
|
||||
output_message = None
|
||||
async for chunk in self.run(
|
||||
session_id=request.session_id,
|
||||
turn_id=turn_id,
|
||||
input_messages=messages,
|
||||
attachments=request.attachments or [],
|
||||
sampling_params=self.agent_config.sampling_params,
|
||||
stream=request.stream,
|
||||
):
|
||||
if isinstance(chunk, CompletionMessage):
|
||||
log.info(
|
||||
f"{chunk.role.capitalize()}: {chunk.content}",
|
||||
)
|
||||
output_message = chunk
|
||||
continue
|
||||
|
||||
assert isinstance(
|
||||
chunk, AgentTurnResponseStreamChunk
|
||||
), f"Unexpected type {type(chunk)}"
|
||||
event = chunk.event
|
||||
if (
|
||||
event.payload.event_type
|
||||
== AgentTurnResponseEventType.step_complete.value
|
||||
steps = []
|
||||
output_message = None
|
||||
async for chunk in self.run(
|
||||
session_id=request.session_id,
|
||||
turn_id=turn_id,
|
||||
input_messages=messages,
|
||||
attachments=request.attachments or [],
|
||||
sampling_params=self.agent_config.sampling_params,
|
||||
stream=request.stream,
|
||||
):
|
||||
steps.append(event.payload.step_details)
|
||||
if isinstance(chunk, CompletionMessage):
|
||||
log.info(
|
||||
f"{chunk.role.capitalize()}: {chunk.content}",
|
||||
)
|
||||
output_message = chunk
|
||||
continue
|
||||
|
||||
yield chunk
|
||||
assert isinstance(
|
||||
chunk, AgentTurnResponseStreamChunk
|
||||
), f"Unexpected type {type(chunk)}"
|
||||
event = chunk.event
|
||||
if (
|
||||
event.payload.event_type
|
||||
== AgentTurnResponseEventType.step_complete.value
|
||||
):
|
||||
steps.append(event.payload.step_details)
|
||||
|
||||
assert output_message is not None
|
||||
yield chunk
|
||||
|
||||
turn = Turn(
|
||||
turn_id=turn_id,
|
||||
session_id=request.session_id,
|
||||
input_messages=request.messages,
|
||||
output_message=output_message,
|
||||
started_at=start_time,
|
||||
completed_at=datetime.now(),
|
||||
steps=steps,
|
||||
)
|
||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||
assert output_message is not None
|
||||
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnCompletePayload(
|
||||
turn=turn,
|
||||
turn = Turn(
|
||||
turn_id=turn_id,
|
||||
session_id=request.session_id,
|
||||
input_messages=request.messages,
|
||||
output_message=output_message,
|
||||
started_at=start_time,
|
||||
completed_at=datetime.now(),
|
||||
steps=steps,
|
||||
)
|
||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnCompletePayload(
|
||||
turn=turn,
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
yield chunk
|
||||
yield chunk
|
||||
|
||||
async def run(
|
||||
self,
|
||||
|
@ -273,7 +277,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
yield final_response
|
||||
|
||||
@tracing.span("run_shields")
|
||||
async def run_multiple_shields_wrapper(
|
||||
self,
|
||||
turn_id: str,
|
||||
|
@ -281,23 +284,47 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
shields: List[str],
|
||||
touchpoint: str,
|
||||
) -> AsyncGenerator:
|
||||
if len(shields) == 0:
|
||||
return
|
||||
with tracing.span("run_shields") as span:
|
||||
span.set_attribute("turn_id", turn_id)
|
||||
span.set_attribute("input", [m.model_dump_json() for m in messages])
|
||||
if len(shields) == 0:
|
||||
span.set_attribute("output", "no shields")
|
||||
return
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
try:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
step_type=StepType.shield_call.value,
|
||||
step_id=step_id,
|
||||
metadata=dict(touchpoint=touchpoint),
|
||||
step_id = str(uuid.uuid4())
|
||||
try:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
step_type=StepType.shield_call.value,
|
||||
step_id=step_id,
|
||||
metadata=dict(touchpoint=touchpoint),
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
await self.run_multiple_shields(messages, shields)
|
||||
await self.run_multiple_shields(messages, shields)
|
||||
|
||||
except SafetyException as e:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.shield_call.value,
|
||||
step_details=ShieldCallStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
violation=e.violation,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
span.set_attribute("output", e.violation.model_dump_json())
|
||||
|
||||
yield CompletionMessage(
|
||||
content=str(e),
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
yield False
|
||||
|
||||
except SafetyException as e:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
|
@ -305,30 +332,12 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
step_details=ShieldCallStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
violation=e.violation,
|
||||
violation=None,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
yield CompletionMessage(
|
||||
content=str(e),
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
yield False
|
||||
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.shield_call.value,
|
||||
step_details=ShieldCallStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
violation=None,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
span.set_attribute("output", "no violations")
|
||||
|
||||
async def _run(
|
||||
self,
|
||||
|
@ -356,10 +365,15 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
# TODO: find older context from the session and either replace it
|
||||
# or append with a sliding window. this is really a very simplistic implementation
|
||||
with tracing.span("retrieve_rag_context"):
|
||||
with tracing.span("retrieve_rag_context") as span:
|
||||
rag_context, bank_ids = await self._retrieve_context(
|
||||
session_id, input_messages, attachments
|
||||
)
|
||||
span.set_attribute(
|
||||
"input", [m.model_dump_json() for m in input_messages]
|
||||
)
|
||||
span.set_attribute("output", rag_context)
|
||||
span.set_attribute("bank_ids", bank_ids)
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
|
@ -416,7 +430,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
content = ""
|
||||
stop_reason = None
|
||||
|
||||
with tracing.span("inference"):
|
||||
with tracing.span("inference") as span:
|
||||
async for chunk in await self.inference_api.chat_completion(
|
||||
self.agent_config.model,
|
||||
input_messages,
|
||||
|
@ -436,7 +450,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
if isinstance(delta, ToolCallDelta):
|
||||
if delta.parse_status == ToolCallParseStatus.success:
|
||||
tool_calls.append(delta.content)
|
||||
|
||||
if stream:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
|
@ -466,6 +479,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
if event.stop_reason is not None:
|
||||
stop_reason = event.stop_reason
|
||||
span.set_attribute("stop_reason", stop_reason)
|
||||
span.set_attribute(
|
||||
"input", [m.model_dump_json() for m in input_messages]
|
||||
)
|
||||
span.set_attribute(
|
||||
"output", f"content: {content} tool_calls: {tool_calls}"
|
||||
)
|
||||
|
||||
stop_reason = stop_reason or StopReason.out_of_tokens
|
||||
|
||||
|
@ -549,7 +569,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
|
||||
with tracing.span("tool_execution"):
|
||||
with tracing.span(
|
||||
"tool_execution",
|
||||
{
|
||||
"tool_name": tool_call.tool_name,
|
||||
"input": message.model_dump_json(),
|
||||
},
|
||||
) as span:
|
||||
result_messages = await execute_tool_call_maybe(
|
||||
self.tools_dict,
|
||||
[message],
|
||||
|
@ -558,6 +584,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
len(result_messages) == 1
|
||||
), "Currently not supporting multiple messages"
|
||||
result_message = result_messages[0]
|
||||
span.set_attribute("output", result_message.model_dump_json())
|
||||
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue