endpoint to bulk export traces for eval

This commit is contained in:
Dinesh Yeduguru 2024-11-26 22:09:49 -08:00
parent b3e149334a
commit dfe152cb97
2 changed files with 133 additions and 51 deletions

View file

@ -142,10 +142,14 @@ class Telemetry(Protocol):
@webmethod(route="/telemetry/log-event")
async def log_event(self, event: Event) -> None: ...
@webmethod(route="/telemetry/get-trace", method="GET")
@webmethod(route="/telemetry/get-trace", method="POST")
async def get_trace(self, trace_id: str) -> Trace: ...
@webmethod(route="/telemetry/get-traces-for-session", method="POST")
async def get_traces_for_session(
self, session_id: str, lookback: str = "1h", limit: int = 100
@webmethod(route="/telemetry/get-traces-for-eval", method="POST")
async def get_traces_for_eval(
self,
session_ids: List[str],
lookback: str = "1h",
limit: int = 100,
dataset_id: Optional[str] = None,
) -> List[Trace]: ...

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import threading
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
import aiohttp
@ -173,7 +173,6 @@ class OpenTelemetryAdapter(Telemetry):
parent_span_id = string_to_span_id(event.payload.parent_span_id)
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
# Create a new trace context with the trace_id
context = trace.Context(trace_id=trace_id)
if parent_span:
context = trace.set_span_in_context(parent_span, context)
@ -182,14 +181,9 @@ class OpenTelemetryAdapter(Telemetry):
name=event.payload.name,
context=context,
attributes=event.attributes or {},
start_time=int(event.timestamp.timestamp() * 1e9),
)
_GLOBAL_STORAGE["active_spans"][span_id] = span
# Set as current span using context manager
with trace.use_span(span, end_on_exit=False):
pass # Let the span continue beyond this block
elif isinstance(event.payload, SpanEndPayload):
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
if span:
@ -202,64 +196,148 @@ class OpenTelemetryAdapter(Telemetry):
else trace.Status(status_code=trace.StatusCode.ERROR)
)
span.set_status(status)
span.end(end_time=int(event.timestamp.timestamp() * 1e9))
# Remove from active spans
span.end()
_GLOBAL_STORAGE["active_spans"].pop(span_id, None)
async def get_trace(self, trace_id: str) -> Trace:
raise NotImplementedError("Trace retrieval not implemented yet")
async def get_traces_for_session(
self, session_id: str, lookback: str = "1h", limit: int = 100
async def get_traces_for_eval(
self,
session_ids: List[str],
lookback: str = "1h",
limit: int = 100,
dataset_id: Optional[str] = None,
) -> List[Dict[str, Any]]:
traces = []
# Fetch traces for each session ID individually
for session_id in session_ids:
params = {
"service": self.config.service_name,
"lookback": lookback,
"limit": limit,
"tags": f'{{"session_id":"{session_id}"}}',
}
try:
async with aiohttp.ClientSession() as session:
async with session.get(
self.config.export_endpoint, params=params
) as response:
if response.status != 200:
raise Exception(
f"Failed to query Jaeger: {response.status} {await response.text()}"
)
traces_data = await response.json()
seen_trace_ids = set()
# For each trace ID, get the detailed trace information
for trace_data in traces_data.get("data", []):
trace_id = trace_data.get("traceID")
if trace_id and trace_id not in seen_trace_ids:
seen_trace_ids.add(trace_id)
trace_details = await self.get_trace_for_eval(trace_id)
if trace_details:
traces.append(trace_details)
except Exception as e:
raise Exception(f"Error querying Jaeger traces: {str(e)}") from e
return traces
async def get_trace(self, trace_id: str) -> Dict[str, Any]:
params = {
"tags": f'{{"session_id":"{session_id}"}}',
"lookback": lookback,
"limit": limit,
"service": self.config.service_name,
"traceID": trace_id,
}
try:
async with aiohttp.ClientSession() as session:
async with session.get(
self.config.export_endpoint, params=params
f"{self.config.export_endpoint}/{trace_id}", params=params
) as response:
if response.status != 200:
raise Exception(
f"Failed to query Jaeger: {response.status} {await response.text()}"
)
traces_data = await response.json()
processed_traces = []
trace_data = await response.json()
if not trace_data.get("data") or not trace_data["data"]:
return None
for trace_data in traces_data.get("data", []):
trace_steps = []
for span in trace_data.get("spans", []):
step_info = {
"step": span.get("operationName"),
"start_time": span.get("startTime"),
"duration": span.get("duration"),
}
# First pass: Build span map
span_map = {}
for span in trace_data["data"][0]["spans"]:
start_time = span["startTime"]
end_time = start_time + span.get(
"duration", 0
) # Get end time from duration if available
tags = span.get("tags", [])
for tag in tags:
if tag.get("key") == "input":
step_info["input"] = tag.get("value")
elif tag.get("key") == "output":
step_info["output"] = tag.get("value")
# we only want to return steps which have input and output
if step_info.get("input") and step_info.get("output"):
trace_steps.append(step_info)
# Some systems store end time directly in the span
if "endTime" in span:
end_time = span["endTime"]
duration = end_time - start_time
else:
duration = span.get("duration", 0)
processed_traces.append(
{
"trace_id": trace_data.get("traceID"),
"steps": trace_steps,
}
)
span_map[span["spanID"]] = {
"id": span["spanID"],
"name": span["operationName"],
"start_time": start_time,
"end_time": end_time,
"duration": duration,
"tags": {
tag["key"]: tag["value"] for tag in span.get("tags", [])
},
"children": [],
}
return processed_traces
# Second pass: Build parent-child relationships
root_spans = []
for span in trace_data["data"][0]["spans"]:
references = span.get("references", [])
if references and references[0]["refType"] == "CHILD_OF":
parent_id = references[0]["spanID"]
if parent_id in span_map:
span_map[parent_id]["children"].append(
span_map[span["spanID"]]
)
else:
root_spans.append(span_map[span["spanID"]])
return {
"trace_id": trace_id,
"spans": root_spans,
}
except Exception as e:
raise Exception(f"Error querying Jaeger traces: {str(e)}") from e
raise Exception(f"Error querying Jaeger trace structure: {str(e)}") from e
async def get_trace_for_eval(self, trace_id: str) -> List[Dict[str, Any]]:
"""
Get simplified trace information focusing on first-level children of create_and_execute_turn operations.
Returns a list of spans with name, input, and output information, sorted by start time.
"""
trace_data = await self.get_trace(trace_id)
if not trace_data:
return []
def find_execute_turn_children(
spans: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
results = []
for span in spans:
if span["name"] == "create_and_execute_turn":
# Extract and format children spans
children = sorted(span["children"], key=lambda x: x["start_time"])
for child in children:
results.append(
{
"name": child["name"],
"input": child["tags"].get("input", ""),
"output": child["tags"].get("output", ""),
}
)
# Recursively search in children
results.extend(find_execute_turn_children(span["children"]))
return results
return find_execute_turn_children(trace_data["spans"])