diff --git a/llama_stack/apis/telemetry/telemetry.py b/llama_stack/apis/telemetry/telemetry.py index 864380e9f..8119c3c4d 100644 --- a/llama_stack/apis/telemetry/telemetry.py +++ b/llama_stack/apis/telemetry/telemetry.py @@ -143,21 +143,52 @@ class EvalTrace(BaseModel): step: str input: str output: str + expected_output: str + + +@json_schema_type +class SpanNode(BaseModel): + span: Span + children: List["SpanNode"] = Field(default_factory=list) + status: Optional[SpanStatus] = None + + +@json_schema_type +class TraceTree(BaseModel): + trace: Trace + root: Optional[SpanNode] = None + + +class TraceStore(Protocol): + async def get_trace( + self, + trace_id: str, + ) -> TraceTree: ... + + async def get_traces_for_sessions( + self, + session_ids: List[str], + ) -> [Trace]: ... @runtime_checkable class Telemetry(Protocol): + @webmethod(route="/telemetry/log-event") async def log_event(self, event: Event) -> None: ... @webmethod(route="/telemetry/get-trace", method="POST") - async def get_trace(self, trace_id: str) -> Trace: ... + async def get_trace(self, trace_id: str) -> TraceTree: ... - @webmethod(route="/telemetry/get-traces-for-agent-eval", method="POST") - async def get_traces_for_agent_eval( + @webmethod(route="/telemetry/get-agent-trace", method="POST") + async def get_agent_trace( self, session_ids: List[str], - lookback: str = "1h", - limit: int = 100, - dataset_id: Optional[str] = None, ) -> List[EvalTrace]: ... + + @webmethod(route="/telemetry/export-agent-trace", method="POST") + async def export_agent_trace( + self, + session_ids: List[str], + dataset_id: str = None, + ) -> None: ... diff --git a/llama_stack/providers/inline/meta_reference/telemetry/console.py b/llama_stack/providers/inline/meta_reference/telemetry/console.py index e6dbdda64..61f75e911 100644 --- a/llama_stack/providers/inline/meta_reference/telemetry/console.py +++ b/llama_stack/providers/inline/meta_reference/telemetry/console.py @@ -7,8 +7,6 @@ import json from typing import List, Optional -from llama_stack.apis.telemetry.telemetry import Trace - from .config import LogFormat from llama_stack.apis.telemetry import * # noqa: F403 @@ -51,13 +49,25 @@ class ConsoleTelemetryImpl(Telemetry): if formatted: print(formatted) - async def get_trace(self, trace_id: str) -> Trace: - raise NotImplementedError() + async def get_trace(self, trace_id: str) -> TraceTree: + raise NotImplementedError("Console telemetry does not support trace retrieval") - async def get_traces_for_agent_eval( - self, session_ids: List[str], lookback: str = "1h", limit: int = 100 + async def get_agent_trace( + self, + session_ids: List[str], ) -> List[EvalTrace]: - raise NotImplementedError() + raise NotImplementedError( + "Console telemetry does not support agent trace retrieval" + ) + + async def export_agent_trace( + self, + session_ids: List[str], + dataset_id: str = None, + ) -> None: + raise NotImplementedError( + "Console telemetry does not support agent trace export" + ) COLORS = { diff --git a/llama_stack/providers/remote/telemetry/opentelemetry/__init__.py b/llama_stack/providers/remote/telemetry/opentelemetry/__init__.py index 56050411d..04594c040 100644 --- a/llama_stack/providers/remote/telemetry/opentelemetry/__init__.py +++ b/llama_stack/providers/remote/telemetry/opentelemetry/__init__.py @@ -4,12 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from llama_stack.providers.utils.telemetry.jaeger import JaegerTraceStore from .config import OpenTelemetryConfig async def get_adapter_impl(config: OpenTelemetryConfig, deps): from .opentelemetry import OpenTelemetryAdapter - impl = OpenTelemetryAdapter(config, deps) + trace_store = JaegerTraceStore(config.jaeger_query_endpoint, config.service_name) + impl = OpenTelemetryAdapter(config, trace_store, deps) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/telemetry/opentelemetry/config.py b/llama_stack/providers/remote/telemetry/opentelemetry/config.py index 9c4f8546f..81c1aed4f 100644 --- a/llama_stack/providers/remote/telemetry/opentelemetry/config.py +++ b/llama_stack/providers/remote/telemetry/opentelemetry/config.py @@ -18,7 +18,7 @@ class OpenTelemetryConfig(BaseModel): default="llama-stack", description="The service name to use for telemetry", ) - export_endpoint: str = Field( + jaeger_query_endpoint: str = Field( default="http://localhost:16686/api/traces", description="The Jaeger query endpoint URL", ) diff --git a/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py b/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py index d944e0b4f..9a5d1b8c2 100644 --- a/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py +++ b/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py @@ -5,9 +5,7 @@ # the root directory of this source tree. import threading -from typing import Any, Dict, List, Optional - -import aiohttp +from typing import List from opentelemetry import metrics, trace from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter @@ -51,9 +49,12 @@ def is_tracing_enabled(tracer): class OpenTelemetryAdapter(Telemetry): - def __init__(self, config: OpenTelemetryConfig, deps) -> None: + def __init__( + self, config: OpenTelemetryConfig, trace_store: TraceStore, deps + ) -> None: self.config = config self.datasetio = deps[Api.datasetio] + self.trace_store = trace_store resource = Resource.create( { @@ -202,157 +203,67 @@ class OpenTelemetryAdapter(Telemetry): span.end() _GLOBAL_STORAGE["active_spans"].pop(span_id, None) - async def get_traces_for_agent_eval( + async def get_trace(self, trace_id: str) -> TraceTree: + return await self.trace_store.get_trace(trace_id) + + async def get_agent_trace( self, session_ids: List[str], - lookback: str = "1h", - limit: int = 100, - dataset_id: Optional[str] = None, ) -> List[EvalTrace]: traces = [] - - # Fetch traces for each session ID individually for session_id in session_ids: - params = { - "service": self.config.service_name, - "lookback": lookback, - "limit": limit, - "tags": f'{{"session_id":"{session_id}"}}', - } + traces_for_session = await self.trace_store.get_traces_for_sessions( + [session_id] + ) + for session_trace in traces_for_session: + trace_details = await self._get_simplified_agent_trace( + session_trace.trace_id, session_id + ) + traces.extend(trace_details) - try: - async with aiohttp.ClientSession() as session: - async with session.get( - self.config.export_endpoint, params=params - ) as response: - if response.status != 200: - raise Exception( - f"Failed to query Jaeger: {response.status} {await response.text()}" - ) - - traces_data = await response.json() - seen_trace_ids = set() - - for trace_data in traces_data.get("data", []): - trace_id = trace_data.get("traceID") - if trace_id and trace_id not in seen_trace_ids: - seen_trace_ids.add(trace_id) - trace_details = await self.get_trace_for_eval( - trace_id, session_id - ) - traces.extend(trace_details) - - except Exception as e: - raise Exception(f"Error querying Jaeger traces: {str(e)}") from e - - if dataset_id: - traces_dict = [ - { - "step": trace.step, - "input": trace.input, - "output": trace.output, - "session_id": trace.session_id, - } - for trace in traces - ] - await self.datasetio.upload_rows(dataset_id, traces_dict) return traces - async def get_trace(self, trace_id: str) -> Dict[str, Any]: - params = { - "traceID": trace_id, - } + async def export_agent_trace( + self, session_ids: List[str], dataset_id: str = None + ) -> None: + traces = await self.get_agent_trace(session_ids) + traces_dict = [ + { + "step": trace.step, + "input": trace.input, + "output": trace.output, + "session_id": trace.session_id, + } + for trace in traces + ] + await self.datasetio.upload_rows(dataset_id, traces_dict) - try: - async with aiohttp.ClientSession() as session: - async with session.get( - f"{self.config.export_endpoint}/{trace_id}", params=params - ) as response: - if response.status != 200: - raise Exception( - f"Failed to query Jaeger: {response.status} {await response.text()}" - ) - - trace_data = await response.json() - if not trace_data.get("data") or not trace_data["data"]: - return None - - # First pass: Build span map - span_map = {} - for span in trace_data["data"][0]["spans"]: - start_time = span["startTime"] - end_time = start_time + span.get( - "duration", 0 - ) # Get end time from duration if available - - # Some systems store end time directly in the span - if "endTime" in span: - end_time = span["endTime"] - duration = end_time - start_time - else: - duration = span.get("duration", 0) - - span_map[span["spanID"]] = { - "id": span["spanID"], - "name": span["operationName"], - "start_time": start_time, - "end_time": end_time, - "duration": duration, - "tags": { - tag["key"]: tag["value"] for tag in span.get("tags", []) - }, - "children": [], - } - - # Second pass: Build parent-child relationships - root_spans = [] - for span in trace_data["data"][0]["spans"]: - references = span.get("references", []) - if references and references[0]["refType"] == "CHILD_OF": - parent_id = references[0]["spanID"] - if parent_id in span_map: - span_map[parent_id]["children"].append( - span_map[span["spanID"]] - ) - else: - root_spans.append(span_map[span["spanID"]]) - - return { - "trace_id": trace_id, - "spans": root_spans, - } - - except Exception as e: - raise Exception(f"Error querying Jaeger trace structure: {str(e)}") from e - - async def get_trace_for_eval( + async def _get_simplified_agent_trace( self, trace_id: str, session_id: str ) -> List[EvalTrace]: - """ - Get simplified trace information focusing on first-level children of create_and_execute_turn operations. - Returns a list of spans with name, input, and output information, sorted by start time. - """ - trace_data = await self.get_trace(trace_id) - if not trace_data: + trace_tree = await self.get_trace(trace_id) + if not trace_tree or not trace_tree.root: return [] - def find_execute_turn_children(spans: List[Dict[str, Any]]) -> List[EvalTrace]: - results: List[EvalTrace] = [] - for span in spans: - if span["name"] == "create_and_execute_turn": - # Extract and format children spans - children = sorted(span["children"], key=lambda x: x["start_time"]) - for child in children: - results.append( - EvalTrace( - step=child["name"], - input=child["tags"].get("input", ""), - output=child["tags"].get("output", ""), - session_id=session_id, - ) + def find_execute_turn_children(node: SpanNode) -> List[EvalTrace]: + results = [] + if node.span.name == "create_and_execute_turn": + # Sort children by start time + sorted_children = sorted(node.children, key=lambda x: x.span.start_time) + for child in sorted_children: + results.append( + EvalTrace( + step=child.span.name, + input=child.span.attributes.get("input", ""), + output=child.span.attributes.get("output", ""), + session_id=session_id, + expected_output="", ) - # Recursively search in children - results.extend(find_execute_turn_children(span["children"])) + ) + + # Recursively process children + for child in node.children: + results.extend(find_execute_turn_children(child)) return results - return find_execute_turn_children(trace_data["spans"]) + return find_execute_turn_children(trace_tree.root) diff --git a/llama_stack/providers/utils/telemetry/jaeger.py b/llama_stack/providers/utils/telemetry/jaeger.py new file mode 100644 index 000000000..3c28748ed --- /dev/null +++ b/llama_stack/providers/utils/telemetry/jaeger.py @@ -0,0 +1,141 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from datetime import datetime, timedelta +from typing import List + +import aiohttp + +from llama_stack.apis.telemetry import Span, SpanNode, Trace, TraceStore, TraceTree + + +class JaegerTraceStore(TraceStore): + def __init__(self, endpoint: str, service_name: str): + self.endpoint = endpoint + self.service_name = service_name + + async def get_trace(self, trace_id: str) -> TraceTree: + params = { + "traceID": trace_id, + } + + try: + async with aiohttp.ClientSession() as session: + async with session.get( + f"{self.endpoint}/{trace_id}", params=params + ) as response: + if response.status != 200: + raise Exception( + f"Failed to query Jaeger: {response.status} {await response.text()}" + ) + + trace_data = await response.json() + if not trace_data.get("data") or not trace_data["data"]: + return None + + # First pass: Build span map + span_map = {} + for jaeger_span in trace_data["data"][0]["spans"]: + start_time = datetime.fromtimestamp( + jaeger_span["startTime"] / 1000000 + ) + + # Some systems store end time directly in the span + if "endTime" in jaeger_span: + end_time = datetime.fromtimestamp( + jaeger_span["endTime"] / 1000000 + ) + else: + duration_microseconds = jaeger_span.get("duration", 0) + duration_timedelta = timedelta( + microseconds=duration_microseconds + ) + end_time = start_time + duration_timedelta + + span = Span( + span_id=jaeger_span["spanID"], + trace_id=trace_id, + name=jaeger_span["operationName"], + start_time=start_time, + end_time=end_time, + parent_span_id=next( + ( + ref["spanID"] + for ref in jaeger_span.get("references", []) + if ref["refType"] == "CHILD_OF" + ), + None, + ), + attributes={ + tag["key"]: tag["value"] + for tag in jaeger_span.get("tags", []) + }, + ) + + span_map[span.span_id] = SpanNode(span=span) + + # Second pass: Build parent-child relationships + root_node = None + for span_node in span_map.values(): + parent_id = span_node.span.parent_span_id + if parent_id and parent_id in span_map: + span_map[parent_id].children.append(span_node) + elif not parent_id: + root_node = span_node + + trace = Trace( + trace_id=trace_id, + root_span_id=root_node.span.span_id if root_node else "", + start_time=( + root_node.span.start_time if root_node else datetime.now() + ), + end_time=root_node.span.end_time if root_node else None, + ) + + return TraceTree(trace=trace, root=root_node) + + except Exception as e: + raise Exception(f"Error querying Jaeger trace structure: {str(e)}") from e + + async def get_traces_for_sessions(self, session_ids: List[str]) -> List[Trace]: + traces = [] + + # Fetch traces for each session ID individually + for session_id in session_ids: + params = { + "service": self.service_name, + "tags": f'{{"session_id":"{session_id}"}}', + "limit": 100, + "lookback": "10000h", + } + + try: + async with aiohttp.ClientSession() as session: + async with session.get(self.endpoint, params=params) as response: + if response.status != 200: + raise Exception( + f"Failed to query Jaeger: {response.status} {await response.text()}" + ) + + traces_data = await response.json() + seen_trace_ids = set() + + for trace_data in traces_data.get("data", []): + trace_id = trace_data.get("traceID") + if trace_id and trace_id not in seen_trace_ids: + seen_trace_ids.add(trace_id) + traces.append( + Trace( + trace_id=trace_id, + root_span_id="", + start_time=datetime.now(), + ) + ) + + except Exception as e: + raise Exception(f"Error querying Jaeger traces: {str(e)}") from e + + return traces