mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 00:34:44 +00:00
new types and working
This commit is contained in:
parent
b7c4997e91
commit
94b4113c63
6 changed files with 252 additions and 157 deletions
|
@ -143,21 +143,52 @@ class EvalTrace(BaseModel):
|
||||||
step: str
|
step: str
|
||||||
input: str
|
input: str
|
||||||
output: str
|
output: str
|
||||||
|
expected_output: str
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class SpanNode(BaseModel):
|
||||||
|
span: Span
|
||||||
|
children: List["SpanNode"] = Field(default_factory=list)
|
||||||
|
status: Optional[SpanStatus] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class TraceTree(BaseModel):
|
||||||
|
trace: Trace
|
||||||
|
root: Optional[SpanNode] = None
|
||||||
|
|
||||||
|
|
||||||
|
class TraceStore(Protocol):
|
||||||
|
async def get_trace(
|
||||||
|
self,
|
||||||
|
trace_id: str,
|
||||||
|
) -> TraceTree: ...
|
||||||
|
|
||||||
|
async def get_traces_for_sessions(
|
||||||
|
self,
|
||||||
|
session_ids: List[str],
|
||||||
|
) -> [Trace]: ...
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class Telemetry(Protocol):
|
class Telemetry(Protocol):
|
||||||
|
|
||||||
@webmethod(route="/telemetry/log-event")
|
@webmethod(route="/telemetry/log-event")
|
||||||
async def log_event(self, event: Event) -> None: ...
|
async def log_event(self, event: Event) -> None: ...
|
||||||
|
|
||||||
@webmethod(route="/telemetry/get-trace", method="POST")
|
@webmethod(route="/telemetry/get-trace", method="POST")
|
||||||
async def get_trace(self, trace_id: str) -> Trace: ...
|
async def get_trace(self, trace_id: str) -> TraceTree: ...
|
||||||
|
|
||||||
@webmethod(route="/telemetry/get-traces-for-agent-eval", method="POST")
|
@webmethod(route="/telemetry/get-agent-trace", method="POST")
|
||||||
async def get_traces_for_agent_eval(
|
async def get_agent_trace(
|
||||||
self,
|
self,
|
||||||
session_ids: List[str],
|
session_ids: List[str],
|
||||||
lookback: str = "1h",
|
|
||||||
limit: int = 100,
|
|
||||||
dataset_id: Optional[str] = None,
|
|
||||||
) -> List[EvalTrace]: ...
|
) -> List[EvalTrace]: ...
|
||||||
|
|
||||||
|
@webmethod(route="/telemetry/export-agent-trace", method="POST")
|
||||||
|
async def export_agent_trace(
|
||||||
|
self,
|
||||||
|
session_ids: List[str],
|
||||||
|
dataset_id: str = None,
|
||||||
|
) -> None: ...
|
||||||
|
|
|
@ -7,8 +7,6 @@
|
||||||
import json
|
import json
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from llama_stack.apis.telemetry.telemetry import Trace
|
|
||||||
|
|
||||||
from .config import LogFormat
|
from .config import LogFormat
|
||||||
|
|
||||||
from llama_stack.apis.telemetry import * # noqa: F403
|
from llama_stack.apis.telemetry import * # noqa: F403
|
||||||
|
@ -51,13 +49,25 @@ class ConsoleTelemetryImpl(Telemetry):
|
||||||
if formatted:
|
if formatted:
|
||||||
print(formatted)
|
print(formatted)
|
||||||
|
|
||||||
async def get_trace(self, trace_id: str) -> Trace:
|
async def get_trace(self, trace_id: str) -> TraceTree:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError("Console telemetry does not support trace retrieval")
|
||||||
|
|
||||||
async def get_traces_for_agent_eval(
|
async def get_agent_trace(
|
||||||
self, session_ids: List[str], lookback: str = "1h", limit: int = 100
|
self,
|
||||||
|
session_ids: List[str],
|
||||||
) -> List[EvalTrace]:
|
) -> List[EvalTrace]:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError(
|
||||||
|
"Console telemetry does not support agent trace retrieval"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def export_agent_trace(
|
||||||
|
self,
|
||||||
|
session_ids: List[str],
|
||||||
|
dataset_id: str = None,
|
||||||
|
) -> None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Console telemetry does not support agent trace export"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
COLORS = {
|
COLORS = {
|
||||||
|
|
|
@ -4,12 +4,14 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.telemetry.jaeger import JaegerTraceStore
|
||||||
from .config import OpenTelemetryConfig
|
from .config import OpenTelemetryConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: OpenTelemetryConfig, deps):
|
async def get_adapter_impl(config: OpenTelemetryConfig, deps):
|
||||||
from .opentelemetry import OpenTelemetryAdapter
|
from .opentelemetry import OpenTelemetryAdapter
|
||||||
|
|
||||||
impl = OpenTelemetryAdapter(config, deps)
|
trace_store = JaegerTraceStore(config.jaeger_query_endpoint, config.service_name)
|
||||||
|
impl = OpenTelemetryAdapter(config, trace_store, deps)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -18,7 +18,7 @@ class OpenTelemetryConfig(BaseModel):
|
||||||
default="llama-stack",
|
default="llama-stack",
|
||||||
description="The service name to use for telemetry",
|
description="The service name to use for telemetry",
|
||||||
)
|
)
|
||||||
export_endpoint: str = Field(
|
jaeger_query_endpoint: str = Field(
|
||||||
default="http://localhost:16686/api/traces",
|
default="http://localhost:16686/api/traces",
|
||||||
description="The Jaeger query endpoint URL",
|
description="The Jaeger query endpoint URL",
|
||||||
)
|
)
|
||||||
|
|
|
@ -5,9 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import threading
|
import threading
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import List
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
|
|
||||||
from opentelemetry import metrics, trace
|
from opentelemetry import metrics, trace
|
||||||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
|
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
|
||||||
|
@ -51,9 +49,12 @@ def is_tracing_enabled(tracer):
|
||||||
|
|
||||||
|
|
||||||
class OpenTelemetryAdapter(Telemetry):
|
class OpenTelemetryAdapter(Telemetry):
|
||||||
def __init__(self, config: OpenTelemetryConfig, deps) -> None:
|
def __init__(
|
||||||
|
self, config: OpenTelemetryConfig, trace_store: TraceStore, deps
|
||||||
|
) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.datasetio = deps[Api.datasetio]
|
self.datasetio = deps[Api.datasetio]
|
||||||
|
self.trace_store = trace_store
|
||||||
|
|
||||||
resource = Resource.create(
|
resource = Resource.create(
|
||||||
{
|
{
|
||||||
|
@ -202,157 +203,67 @@ class OpenTelemetryAdapter(Telemetry):
|
||||||
span.end()
|
span.end()
|
||||||
_GLOBAL_STORAGE["active_spans"].pop(span_id, None)
|
_GLOBAL_STORAGE["active_spans"].pop(span_id, None)
|
||||||
|
|
||||||
async def get_traces_for_agent_eval(
|
async def get_trace(self, trace_id: str) -> TraceTree:
|
||||||
|
return await self.trace_store.get_trace(trace_id)
|
||||||
|
|
||||||
|
async def get_agent_trace(
|
||||||
self,
|
self,
|
||||||
session_ids: List[str],
|
session_ids: List[str],
|
||||||
lookback: str = "1h",
|
|
||||||
limit: int = 100,
|
|
||||||
dataset_id: Optional[str] = None,
|
|
||||||
) -> List[EvalTrace]:
|
) -> List[EvalTrace]:
|
||||||
traces = []
|
traces = []
|
||||||
|
|
||||||
# Fetch traces for each session ID individually
|
|
||||||
for session_id in session_ids:
|
for session_id in session_ids:
|
||||||
params = {
|
traces_for_session = await self.trace_store.get_traces_for_sessions(
|
||||||
"service": self.config.service_name,
|
[session_id]
|
||||||
"lookback": lookback,
|
)
|
||||||
"limit": limit,
|
for session_trace in traces_for_session:
|
||||||
"tags": f'{{"session_id":"{session_id}"}}',
|
trace_details = await self._get_simplified_agent_trace(
|
||||||
}
|
session_trace.trace_id, session_id
|
||||||
|
)
|
||||||
|
traces.extend(trace_details)
|
||||||
|
|
||||||
try:
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.get(
|
|
||||||
self.config.export_endpoint, params=params
|
|
||||||
) as response:
|
|
||||||
if response.status != 200:
|
|
||||||
raise Exception(
|
|
||||||
f"Failed to query Jaeger: {response.status} {await response.text()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
traces_data = await response.json()
|
|
||||||
seen_trace_ids = set()
|
|
||||||
|
|
||||||
for trace_data in traces_data.get("data", []):
|
|
||||||
trace_id = trace_data.get("traceID")
|
|
||||||
if trace_id and trace_id not in seen_trace_ids:
|
|
||||||
seen_trace_ids.add(trace_id)
|
|
||||||
trace_details = await self.get_trace_for_eval(
|
|
||||||
trace_id, session_id
|
|
||||||
)
|
|
||||||
traces.extend(trace_details)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise Exception(f"Error querying Jaeger traces: {str(e)}") from e
|
|
||||||
|
|
||||||
if dataset_id:
|
|
||||||
traces_dict = [
|
|
||||||
{
|
|
||||||
"step": trace.step,
|
|
||||||
"input": trace.input,
|
|
||||||
"output": trace.output,
|
|
||||||
"session_id": trace.session_id,
|
|
||||||
}
|
|
||||||
for trace in traces
|
|
||||||
]
|
|
||||||
await self.datasetio.upload_rows(dataset_id, traces_dict)
|
|
||||||
return traces
|
return traces
|
||||||
|
|
||||||
async def get_trace(self, trace_id: str) -> Dict[str, Any]:
|
async def export_agent_trace(
|
||||||
params = {
|
self, session_ids: List[str], dataset_id: str = None
|
||||||
"traceID": trace_id,
|
) -> None:
|
||||||
}
|
traces = await self.get_agent_trace(session_ids)
|
||||||
|
traces_dict = [
|
||||||
|
{
|
||||||
|
"step": trace.step,
|
||||||
|
"input": trace.input,
|
||||||
|
"output": trace.output,
|
||||||
|
"session_id": trace.session_id,
|
||||||
|
}
|
||||||
|
for trace in traces
|
||||||
|
]
|
||||||
|
await self.datasetio.upload_rows(dataset_id, traces_dict)
|
||||||
|
|
||||||
try:
|
async def _get_simplified_agent_trace(
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.get(
|
|
||||||
f"{self.config.export_endpoint}/{trace_id}", params=params
|
|
||||||
) as response:
|
|
||||||
if response.status != 200:
|
|
||||||
raise Exception(
|
|
||||||
f"Failed to query Jaeger: {response.status} {await response.text()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
trace_data = await response.json()
|
|
||||||
if not trace_data.get("data") or not trace_data["data"]:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# First pass: Build span map
|
|
||||||
span_map = {}
|
|
||||||
for span in trace_data["data"][0]["spans"]:
|
|
||||||
start_time = span["startTime"]
|
|
||||||
end_time = start_time + span.get(
|
|
||||||
"duration", 0
|
|
||||||
) # Get end time from duration if available
|
|
||||||
|
|
||||||
# Some systems store end time directly in the span
|
|
||||||
if "endTime" in span:
|
|
||||||
end_time = span["endTime"]
|
|
||||||
duration = end_time - start_time
|
|
||||||
else:
|
|
||||||
duration = span.get("duration", 0)
|
|
||||||
|
|
||||||
span_map[span["spanID"]] = {
|
|
||||||
"id": span["spanID"],
|
|
||||||
"name": span["operationName"],
|
|
||||||
"start_time": start_time,
|
|
||||||
"end_time": end_time,
|
|
||||||
"duration": duration,
|
|
||||||
"tags": {
|
|
||||||
tag["key"]: tag["value"] for tag in span.get("tags", [])
|
|
||||||
},
|
|
||||||
"children": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
# Second pass: Build parent-child relationships
|
|
||||||
root_spans = []
|
|
||||||
for span in trace_data["data"][0]["spans"]:
|
|
||||||
references = span.get("references", [])
|
|
||||||
if references and references[0]["refType"] == "CHILD_OF":
|
|
||||||
parent_id = references[0]["spanID"]
|
|
||||||
if parent_id in span_map:
|
|
||||||
span_map[parent_id]["children"].append(
|
|
||||||
span_map[span["spanID"]]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
root_spans.append(span_map[span["spanID"]])
|
|
||||||
|
|
||||||
return {
|
|
||||||
"trace_id": trace_id,
|
|
||||||
"spans": root_spans,
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise Exception(f"Error querying Jaeger trace structure: {str(e)}") from e
|
|
||||||
|
|
||||||
async def get_trace_for_eval(
|
|
||||||
self, trace_id: str, session_id: str
|
self, trace_id: str, session_id: str
|
||||||
) -> List[EvalTrace]:
|
) -> List[EvalTrace]:
|
||||||
"""
|
trace_tree = await self.get_trace(trace_id)
|
||||||
Get simplified trace information focusing on first-level children of create_and_execute_turn operations.
|
if not trace_tree or not trace_tree.root:
|
||||||
Returns a list of spans with name, input, and output information, sorted by start time.
|
|
||||||
"""
|
|
||||||
trace_data = await self.get_trace(trace_id)
|
|
||||||
if not trace_data:
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def find_execute_turn_children(spans: List[Dict[str, Any]]) -> List[EvalTrace]:
|
def find_execute_turn_children(node: SpanNode) -> List[EvalTrace]:
|
||||||
results: List[EvalTrace] = []
|
results = []
|
||||||
for span in spans:
|
if node.span.name == "create_and_execute_turn":
|
||||||
if span["name"] == "create_and_execute_turn":
|
# Sort children by start time
|
||||||
# Extract and format children spans
|
sorted_children = sorted(node.children, key=lambda x: x.span.start_time)
|
||||||
children = sorted(span["children"], key=lambda x: x["start_time"])
|
for child in sorted_children:
|
||||||
for child in children:
|
results.append(
|
||||||
results.append(
|
EvalTrace(
|
||||||
EvalTrace(
|
step=child.span.name,
|
||||||
step=child["name"],
|
input=child.span.attributes.get("input", ""),
|
||||||
input=child["tags"].get("input", ""),
|
output=child.span.attributes.get("output", ""),
|
||||||
output=child["tags"].get("output", ""),
|
session_id=session_id,
|
||||||
session_id=session_id,
|
expected_output="",
|
||||||
)
|
|
||||||
)
|
)
|
||||||
# Recursively search in children
|
)
|
||||||
results.extend(find_execute_turn_children(span["children"]))
|
|
||||||
|
# Recursively process children
|
||||||
|
for child in node.children:
|
||||||
|
results.extend(find_execute_turn_children(child))
|
||||||
return results
|
return results
|
||||||
|
|
||||||
return find_execute_turn_children(trace_data["spans"])
|
return find_execute_turn_children(trace_tree.root)
|
||||||
|
|
141
llama_stack/providers/utils/telemetry/jaeger.py
Normal file
141
llama_stack/providers/utils/telemetry/jaeger.py
Normal file
|
@ -0,0 +1,141 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
from llama_stack.apis.telemetry import Span, SpanNode, Trace, TraceStore, TraceTree
|
||||||
|
|
||||||
|
|
||||||
|
class JaegerTraceStore(TraceStore):
|
||||||
|
def __init__(self, endpoint: str, service_name: str):
|
||||||
|
self.endpoint = endpoint
|
||||||
|
self.service_name = service_name
|
||||||
|
|
||||||
|
async def get_trace(self, trace_id: str) -> TraceTree:
|
||||||
|
params = {
|
||||||
|
"traceID": trace_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(
|
||||||
|
f"{self.endpoint}/{trace_id}", params=params
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise Exception(
|
||||||
|
f"Failed to query Jaeger: {response.status} {await response.text()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
trace_data = await response.json()
|
||||||
|
if not trace_data.get("data") or not trace_data["data"]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# First pass: Build span map
|
||||||
|
span_map = {}
|
||||||
|
for jaeger_span in trace_data["data"][0]["spans"]:
|
||||||
|
start_time = datetime.fromtimestamp(
|
||||||
|
jaeger_span["startTime"] / 1000000
|
||||||
|
)
|
||||||
|
|
||||||
|
# Some systems store end time directly in the span
|
||||||
|
if "endTime" in jaeger_span:
|
||||||
|
end_time = datetime.fromtimestamp(
|
||||||
|
jaeger_span["endTime"] / 1000000
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
duration_microseconds = jaeger_span.get("duration", 0)
|
||||||
|
duration_timedelta = timedelta(
|
||||||
|
microseconds=duration_microseconds
|
||||||
|
)
|
||||||
|
end_time = start_time + duration_timedelta
|
||||||
|
|
||||||
|
span = Span(
|
||||||
|
span_id=jaeger_span["spanID"],
|
||||||
|
trace_id=trace_id,
|
||||||
|
name=jaeger_span["operationName"],
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
parent_span_id=next(
|
||||||
|
(
|
||||||
|
ref["spanID"]
|
||||||
|
for ref in jaeger_span.get("references", [])
|
||||||
|
if ref["refType"] == "CHILD_OF"
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
attributes={
|
||||||
|
tag["key"]: tag["value"]
|
||||||
|
for tag in jaeger_span.get("tags", [])
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
span_map[span.span_id] = SpanNode(span=span)
|
||||||
|
|
||||||
|
# Second pass: Build parent-child relationships
|
||||||
|
root_node = None
|
||||||
|
for span_node in span_map.values():
|
||||||
|
parent_id = span_node.span.parent_span_id
|
||||||
|
if parent_id and parent_id in span_map:
|
||||||
|
span_map[parent_id].children.append(span_node)
|
||||||
|
elif not parent_id:
|
||||||
|
root_node = span_node
|
||||||
|
|
||||||
|
trace = Trace(
|
||||||
|
trace_id=trace_id,
|
||||||
|
root_span_id=root_node.span.span_id if root_node else "",
|
||||||
|
start_time=(
|
||||||
|
root_node.span.start_time if root_node else datetime.now()
|
||||||
|
),
|
||||||
|
end_time=root_node.span.end_time if root_node else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return TraceTree(trace=trace, root=root_node)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Error querying Jaeger trace structure: {str(e)}") from e
|
||||||
|
|
||||||
|
async def get_traces_for_sessions(self, session_ids: List[str]) -> List[Trace]:
|
||||||
|
traces = []
|
||||||
|
|
||||||
|
# Fetch traces for each session ID individually
|
||||||
|
for session_id in session_ids:
|
||||||
|
params = {
|
||||||
|
"service": self.service_name,
|
||||||
|
"tags": f'{{"session_id":"{session_id}"}}',
|
||||||
|
"limit": 100,
|
||||||
|
"lookback": "10000h",
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(self.endpoint, params=params) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise Exception(
|
||||||
|
f"Failed to query Jaeger: {response.status} {await response.text()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
traces_data = await response.json()
|
||||||
|
seen_trace_ids = set()
|
||||||
|
|
||||||
|
for trace_data in traces_data.get("data", []):
|
||||||
|
trace_id = trace_data.get("traceID")
|
||||||
|
if trace_id and trace_id not in seen_trace_ids:
|
||||||
|
seen_trace_ids.add(trace_id)
|
||||||
|
traces.append(
|
||||||
|
Trace(
|
||||||
|
trace_id=trace_id,
|
||||||
|
root_span_id="",
|
||||||
|
start_time=datetime.now(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Error querying Jaeger traces: {str(e)}") from e
|
||||||
|
|
||||||
|
return traces
|
Loading…
Add table
Add a link
Reference in a new issue