new types and working

This commit is contained in:
Dinesh Yeduguru 2024-11-29 18:53:52 -08:00
parent b7c4997e91
commit 94b4113c63
6 changed files with 252 additions and 157 deletions

View file

@ -143,21 +143,52 @@ class EvalTrace(BaseModel):
step: str step: str
input: str input: str
output: str output: str
expected_output: str
@json_schema_type
class SpanNode(BaseModel):
span: Span
children: List["SpanNode"] = Field(default_factory=list)
status: Optional[SpanStatus] = None
@json_schema_type
class TraceTree(BaseModel):
trace: Trace
root: Optional[SpanNode] = None
class TraceStore(Protocol):
async def get_trace(
self,
trace_id: str,
) -> TraceTree: ...
async def get_traces_for_sessions(
self,
session_ids: List[str],
) -> [Trace]: ...
@runtime_checkable @runtime_checkable
class Telemetry(Protocol): class Telemetry(Protocol):
@webmethod(route="/telemetry/log-event") @webmethod(route="/telemetry/log-event")
async def log_event(self, event: Event) -> None: ... async def log_event(self, event: Event) -> None: ...
@webmethod(route="/telemetry/get-trace", method="POST") @webmethod(route="/telemetry/get-trace", method="POST")
async def get_trace(self, trace_id: str) -> Trace: ... async def get_trace(self, trace_id: str) -> TraceTree: ...
@webmethod(route="/telemetry/get-traces-for-agent-eval", method="POST") @webmethod(route="/telemetry/get-agent-trace", method="POST")
async def get_traces_for_agent_eval( async def get_agent_trace(
self, self,
session_ids: List[str], session_ids: List[str],
lookback: str = "1h",
limit: int = 100,
dataset_id: Optional[str] = None,
) -> List[EvalTrace]: ... ) -> List[EvalTrace]: ...
@webmethod(route="/telemetry/export-agent-trace", method="POST")
async def export_agent_trace(
self,
session_ids: List[str],
dataset_id: str = None,
) -> None: ...

View file

@ -7,8 +7,6 @@
import json import json
from typing import List, Optional from typing import List, Optional
from llama_stack.apis.telemetry.telemetry import Trace
from .config import LogFormat from .config import LogFormat
from llama_stack.apis.telemetry import * # noqa: F403 from llama_stack.apis.telemetry import * # noqa: F403
@ -51,13 +49,25 @@ class ConsoleTelemetryImpl(Telemetry):
if formatted: if formatted:
print(formatted) print(formatted)
async def get_trace(self, trace_id: str) -> Trace: async def get_trace(self, trace_id: str) -> TraceTree:
raise NotImplementedError() raise NotImplementedError("Console telemetry does not support trace retrieval")
async def get_traces_for_agent_eval( async def get_agent_trace(
self, session_ids: List[str], lookback: str = "1h", limit: int = 100 self,
session_ids: List[str],
) -> List[EvalTrace]: ) -> List[EvalTrace]:
raise NotImplementedError() raise NotImplementedError(
"Console telemetry does not support agent trace retrieval"
)
async def export_agent_trace(
self,
session_ids: List[str],
dataset_id: str = None,
) -> None:
raise NotImplementedError(
"Console telemetry does not support agent trace export"
)
COLORS = { COLORS = {

View file

@ -4,12 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.providers.utils.telemetry.jaeger import JaegerTraceStore
from .config import OpenTelemetryConfig from .config import OpenTelemetryConfig
async def get_adapter_impl(config: OpenTelemetryConfig, deps): async def get_adapter_impl(config: OpenTelemetryConfig, deps):
from .opentelemetry import OpenTelemetryAdapter from .opentelemetry import OpenTelemetryAdapter
impl = OpenTelemetryAdapter(config, deps) trace_store = JaegerTraceStore(config.jaeger_query_endpoint, config.service_name)
impl = OpenTelemetryAdapter(config, trace_store, deps)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -18,7 +18,7 @@ class OpenTelemetryConfig(BaseModel):
default="llama-stack", default="llama-stack",
description="The service name to use for telemetry", description="The service name to use for telemetry",
) )
export_endpoint: str = Field( jaeger_query_endpoint: str = Field(
default="http://localhost:16686/api/traces", default="http://localhost:16686/api/traces",
description="The Jaeger query endpoint URL", description="The Jaeger query endpoint URL",
) )

View file

@ -5,9 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import threading import threading
from typing import Any, Dict, List, Optional from typing import List
import aiohttp
from opentelemetry import metrics, trace from opentelemetry import metrics, trace
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
@ -51,9 +49,12 @@ def is_tracing_enabled(tracer):
class OpenTelemetryAdapter(Telemetry): class OpenTelemetryAdapter(Telemetry):
def __init__(self, config: OpenTelemetryConfig, deps) -> None: def __init__(
self, config: OpenTelemetryConfig, trace_store: TraceStore, deps
) -> None:
self.config = config self.config = config
self.datasetio = deps[Api.datasetio] self.datasetio = deps[Api.datasetio]
self.trace_store = trace_store
resource = Resource.create( resource = Resource.create(
{ {
@ -202,157 +203,67 @@ class OpenTelemetryAdapter(Telemetry):
span.end() span.end()
_GLOBAL_STORAGE["active_spans"].pop(span_id, None) _GLOBAL_STORAGE["active_spans"].pop(span_id, None)
async def get_traces_for_agent_eval( async def get_trace(self, trace_id: str) -> TraceTree:
return await self.trace_store.get_trace(trace_id)
async def get_agent_trace(
self, self,
session_ids: List[str], session_ids: List[str],
lookback: str = "1h",
limit: int = 100,
dataset_id: Optional[str] = None,
) -> List[EvalTrace]: ) -> List[EvalTrace]:
traces = [] traces = []
# Fetch traces for each session ID individually
for session_id in session_ids: for session_id in session_ids:
params = { traces_for_session = await self.trace_store.get_traces_for_sessions(
"service": self.config.service_name, [session_id]
"lookback": lookback, )
"limit": limit, for session_trace in traces_for_session:
"tags": f'{{"session_id":"{session_id}"}}', trace_details = await self._get_simplified_agent_trace(
} session_trace.trace_id, session_id
)
traces.extend(trace_details)
try:
async with aiohttp.ClientSession() as session:
async with session.get(
self.config.export_endpoint, params=params
) as response:
if response.status != 200:
raise Exception(
f"Failed to query Jaeger: {response.status} {await response.text()}"
)
traces_data = await response.json()
seen_trace_ids = set()
for trace_data in traces_data.get("data", []):
trace_id = trace_data.get("traceID")
if trace_id and trace_id not in seen_trace_ids:
seen_trace_ids.add(trace_id)
trace_details = await self.get_trace_for_eval(
trace_id, session_id
)
traces.extend(trace_details)
except Exception as e:
raise Exception(f"Error querying Jaeger traces: {str(e)}") from e
if dataset_id:
traces_dict = [
{
"step": trace.step,
"input": trace.input,
"output": trace.output,
"session_id": trace.session_id,
}
for trace in traces
]
await self.datasetio.upload_rows(dataset_id, traces_dict)
return traces return traces
async def get_trace(self, trace_id: str) -> Dict[str, Any]: async def export_agent_trace(
params = { self, session_ids: List[str], dataset_id: str = None
"traceID": trace_id, ) -> None:
} traces = await self.get_agent_trace(session_ids)
traces_dict = [
{
"step": trace.step,
"input": trace.input,
"output": trace.output,
"session_id": trace.session_id,
}
for trace in traces
]
await self.datasetio.upload_rows(dataset_id, traces_dict)
try: async def _get_simplified_agent_trace(
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.config.export_endpoint}/{trace_id}", params=params
) as response:
if response.status != 200:
raise Exception(
f"Failed to query Jaeger: {response.status} {await response.text()}"
)
trace_data = await response.json()
if not trace_data.get("data") or not trace_data["data"]:
return None
# First pass: Build span map
span_map = {}
for span in trace_data["data"][0]["spans"]:
start_time = span["startTime"]
end_time = start_time + span.get(
"duration", 0
) # Get end time from duration if available
# Some systems store end time directly in the span
if "endTime" in span:
end_time = span["endTime"]
duration = end_time - start_time
else:
duration = span.get("duration", 0)
span_map[span["spanID"]] = {
"id": span["spanID"],
"name": span["operationName"],
"start_time": start_time,
"end_time": end_time,
"duration": duration,
"tags": {
tag["key"]: tag["value"] for tag in span.get("tags", [])
},
"children": [],
}
# Second pass: Build parent-child relationships
root_spans = []
for span in trace_data["data"][0]["spans"]:
references = span.get("references", [])
if references and references[0]["refType"] == "CHILD_OF":
parent_id = references[0]["spanID"]
if parent_id in span_map:
span_map[parent_id]["children"].append(
span_map[span["spanID"]]
)
else:
root_spans.append(span_map[span["spanID"]])
return {
"trace_id": trace_id,
"spans": root_spans,
}
except Exception as e:
raise Exception(f"Error querying Jaeger trace structure: {str(e)}") from e
async def get_trace_for_eval(
self, trace_id: str, session_id: str self, trace_id: str, session_id: str
) -> List[EvalTrace]: ) -> List[EvalTrace]:
""" trace_tree = await self.get_trace(trace_id)
Get simplified trace information focusing on first-level children of create_and_execute_turn operations. if not trace_tree or not trace_tree.root:
Returns a list of spans with name, input, and output information, sorted by start time.
"""
trace_data = await self.get_trace(trace_id)
if not trace_data:
return [] return []
def find_execute_turn_children(spans: List[Dict[str, Any]]) -> List[EvalTrace]: def find_execute_turn_children(node: SpanNode) -> List[EvalTrace]:
results: List[EvalTrace] = [] results = []
for span in spans: if node.span.name == "create_and_execute_turn":
if span["name"] == "create_and_execute_turn": # Sort children by start time
# Extract and format children spans sorted_children = sorted(node.children, key=lambda x: x.span.start_time)
children = sorted(span["children"], key=lambda x: x["start_time"]) for child in sorted_children:
for child in children: results.append(
results.append( EvalTrace(
EvalTrace( step=child.span.name,
step=child["name"], input=child.span.attributes.get("input", ""),
input=child["tags"].get("input", ""), output=child.span.attributes.get("output", ""),
output=child["tags"].get("output", ""), session_id=session_id,
session_id=session_id, expected_output="",
)
) )
# Recursively search in children )
results.extend(find_execute_turn_children(span["children"]))
# Recursively process children
for child in node.children:
results.extend(find_execute_turn_children(child))
return results return results
return find_execute_turn_children(trace_data["spans"]) return find_execute_turn_children(trace_tree.root)

View file

@ -0,0 +1,141 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime, timedelta
from typing import List
import aiohttp
from llama_stack.apis.telemetry import Span, SpanNode, Trace, TraceStore, TraceTree
class JaegerTraceStore(TraceStore):
def __init__(self, endpoint: str, service_name: str):
self.endpoint = endpoint
self.service_name = service_name
async def get_trace(self, trace_id: str) -> TraceTree:
params = {
"traceID": trace_id,
}
try:
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.endpoint}/{trace_id}", params=params
) as response:
if response.status != 200:
raise Exception(
f"Failed to query Jaeger: {response.status} {await response.text()}"
)
trace_data = await response.json()
if not trace_data.get("data") or not trace_data["data"]:
return None
# First pass: Build span map
span_map = {}
for jaeger_span in trace_data["data"][0]["spans"]:
start_time = datetime.fromtimestamp(
jaeger_span["startTime"] / 1000000
)
# Some systems store end time directly in the span
if "endTime" in jaeger_span:
end_time = datetime.fromtimestamp(
jaeger_span["endTime"] / 1000000
)
else:
duration_microseconds = jaeger_span.get("duration", 0)
duration_timedelta = timedelta(
microseconds=duration_microseconds
)
end_time = start_time + duration_timedelta
span = Span(
span_id=jaeger_span["spanID"],
trace_id=trace_id,
name=jaeger_span["operationName"],
start_time=start_time,
end_time=end_time,
parent_span_id=next(
(
ref["spanID"]
for ref in jaeger_span.get("references", [])
if ref["refType"] == "CHILD_OF"
),
None,
),
attributes={
tag["key"]: tag["value"]
for tag in jaeger_span.get("tags", [])
},
)
span_map[span.span_id] = SpanNode(span=span)
# Second pass: Build parent-child relationships
root_node = None
for span_node in span_map.values():
parent_id = span_node.span.parent_span_id
if parent_id and parent_id in span_map:
span_map[parent_id].children.append(span_node)
elif not parent_id:
root_node = span_node
trace = Trace(
trace_id=trace_id,
root_span_id=root_node.span.span_id if root_node else "",
start_time=(
root_node.span.start_time if root_node else datetime.now()
),
end_time=root_node.span.end_time if root_node else None,
)
return TraceTree(trace=trace, root=root_node)
except Exception as e:
raise Exception(f"Error querying Jaeger trace structure: {str(e)}") from e
async def get_traces_for_sessions(self, session_ids: List[str]) -> List[Trace]:
traces = []
# Fetch traces for each session ID individually
for session_id in session_ids:
params = {
"service": self.service_name,
"tags": f'{{"session_id":"{session_id}"}}',
"limit": 100,
"lookback": "10000h",
}
try:
async with aiohttp.ClientSession() as session:
async with session.get(self.endpoint, params=params) as response:
if response.status != 200:
raise Exception(
f"Failed to query Jaeger: {response.status} {await response.text()}"
)
traces_data = await response.json()
seen_trace_ids = set()
for trace_data in traces_data.get("data", []):
trace_id = trace_data.get("traceID")
if trace_id and trace_id not in seen_trace_ids:
seen_trace_ids.add(trace_id)
traces.append(
Trace(
trace_id=trace_id,
root_span_id="",
start_time=datetime.now(),
)
)
except Exception as e:
raise Exception(f"Error querying Jaeger traces: {str(e)}") from e
return traces