Telemetry API redesign (#525)

# What does this PR do?
Change the Telemetry API to be able to support different use cases like
returning traces for the UI and ability to export for Evals.
Other changes:
* Add a new trace_protocol decorator to decorate all our API methods so
that any call to them will automatically get traced across all impls.
* There is some issue with the decorator pattern of span creation when
using async generators, where there are multiple yields with in the same
context. I think its much more explicit by using the explicit context
manager pattern using with. I moved the span creations in agent instance
to be using with
* Inject session id at the turn level, which should quickly give us all
traces across turns for a given session

Addresses #509

## Test Plan
```
llama stack run /Users/dineshyv/.llama/distributions/llamastack-together/together-run.yaml
PYTHONPATH=. python -m examples.agents.rag_with_memory_bank localhost 5000


 curl -X POST 'http://localhost:5000/alpha/telemetry/query-traces' \
-H 'Content-Type: application/json' \
-d '{
  "attribute_filters": [
    {
      "key": "session_id",
      "op": "eq",
      "value": "dd667b87-ca4b-4d30-9265-5a0de318fc65" }],
  "limit": 100,
  "offset": 0,
  "order_by": ["start_time"]
}' | jq .
[
  {
    "trace_id": "6902f54b83b4b48be18a6f422b13e16f",
    "root_span_id": "5f37b85543afc15a",
    "start_time": "2024-12-04T08:08:30.501587",
    "end_time": "2024-12-04T08:08:36.026463"
  },
  {
    "trace_id": "92227dac84c0615ed741be393813fb5f",
    "root_span_id": "af7c5bb46665c2c8",
    "start_time": "2024-12-04T08:08:36.031170",
    "end_time": "2024-12-04T08:08:41.693301"
  },
  {
    "trace_id": "7d578a6edac62f204ab479fba82f77b6",
    "root_span_id": "1d935e3362676896",
    "start_time": "2024-12-04T08:08:41.695204",
    "end_time": "2024-12-04T08:08:47.228016"
  },
  {
    "trace_id": "dbd767d76991bc816f9f078907dc9ff2",
    "root_span_id": "f5a7ee76683b9602",
    "start_time": "2024-12-04T08:08:47.234578",
    "end_time": "2024-12-04T08:08:53.189412"
  }
]


curl -X POST 'http://localhost:5000/alpha/telemetry/get-span-tree' \
-H 'Content-Type: application/json' \
-d '{ "span_id" : "6cceb4b48a156913", "max_depth": 2, "attributes_to_return": ["input"] }' | jq .
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100   875  100   790  100    85  18462   1986 --:--:-- --:--:-- --:--:-- 20833
{
  "span_id": "6cceb4b48a156913",
  "trace_id": "dafa796f6aaf925f511c04cd7c67fdda",
  "parent_span_id": "892a66d726c7f990",
  "name": "retrieve_rag_context",
  "start_time": "2024-12-04T09:28:21.781995",
  "end_time": "2024-12-04T09:28:21.913352",
  "attributes": {
    "input": [
      "{\"role\":\"system\",\"content\":\"You are a helpful assistant\"}",
      "{\"role\":\"user\",\"content\":\"What are the top 5 topics that were explained in the documentation? Only list succinct bullet points.\",\"context\":null}"
    ]
  },
  "children": [
    {
      "span_id": "1a2df181854064a8",
      "trace_id": "dafa796f6aaf925f511c04cd7c67fdda",
      "parent_span_id": "6cceb4b48a156913",
      "name": "MemoryRouter.query_documents",
      "start_time": "2024-12-04T09:28:21.787620",
      "end_time": "2024-12-04T09:28:21.906512",
      "attributes": {
        "input": null
      },
      "children": [],
      "status": "ok"
    }
  ],
  "status": "ok"
}

```

<img width="1677" alt="Screenshot 2024-12-04 at 9 42 56 AM"
src="https://github.com/user-attachments/assets/4d3cea93-05ce-415a-93d9-4b1628631bf8">
This commit is contained in:
Dinesh Yeduguru 2024-12-04 11:22:45 -08:00 committed by GitHub
parent 16769256b7
commit fcd6449519
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
34 changed files with 1551 additions and 245 deletions

View file

@ -144,87 +144,91 @@ class ChatAgent(ShieldRunnerMixin):
async def create_session(self, name: str) -> str:
return await self.storage.create_session(name)
@tracing.span("create_and_execute_turn")
async def create_and_execute_turn(
self, request: AgentTurnCreateRequest
) -> AsyncGenerator:
assert request.stream is True, "Non-streaming not supported"
with tracing.span("create_and_execute_turn") as span:
span.set_attribute("session_id", request.session_id)
span.set_attribute("agent_id", self.agent_id)
span.set_attribute("request", request.model_dump_json())
assert request.stream is True, "Non-streaming not supported"
session_info = await self.storage.get_session_info(request.session_id)
if session_info is None:
raise ValueError(f"Session {request.session_id} not found")
session_info = await self.storage.get_session_info(request.session_id)
if session_info is None:
raise ValueError(f"Session {request.session_id} not found")
turns = await self.storage.get_session_turns(request.session_id)
turns = await self.storage.get_session_turns(request.session_id)
messages = []
if self.agent_config.instructions != "":
messages.append(SystemMessage(content=self.agent_config.instructions))
messages = []
if self.agent_config.instructions != "":
messages.append(SystemMessage(content=self.agent_config.instructions))
for i, turn in enumerate(turns):
messages.extend(self.turn_to_messages(turn))
for i, turn in enumerate(turns):
messages.extend(self.turn_to_messages(turn))
messages.extend(request.messages)
messages.extend(request.messages)
turn_id = str(uuid.uuid4())
start_time = datetime.now()
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnStartPayload(
turn_id=turn_id,
turn_id = str(uuid.uuid4())
span.set_attribute("turn_id", turn_id)
start_time = datetime.now()
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnStartPayload(
turn_id=turn_id,
)
)
)
)
steps = []
output_message = None
async for chunk in self.run(
session_id=request.session_id,
turn_id=turn_id,
input_messages=messages,
attachments=request.attachments or [],
sampling_params=self.agent_config.sampling_params,
stream=request.stream,
):
if isinstance(chunk, CompletionMessage):
log.info(
f"{chunk.role.capitalize()}: {chunk.content}",
)
output_message = chunk
continue
assert isinstance(
chunk, AgentTurnResponseStreamChunk
), f"Unexpected type {type(chunk)}"
event = chunk.event
if (
event.payload.event_type
== AgentTurnResponseEventType.step_complete.value
steps = []
output_message = None
async for chunk in self.run(
session_id=request.session_id,
turn_id=turn_id,
input_messages=messages,
attachments=request.attachments or [],
sampling_params=self.agent_config.sampling_params,
stream=request.stream,
):
steps.append(event.payload.step_details)
if isinstance(chunk, CompletionMessage):
log.info(
f"{chunk.role.capitalize()}: {chunk.content}",
)
output_message = chunk
continue
yield chunk
assert isinstance(
chunk, AgentTurnResponseStreamChunk
), f"Unexpected type {type(chunk)}"
event = chunk.event
if (
event.payload.event_type
== AgentTurnResponseEventType.step_complete.value
):
steps.append(event.payload.step_details)
assert output_message is not None
yield chunk
turn = Turn(
turn_id=turn_id,
session_id=request.session_id,
input_messages=request.messages,
output_message=output_message,
started_at=start_time,
completed_at=datetime.now(),
steps=steps,
)
await self.storage.add_turn_to_session(request.session_id, turn)
assert output_message is not None
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnCompletePayload(
turn=turn,
turn = Turn(
turn_id=turn_id,
session_id=request.session_id,
input_messages=request.messages,
output_message=output_message,
started_at=start_time,
completed_at=datetime.now(),
steps=steps,
)
await self.storage.add_turn_to_session(request.session_id, turn)
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnCompletePayload(
turn=turn,
)
)
)
)
yield chunk
yield chunk
async def run(
self,
@ -273,7 +277,6 @@ class ChatAgent(ShieldRunnerMixin):
yield final_response
@tracing.span("run_shields")
async def run_multiple_shields_wrapper(
self,
turn_id: str,
@ -281,23 +284,47 @@ class ChatAgent(ShieldRunnerMixin):
shields: List[str],
touchpoint: str,
) -> AsyncGenerator:
if len(shields) == 0:
return
with tracing.span("run_shields") as span:
span.set_attribute("turn_id", turn_id)
span.set_attribute("input", [m.model_dump_json() for m in messages])
if len(shields) == 0:
span.set_attribute("output", "no shields")
return
step_id = str(uuid.uuid4())
try:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.shield_call.value,
step_id=step_id,
metadata=dict(touchpoint=touchpoint),
step_id = str(uuid.uuid4())
try:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.shield_call.value,
step_id=step_id,
metadata=dict(touchpoint=touchpoint),
)
)
)
)
await self.run_multiple_shields(messages, shields)
await self.run_multiple_shields(messages, shields)
except SafetyException as e:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value,
step_details=ShieldCallStep(
step_id=step_id,
turn_id=turn_id,
violation=e.violation,
),
)
)
)
span.set_attribute("output", e.violation.model_dump_json())
yield CompletionMessage(
content=str(e),
stop_reason=StopReason.end_of_turn,
)
yield False
except SafetyException as e:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
@ -305,30 +332,12 @@ class ChatAgent(ShieldRunnerMixin):
step_details=ShieldCallStep(
step_id=step_id,
turn_id=turn_id,
violation=e.violation,
violation=None,
),
)
)
)
yield CompletionMessage(
content=str(e),
stop_reason=StopReason.end_of_turn,
)
yield False
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value,
step_details=ShieldCallStep(
step_id=step_id,
turn_id=turn_id,
violation=None,
),
)
)
)
span.set_attribute("output", "no violations")
async def _run(
self,
@ -356,10 +365,15 @@ class ChatAgent(ShieldRunnerMixin):
# TODO: find older context from the session and either replace it
# or append with a sliding window. this is really a very simplistic implementation
with tracing.span("retrieve_rag_context"):
with tracing.span("retrieve_rag_context") as span:
rag_context, bank_ids = await self._retrieve_context(
session_id, input_messages, attachments
)
span.set_attribute(
"input", [m.model_dump_json() for m in input_messages]
)
span.set_attribute("output", rag_context)
span.set_attribute("bank_ids", bank_ids)
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
@ -416,7 +430,7 @@ class ChatAgent(ShieldRunnerMixin):
content = ""
stop_reason = None
with tracing.span("inference"):
with tracing.span("inference") as span:
async for chunk in await self.inference_api.chat_completion(
self.agent_config.model,
input_messages,
@ -436,7 +450,6 @@ class ChatAgent(ShieldRunnerMixin):
if isinstance(delta, ToolCallDelta):
if delta.parse_status == ToolCallParseStatus.success:
tool_calls.append(delta.content)
if stream:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
@ -466,6 +479,13 @@ class ChatAgent(ShieldRunnerMixin):
if event.stop_reason is not None:
stop_reason = event.stop_reason
span.set_attribute("stop_reason", stop_reason)
span.set_attribute(
"input", [m.model_dump_json() for m in input_messages]
)
span.set_attribute(
"output", f"content: {content} tool_calls: {tool_calls}"
)
stop_reason = stop_reason or StopReason.out_of_tokens
@ -549,7 +569,13 @@ class ChatAgent(ShieldRunnerMixin):
)
)
with tracing.span("tool_execution"):
with tracing.span(
"tool_execution",
{
"tool_name": tool_call.tool_name,
"input": message.model_dump_json(),
},
) as span:
result_messages = await execute_tool_call_maybe(
self.tools_dict,
[message],
@ -558,6 +584,7 @@ class ChatAgent(ShieldRunnerMixin):
len(result_messages) == 1
), "Currently not supporting multiple messages"
result_message = result_messages[0]
span.set_attribute("output", result_message.model_dump_json())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(