mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
endpoint to bulk export traces for eval
This commit is contained in:
parent
b3e149334a
commit
dfe152cb97
2 changed files with 133 additions and 51 deletions
|
@ -142,10 +142,14 @@ class Telemetry(Protocol):
|
||||||
@webmethod(route="/telemetry/log-event")
|
@webmethod(route="/telemetry/log-event")
|
||||||
async def log_event(self, event: Event) -> None: ...
|
async def log_event(self, event: Event) -> None: ...
|
||||||
|
|
||||||
@webmethod(route="/telemetry/get-trace", method="GET")
|
@webmethod(route="/telemetry/get-trace", method="POST")
|
||||||
async def get_trace(self, trace_id: str) -> Trace: ...
|
async def get_trace(self, trace_id: str) -> Trace: ...
|
||||||
|
|
||||||
@webmethod(route="/telemetry/get-traces-for-session", method="POST")
|
@webmethod(route="/telemetry/get-traces-for-eval", method="POST")
|
||||||
async def get_traces_for_session(
|
async def get_traces_for_eval(
|
||||||
self, session_id: str, lookback: str = "1h", limit: int = 100
|
self,
|
||||||
|
session_ids: List[str],
|
||||||
|
lookback: str = "1h",
|
||||||
|
limit: int = 100,
|
||||||
|
dataset_id: Optional[str] = None,
|
||||||
) -> List[Trace]: ...
|
) -> List[Trace]: ...
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import threading
|
import threading
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
|
@ -173,7 +173,6 @@ class OpenTelemetryAdapter(Telemetry):
|
||||||
parent_span_id = string_to_span_id(event.payload.parent_span_id)
|
parent_span_id = string_to_span_id(event.payload.parent_span_id)
|
||||||
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
|
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
|
||||||
|
|
||||||
# Create a new trace context with the trace_id
|
|
||||||
context = trace.Context(trace_id=trace_id)
|
context = trace.Context(trace_id=trace_id)
|
||||||
if parent_span:
|
if parent_span:
|
||||||
context = trace.set_span_in_context(parent_span, context)
|
context = trace.set_span_in_context(parent_span, context)
|
||||||
|
@ -182,14 +181,9 @@ class OpenTelemetryAdapter(Telemetry):
|
||||||
name=event.payload.name,
|
name=event.payload.name,
|
||||||
context=context,
|
context=context,
|
||||||
attributes=event.attributes or {},
|
attributes=event.attributes or {},
|
||||||
start_time=int(event.timestamp.timestamp() * 1e9),
|
|
||||||
)
|
)
|
||||||
_GLOBAL_STORAGE["active_spans"][span_id] = span
|
_GLOBAL_STORAGE["active_spans"][span_id] = span
|
||||||
|
|
||||||
# Set as current span using context manager
|
|
||||||
with trace.use_span(span, end_on_exit=False):
|
|
||||||
pass # Let the span continue beyond this block
|
|
||||||
|
|
||||||
elif isinstance(event.payload, SpanEndPayload):
|
elif isinstance(event.payload, SpanEndPayload):
|
||||||
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
|
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
|
||||||
if span:
|
if span:
|
||||||
|
@ -202,22 +196,25 @@ class OpenTelemetryAdapter(Telemetry):
|
||||||
else trace.Status(status_code=trace.StatusCode.ERROR)
|
else trace.Status(status_code=trace.StatusCode.ERROR)
|
||||||
)
|
)
|
||||||
span.set_status(status)
|
span.set_status(status)
|
||||||
span.end(end_time=int(event.timestamp.timestamp() * 1e9))
|
span.end()
|
||||||
|
|
||||||
# Remove from active spans
|
|
||||||
_GLOBAL_STORAGE["active_spans"].pop(span_id, None)
|
_GLOBAL_STORAGE["active_spans"].pop(span_id, None)
|
||||||
|
|
||||||
async def get_trace(self, trace_id: str) -> Trace:
|
async def get_traces_for_eval(
|
||||||
raise NotImplementedError("Trace retrieval not implemented yet")
|
self,
|
||||||
|
session_ids: List[str],
|
||||||
async def get_traces_for_session(
|
lookback: str = "1h",
|
||||||
self, session_id: str, lookback: str = "1h", limit: int = 100
|
limit: int = 100,
|
||||||
|
dataset_id: Optional[str] = None,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
|
traces = []
|
||||||
|
|
||||||
|
# Fetch traces for each session ID individually
|
||||||
|
for session_id in session_ids:
|
||||||
params = {
|
params = {
|
||||||
"tags": f'{{"session_id":"{session_id}"}}',
|
"service": self.config.service_name,
|
||||||
"lookback": lookback,
|
"lookback": lookback,
|
||||||
"limit": limit,
|
"limit": limit,
|
||||||
"service": self.config.service_name,
|
"tags": f'{{"session_id":"{session_id}"}}',
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -231,35 +228,116 @@ class OpenTelemetryAdapter(Telemetry):
|
||||||
)
|
)
|
||||||
|
|
||||||
traces_data = await response.json()
|
traces_data = await response.json()
|
||||||
processed_traces = []
|
seen_trace_ids = set()
|
||||||
|
|
||||||
|
# For each trace ID, get the detailed trace information
|
||||||
for trace_data in traces_data.get("data", []):
|
for trace_data in traces_data.get("data", []):
|
||||||
trace_steps = []
|
trace_id = trace_data.get("traceID")
|
||||||
for span in trace_data.get("spans", []):
|
if trace_id and trace_id not in seen_trace_ids:
|
||||||
step_info = {
|
seen_trace_ids.add(trace_id)
|
||||||
"step": span.get("operationName"),
|
trace_details = await self.get_trace_for_eval(trace_id)
|
||||||
"start_time": span.get("startTime"),
|
if trace_details:
|
||||||
"duration": span.get("duration"),
|
traces.append(trace_details)
|
||||||
}
|
|
||||||
|
|
||||||
tags = span.get("tags", [])
|
|
||||||
for tag in tags:
|
|
||||||
if tag.get("key") == "input":
|
|
||||||
step_info["input"] = tag.get("value")
|
|
||||||
elif tag.get("key") == "output":
|
|
||||||
step_info["output"] = tag.get("value")
|
|
||||||
# we only want to return steps which have input and output
|
|
||||||
if step_info.get("input") and step_info.get("output"):
|
|
||||||
trace_steps.append(step_info)
|
|
||||||
|
|
||||||
processed_traces.append(
|
|
||||||
{
|
|
||||||
"trace_id": trace_data.get("traceID"),
|
|
||||||
"steps": trace_steps,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return processed_traces
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f"Error querying Jaeger traces: {str(e)}") from e
|
raise Exception(f"Error querying Jaeger traces: {str(e)}") from e
|
||||||
|
|
||||||
|
return traces
|
||||||
|
|
||||||
|
async def get_trace(self, trace_id: str) -> Dict[str, Any]:
|
||||||
|
params = {
|
||||||
|
"traceID": trace_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(
|
||||||
|
f"{self.config.export_endpoint}/{trace_id}", params=params
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise Exception(
|
||||||
|
f"Failed to query Jaeger: {response.status} {await response.text()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
trace_data = await response.json()
|
||||||
|
if not trace_data.get("data") or not trace_data["data"]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# First pass: Build span map
|
||||||
|
span_map = {}
|
||||||
|
for span in trace_data["data"][0]["spans"]:
|
||||||
|
start_time = span["startTime"]
|
||||||
|
end_time = start_time + span.get(
|
||||||
|
"duration", 0
|
||||||
|
) # Get end time from duration if available
|
||||||
|
|
||||||
|
# Some systems store end time directly in the span
|
||||||
|
if "endTime" in span:
|
||||||
|
end_time = span["endTime"]
|
||||||
|
duration = end_time - start_time
|
||||||
|
else:
|
||||||
|
duration = span.get("duration", 0)
|
||||||
|
|
||||||
|
span_map[span["spanID"]] = {
|
||||||
|
"id": span["spanID"],
|
||||||
|
"name": span["operationName"],
|
||||||
|
"start_time": start_time,
|
||||||
|
"end_time": end_time,
|
||||||
|
"duration": duration,
|
||||||
|
"tags": {
|
||||||
|
tag["key"]: tag["value"] for tag in span.get("tags", [])
|
||||||
|
},
|
||||||
|
"children": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Second pass: Build parent-child relationships
|
||||||
|
root_spans = []
|
||||||
|
for span in trace_data["data"][0]["spans"]:
|
||||||
|
references = span.get("references", [])
|
||||||
|
if references and references[0]["refType"] == "CHILD_OF":
|
||||||
|
parent_id = references[0]["spanID"]
|
||||||
|
if parent_id in span_map:
|
||||||
|
span_map[parent_id]["children"].append(
|
||||||
|
span_map[span["spanID"]]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
root_spans.append(span_map[span["spanID"]])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"trace_id": trace_id,
|
||||||
|
"spans": root_spans,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Error querying Jaeger trace structure: {str(e)}") from e
|
||||||
|
|
||||||
|
async def get_trace_for_eval(self, trace_id: str) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get simplified trace information focusing on first-level children of create_and_execute_turn operations.
|
||||||
|
Returns a list of spans with name, input, and output information, sorted by start time.
|
||||||
|
"""
|
||||||
|
trace_data = await self.get_trace(trace_id)
|
||||||
|
if not trace_data:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def find_execute_turn_children(
|
||||||
|
spans: List[Dict[str, Any]]
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
results = []
|
||||||
|
for span in spans:
|
||||||
|
if span["name"] == "create_and_execute_turn":
|
||||||
|
# Extract and format children spans
|
||||||
|
children = sorted(span["children"], key=lambda x: x["start_time"])
|
||||||
|
for child in children:
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"name": child["name"],
|
||||||
|
"input": child["tags"].get("input", ""),
|
||||||
|
"output": child["tags"].get("output", ""),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# Recursively search in children
|
||||||
|
results.extend(find_execute_turn_children(span["children"]))
|
||||||
|
return results
|
||||||
|
|
||||||
|
return find_execute_turn_children(trace_data["spans"])
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue