diff --git a/llama_stack/apis/telemetry/telemetry.py b/llama_stack/apis/telemetry/telemetry.py index 7e8d6bdd3..c9d939d2d 100644 --- a/llama_stack/apis/telemetry/telemetry.py +++ b/llama_stack/apis/telemetry/telemetry.py @@ -142,10 +142,14 @@ class Telemetry(Protocol): @webmethod(route="/telemetry/log-event") async def log_event(self, event: Event) -> None: ... - @webmethod(route="/telemetry/get-trace", method="GET") + @webmethod(route="/telemetry/get-trace", method="POST") async def get_trace(self, trace_id: str) -> Trace: ... - @webmethod(route="/telemetry/get-traces-for-session", method="POST") - async def get_traces_for_session( - self, session_id: str, lookback: str = "1h", limit: int = 100 + @webmethod(route="/telemetry/get-traces-for-eval", method="POST") + async def get_traces_for_eval( + self, + session_ids: List[str], + lookback: str = "1h", + limit: int = 100, + dataset_id: Optional[str] = None, ) -> List[Trace]: ... diff --git a/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py b/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py index d69d3d0d8..c4e4afabe 100644 --- a/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py +++ b/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import threading -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import aiohttp @@ -173,7 +173,6 @@ class OpenTelemetryAdapter(Telemetry): parent_span_id = string_to_span_id(event.payload.parent_span_id) parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id) - # Create a new trace context with the trace_id context = trace.Context(trace_id=trace_id) if parent_span: context = trace.set_span_in_context(parent_span, context) @@ -182,14 +181,9 @@ class OpenTelemetryAdapter(Telemetry): name=event.payload.name, context=context, attributes=event.attributes or {}, - start_time=int(event.timestamp.timestamp() * 1e9), ) _GLOBAL_STORAGE["active_spans"][span_id] = span - # Set as current span using context manager - with trace.use_span(span, end_on_exit=False): - pass # Let the span continue beyond this block - elif isinstance(event.payload, SpanEndPayload): span = _GLOBAL_STORAGE["active_spans"].get(span_id) if span: @@ -202,64 +196,148 @@ class OpenTelemetryAdapter(Telemetry): else trace.Status(status_code=trace.StatusCode.ERROR) ) span.set_status(status) - span.end(end_time=int(event.timestamp.timestamp() * 1e9)) - - # Remove from active spans + span.end() _GLOBAL_STORAGE["active_spans"].pop(span_id, None) - async def get_trace(self, trace_id: str) -> Trace: - raise NotImplementedError("Trace retrieval not implemented yet") - - async def get_traces_for_session( - self, session_id: str, lookback: str = "1h", limit: int = 100 + async def get_traces_for_eval( + self, + session_ids: List[str], + lookback: str = "1h", + limit: int = 100, + dataset_id: Optional[str] = None, ) -> List[Dict[str, Any]]: + traces = [] + + # Fetch traces for each session ID individually + for session_id in session_ids: + params = { + "service": self.config.service_name, + "lookback": lookback, + "limit": limit, + "tags": f'{{"session_id":"{session_id}"}}', + } + + try: + async with aiohttp.ClientSession() as session: + async with session.get( + self.config.export_endpoint, params=params + ) as response: + if response.status != 200: + raise Exception( + f"Failed to query Jaeger: {response.status} {await response.text()}" + ) + + traces_data = await response.json() + seen_trace_ids = set() + + # For each trace ID, get the detailed trace information + for trace_data in traces_data.get("data", []): + trace_id = trace_data.get("traceID") + if trace_id and trace_id not in seen_trace_ids: + seen_trace_ids.add(trace_id) + trace_details = await self.get_trace_for_eval(trace_id) + if trace_details: + traces.append(trace_details) + + except Exception as e: + raise Exception(f"Error querying Jaeger traces: {str(e)}") from e + + return traces + + async def get_trace(self, trace_id: str) -> Dict[str, Any]: params = { - "tags": f'{{"session_id":"{session_id}"}}', - "lookback": lookback, - "limit": limit, - "service": self.config.service_name, + "traceID": trace_id, } try: async with aiohttp.ClientSession() as session: async with session.get( - self.config.export_endpoint, params=params + f"{self.config.export_endpoint}/{trace_id}", params=params ) as response: if response.status != 200: raise Exception( f"Failed to query Jaeger: {response.status} {await response.text()}" ) - traces_data = await response.json() - processed_traces = [] + trace_data = await response.json() + if not trace_data.get("data") or not trace_data["data"]: + return None - for trace_data in traces_data.get("data", []): - trace_steps = [] - for span in trace_data.get("spans", []): - step_info = { - "step": span.get("operationName"), - "start_time": span.get("startTime"), - "duration": span.get("duration"), - } + # First pass: Build span map + span_map = {} + for span in trace_data["data"][0]["spans"]: + start_time = span["startTime"] + end_time = start_time + span.get( + "duration", 0 + ) # Get end time from duration if available - tags = span.get("tags", []) - for tag in tags: - if tag.get("key") == "input": - step_info["input"] = tag.get("value") - elif tag.get("key") == "output": - step_info["output"] = tag.get("value") - # we only want to return steps which have input and output - if step_info.get("input") and step_info.get("output"): - trace_steps.append(step_info) + # Some systems store end time directly in the span + if "endTime" in span: + end_time = span["endTime"] + duration = end_time - start_time + else: + duration = span.get("duration", 0) - processed_traces.append( - { - "trace_id": trace_data.get("traceID"), - "steps": trace_steps, - } - ) + span_map[span["spanID"]] = { + "id": span["spanID"], + "name": span["operationName"], + "start_time": start_time, + "end_time": end_time, + "duration": duration, + "tags": { + tag["key"]: tag["value"] for tag in span.get("tags", []) + }, + "children": [], + } - return processed_traces + # Second pass: Build parent-child relationships + root_spans = [] + for span in trace_data["data"][0]["spans"]: + references = span.get("references", []) + if references and references[0]["refType"] == "CHILD_OF": + parent_id = references[0]["spanID"] + if parent_id in span_map: + span_map[parent_id]["children"].append( + span_map[span["spanID"]] + ) + else: + root_spans.append(span_map[span["spanID"]]) + + return { + "trace_id": trace_id, + "spans": root_spans, + } except Exception as e: - raise Exception(f"Error querying Jaeger traces: {str(e)}") from e + raise Exception(f"Error querying Jaeger trace structure: {str(e)}") from e + + async def get_trace_for_eval(self, trace_id: str) -> List[Dict[str, Any]]: + """ + Get simplified trace information focusing on first-level children of create_and_execute_turn operations. + Returns a list of spans with name, input, and output information, sorted by start time. + """ + trace_data = await self.get_trace(trace_id) + if not trace_data: + return [] + + def find_execute_turn_children( + spans: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + results = [] + for span in spans: + if span["name"] == "create_and_execute_turn": + # Extract and format children spans + children = sorted(span["children"], key=lambda x: x["start_time"]) + for child in children: + results.append( + { + "name": child["name"], + "input": child["tags"].get("input", ""), + "output": child["tags"].get("output", ""), + } + ) + # Recursively search in children + results.extend(find_execute_turn_children(span["children"])) + return results + + return find_execute_turn_children(trace_data["spans"])