diff --git a/llama_stack/apis/telemetry/telemetry.py b/llama_stack/apis/telemetry/telemetry.py index 320af3e69..7e8d6bdd3 100644 --- a/llama_stack/apis/telemetry/telemetry.py +++ b/llama_stack/apis/telemetry/telemetry.py @@ -6,7 +6,16 @@ from datetime import datetime from enum import Enum -from typing import Any, Dict, Literal, Optional, Protocol, runtime_checkable, Union +from typing import ( + Any, + Dict, + List, + Literal, + Optional, + Protocol, + runtime_checkable, + Union, +) from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field @@ -135,3 +144,8 @@ class Telemetry(Protocol): @webmethod(route="/telemetry/get-trace", method="GET") async def get_trace(self, trace_id: str) -> Trace: ... + + @webmethod(route="/telemetry/get-traces-for-session", method="POST") + async def get_traces_for_session( + self, session_id: str, lookback: str = "1h", limit: int = 100 + ) -> List[Trace]: ... diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 7e2979c24..dddb34b5d 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -285,8 +285,9 @@ class ChatAgent(ShieldRunnerMixin): ) -> AsyncGenerator: with tracing.span("run_shields") as span: span.set_attribute("turn_id", turn_id) - span.set_attribute("messages", [m.model_dump_json() for m in messages]) + span.set_attribute("input", [m.model_dump_json() for m in messages]) if len(shields) == 0: + span.set_attribute("output", "no shields") return step_id = str(uuid.uuid4()) @@ -315,6 +316,7 @@ class ChatAgent(ShieldRunnerMixin): ) ) ) + span.set_attribute("output", e.violation.model_dump_json()) yield CompletionMessage( content=str(e), @@ -334,6 +336,7 @@ class ChatAgent(ShieldRunnerMixin): ) ) ) + span.set_attribute("output", "no violations") async def _run( self, @@ -365,7 +368,10 @@ class ChatAgent(ShieldRunnerMixin): rag_context, bank_ids = await self._retrieve_context( session_id, input_messages, attachments ) - span.set_attribute("rag_context", rag_context) + span.set_attribute( + "input", [m.model_dump_json() for m in input_messages] + ) + span.set_attribute("output", rag_context) span.set_attribute("bank_ids", bank_ids) step_id = str(uuid.uuid4()) @@ -473,9 +479,11 @@ class ChatAgent(ShieldRunnerMixin): if event.stop_reason is not None: stop_reason = event.stop_reason span.set_attribute("stop_reason", stop_reason) - span.set_attribute("content", content) span.set_attribute( - "tool_calls", [tc.model_dump_json() for tc in tool_calls] + "input", [m.model_dump_json() for m in input_messages] + ) + span.set_attribute( + "output", f"content: {content} tool_calls: {tool_calls}" ) stop_reason = stop_reason or StopReason.out_of_tokens diff --git a/llama_stack/providers/remote/telemetry/opentelemetry/config.py b/llama_stack/providers/remote/telemetry/opentelemetry/config.py index 5e9dff1a1..9c4f8546f 100644 --- a/llama_stack/providers/remote/telemetry/opentelemetry/config.py +++ b/llama_stack/providers/remote/telemetry/opentelemetry/config.py @@ -18,6 +18,10 @@ class OpenTelemetryConfig(BaseModel): default="llama-stack", description="The service name to use for telemetry", ) + export_endpoint: str = Field( + default="http://localhost:16686/api/traces", + description="The Jaeger query endpoint URL", + ) @classmethod def sample_run_config(cls, **kwargs) -> Dict[str, Any]: diff --git a/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py b/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py index c9830fd9d..d69d3d0d8 100644 --- a/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py +++ b/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py @@ -5,6 +5,9 @@ # the root directory of this source tree. import threading +from typing import Any, Dict, List + +import aiohttp from opentelemetry import metrics, trace from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter @@ -206,3 +209,57 @@ class OpenTelemetryAdapter(Telemetry): async def get_trace(self, trace_id: str) -> Trace: raise NotImplementedError("Trace retrieval not implemented yet") + + async def get_traces_for_session( + self, session_id: str, lookback: str = "1h", limit: int = 100 + ) -> List[Dict[str, Any]]: + params = { + "tags": f'{{"session_id":"{session_id}"}}', + "lookback": lookback, + "limit": limit, + "service": self.config.service_name, + } + + try: + async with aiohttp.ClientSession() as session: + async with session.get( + self.config.export_endpoint, params=params + ) as response: + if response.status != 200: + raise Exception( + f"Failed to query Jaeger: {response.status} {await response.text()}" + ) + + traces_data = await response.json() + processed_traces = [] + + for trace_data in traces_data.get("data", []): + trace_steps = [] + for span in trace_data.get("spans", []): + step_info = { + "step": span.get("operationName"), + "start_time": span.get("startTime"), + "duration": span.get("duration"), + } + + tags = span.get("tags", []) + for tag in tags: + if tag.get("key") == "input": + step_info["input"] = tag.get("value") + elif tag.get("key") == "output": + step_info["output"] = tag.get("value") + # we only want to return steps which have input and output + if step_info.get("input") and step_info.get("output"): + trace_steps.append(step_info) + + processed_traces.append( + { + "trace_id": trace_data.get("traceID"), + "steps": trace_steps, + } + ) + + return processed_traces + + except Exception as e: + raise Exception(f"Error querying Jaeger traces: {str(e)}") from e