add endpoint to export traces and standerdize the span creation

This commit is contained in:
Dinesh Yeduguru 2024-11-25 16:01:52 -08:00
parent 54bc5f2d55
commit c6b4bf8ada
4 changed files with 88 additions and 5 deletions

View file

@ -6,7 +6,16 @@
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import Any, Dict, Literal, Optional, Protocol, runtime_checkable, Union from typing import (
Any,
Dict,
List,
Literal,
Optional,
Protocol,
runtime_checkable,
Union,
)
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -135,3 +144,8 @@ class Telemetry(Protocol):
@webmethod(route="/telemetry/get-trace", method="GET") @webmethod(route="/telemetry/get-trace", method="GET")
async def get_trace(self, trace_id: str) -> Trace: ... async def get_trace(self, trace_id: str) -> Trace: ...
@webmethod(route="/telemetry/get-traces-for-session", method="POST")
async def get_traces_for_session(
self, session_id: str, lookback: str = "1h", limit: int = 100
) -> List[Trace]: ...

View file

@ -285,8 +285,9 @@ class ChatAgent(ShieldRunnerMixin):
) -> AsyncGenerator: ) -> AsyncGenerator:
with tracing.span("run_shields") as span: with tracing.span("run_shields") as span:
span.set_attribute("turn_id", turn_id) span.set_attribute("turn_id", turn_id)
span.set_attribute("messages", [m.model_dump_json() for m in messages]) span.set_attribute("input", [m.model_dump_json() for m in messages])
if len(shields) == 0: if len(shields) == 0:
span.set_attribute("output", "no shields")
return return
step_id = str(uuid.uuid4()) step_id = str(uuid.uuid4())
@ -315,6 +316,7 @@ class ChatAgent(ShieldRunnerMixin):
) )
) )
) )
span.set_attribute("output", e.violation.model_dump_json())
yield CompletionMessage( yield CompletionMessage(
content=str(e), content=str(e),
@ -334,6 +336,7 @@ class ChatAgent(ShieldRunnerMixin):
) )
) )
) )
span.set_attribute("output", "no violations")
async def _run( async def _run(
self, self,
@ -365,7 +368,10 @@ class ChatAgent(ShieldRunnerMixin):
rag_context, bank_ids = await self._retrieve_context( rag_context, bank_ids = await self._retrieve_context(
session_id, input_messages, attachments session_id, input_messages, attachments
) )
span.set_attribute("rag_context", rag_context) span.set_attribute(
"input", [m.model_dump_json() for m in input_messages]
)
span.set_attribute("output", rag_context)
span.set_attribute("bank_ids", bank_ids) span.set_attribute("bank_ids", bank_ids)
step_id = str(uuid.uuid4()) step_id = str(uuid.uuid4())
@ -473,9 +479,11 @@ class ChatAgent(ShieldRunnerMixin):
if event.stop_reason is not None: if event.stop_reason is not None:
stop_reason = event.stop_reason stop_reason = event.stop_reason
span.set_attribute("stop_reason", stop_reason) span.set_attribute("stop_reason", stop_reason)
span.set_attribute("content", content)
span.set_attribute( span.set_attribute(
"tool_calls", [tc.model_dump_json() for tc in tool_calls] "input", [m.model_dump_json() for m in input_messages]
)
span.set_attribute(
"output", f"content: {content} tool_calls: {tool_calls}"
) )
stop_reason = stop_reason or StopReason.out_of_tokens stop_reason = stop_reason or StopReason.out_of_tokens

View file

@ -18,6 +18,10 @@ class OpenTelemetryConfig(BaseModel):
default="llama-stack", default="llama-stack",
description="The service name to use for telemetry", description="The service name to use for telemetry",
) )
export_endpoint: str = Field(
default="http://localhost:16686/api/traces",
description="The Jaeger query endpoint URL",
)
@classmethod @classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]: def sample_run_config(cls, **kwargs) -> Dict[str, Any]:

View file

@ -5,6 +5,9 @@
# the root directory of this source tree. # the root directory of this source tree.
import threading import threading
from typing import Any, Dict, List
import aiohttp
from opentelemetry import metrics, trace from opentelemetry import metrics, trace
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
@ -206,3 +209,57 @@ class OpenTelemetryAdapter(Telemetry):
async def get_trace(self, trace_id: str) -> Trace: async def get_trace(self, trace_id: str) -> Trace:
raise NotImplementedError("Trace retrieval not implemented yet") raise NotImplementedError("Trace retrieval not implemented yet")
async def get_traces_for_session(
self, session_id: str, lookback: str = "1h", limit: int = 100
) -> List[Dict[str, Any]]:
params = {
"tags": f'{{"session_id":"{session_id}"}}',
"lookback": lookback,
"limit": limit,
"service": self.config.service_name,
}
try:
async with aiohttp.ClientSession() as session:
async with session.get(
self.config.export_endpoint, params=params
) as response:
if response.status != 200:
raise Exception(
f"Failed to query Jaeger: {response.status} {await response.text()}"
)
traces_data = await response.json()
processed_traces = []
for trace_data in traces_data.get("data", []):
trace_steps = []
for span in trace_data.get("spans", []):
step_info = {
"step": span.get("operationName"),
"start_time": span.get("startTime"),
"duration": span.get("duration"),
}
tags = span.get("tags", [])
for tag in tags:
if tag.get("key") == "input":
step_info["input"] = tag.get("value")
elif tag.get("key") == "output":
step_info["output"] = tag.get("value")
# we only want to return steps which have input and output
if step_info.get("input") and step_info.get("output"):
trace_steps.append(step_info)
processed_traces.append(
{
"trace_id": trace_data.get("traceID"),
"steps": trace_steps,
}
)
return processed_traces
except Exception as e:
raise Exception(f"Error querying Jaeger traces: {str(e)}") from e