add endpoint to export traces and standerdize the span creation

This commit is contained in:
Dinesh Yeduguru 2024-11-25 16:01:52 -08:00
parent 54bc5f2d55
commit c6b4bf8ada
4 changed files with 88 additions and 5 deletions

View file

@ -6,7 +6,16 @@
from datetime import datetime
from enum import Enum
from typing import Any, Dict, Literal, Optional, Protocol, runtime_checkable, Union
from typing import (
Any,
Dict,
List,
Literal,
Optional,
Protocol,
runtime_checkable,
Union,
)
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
@ -135,3 +144,8 @@ class Telemetry(Protocol):
@webmethod(route="/telemetry/get-trace", method="GET")
async def get_trace(self, trace_id: str) -> Trace: ...
@webmethod(route="/telemetry/get-traces-for-session", method="POST")
async def get_traces_for_session(
self, session_id: str, lookback: str = "1h", limit: int = 100
) -> List[Trace]: ...

View file

@ -285,8 +285,9 @@ class ChatAgent(ShieldRunnerMixin):
) -> AsyncGenerator:
with tracing.span("run_shields") as span:
span.set_attribute("turn_id", turn_id)
span.set_attribute("messages", [m.model_dump_json() for m in messages])
span.set_attribute("input", [m.model_dump_json() for m in messages])
if len(shields) == 0:
span.set_attribute("output", "no shields")
return
step_id = str(uuid.uuid4())
@ -315,6 +316,7 @@ class ChatAgent(ShieldRunnerMixin):
)
)
)
span.set_attribute("output", e.violation.model_dump_json())
yield CompletionMessage(
content=str(e),
@ -334,6 +336,7 @@ class ChatAgent(ShieldRunnerMixin):
)
)
)
span.set_attribute("output", "no violations")
async def _run(
self,
@ -365,7 +368,10 @@ class ChatAgent(ShieldRunnerMixin):
rag_context, bank_ids = await self._retrieve_context(
session_id, input_messages, attachments
)
span.set_attribute("rag_context", rag_context)
span.set_attribute(
"input", [m.model_dump_json() for m in input_messages]
)
span.set_attribute("output", rag_context)
span.set_attribute("bank_ids", bank_ids)
step_id = str(uuid.uuid4())
@ -473,9 +479,11 @@ class ChatAgent(ShieldRunnerMixin):
if event.stop_reason is not None:
stop_reason = event.stop_reason
span.set_attribute("stop_reason", stop_reason)
span.set_attribute("content", content)
span.set_attribute(
"tool_calls", [tc.model_dump_json() for tc in tool_calls]
"input", [m.model_dump_json() for m in input_messages]
)
span.set_attribute(
"output", f"content: {content} tool_calls: {tool_calls}"
)
stop_reason = stop_reason or StopReason.out_of_tokens

View file

@ -18,6 +18,10 @@ class OpenTelemetryConfig(BaseModel):
default="llama-stack",
description="The service name to use for telemetry",
)
export_endpoint: str = Field(
default="http://localhost:16686/api/traces",
description="The Jaeger query endpoint URL",
)
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:

View file

@ -5,6 +5,9 @@
# the root directory of this source tree.
import threading
from typing import Any, Dict, List
import aiohttp
from opentelemetry import metrics, trace
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
@ -206,3 +209,57 @@ class OpenTelemetryAdapter(Telemetry):
async def get_trace(self, trace_id: str) -> Trace:
raise NotImplementedError("Trace retrieval not implemented yet")
async def get_traces_for_session(
self, session_id: str, lookback: str = "1h", limit: int = 100
) -> List[Dict[str, Any]]:
params = {
"tags": f'{{"session_id":"{session_id}"}}',
"lookback": lookback,
"limit": limit,
"service": self.config.service_name,
}
try:
async with aiohttp.ClientSession() as session:
async with session.get(
self.config.export_endpoint, params=params
) as response:
if response.status != 200:
raise Exception(
f"Failed to query Jaeger: {response.status} {await response.text()}"
)
traces_data = await response.json()
processed_traces = []
for trace_data in traces_data.get("data", []):
trace_steps = []
for span in trace_data.get("spans", []):
step_info = {
"step": span.get("operationName"),
"start_time": span.get("startTime"),
"duration": span.get("duration"),
}
tags = span.get("tags", [])
for tag in tags:
if tag.get("key") == "input":
step_info["input"] = tag.get("value")
elif tag.get("key") == "output":
step_info["output"] = tag.get("value")
# we only want to return steps which have input and output
if step_info.get("input") and step_info.get("output"):
trace_steps.append(step_info)
processed_traces.append(
{
"trace_id": trace_data.get("traceID"),
"steps": trace_steps,
}
)
return processed_traces
except Exception as e:
raise Exception(f"Error querying Jaeger traces: {str(e)}") from e