forked from phoenix-oss/llama-stack-mirror
This PR adds two new methods to the telemetry API: 1) Gives the ability to query spans directly instead of first querying traces and then using that to get spans 2) Another method save_spans_to_dataset, which builds on the query spans to save it on dataset. This give the ability to saves spans that are part of an agent session to a dataset. The unique aspect of this API is that we dont require each provider of telemetry to implement this method. Hence, its implemented in the protocol class itself. This required the protocol check to be slightly modified.
180 lines
6.2 KiB
Python
180 lines
6.2 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import json
|
|
from datetime import datetime
|
|
from typing import List, Optional, Protocol
|
|
|
|
import aiosqlite
|
|
|
|
from llama_stack.apis.telemetry import QueryCondition, SpanWithChildren, Trace
|
|
|
|
|
|
class TraceStore(Protocol):
|
|
|
|
async def query_traces(
|
|
self,
|
|
attribute_filters: Optional[List[QueryCondition]] = None,
|
|
limit: Optional[int] = 100,
|
|
offset: Optional[int] = 0,
|
|
order_by: Optional[List[str]] = None,
|
|
) -> List[Trace]: ...
|
|
|
|
async def get_span_tree(
|
|
self,
|
|
span_id: str,
|
|
attributes_to_return: Optional[List[str]] = None,
|
|
max_depth: Optional[int] = None,
|
|
) -> SpanWithChildren: ...
|
|
|
|
|
|
class SQLiteTraceStore(TraceStore):
|
|
def __init__(self, conn_string: str):
|
|
self.conn_string = conn_string
|
|
|
|
async def query_traces(
|
|
self,
|
|
attribute_filters: Optional[List[QueryCondition]] = None,
|
|
limit: Optional[int] = 100,
|
|
offset: Optional[int] = 0,
|
|
order_by: Optional[List[str]] = None,
|
|
) -> List[Trace]:
|
|
|
|
def build_where_clause() -> tuple[str, list]:
|
|
if not attribute_filters:
|
|
return "", []
|
|
|
|
ops_map = {"eq": "=", "ne": "!=", "gt": ">", "lt": "<"}
|
|
|
|
conditions = [
|
|
f"json_extract(s.attributes, '$.{condition.key}') {ops_map[condition.op]} ?"
|
|
for condition in attribute_filters
|
|
]
|
|
params = [condition.value for condition in attribute_filters]
|
|
where_clause = " WHERE " + " AND ".join(conditions)
|
|
return where_clause, params
|
|
|
|
def build_order_clause() -> str:
|
|
if not order_by:
|
|
return ""
|
|
|
|
order_clauses = []
|
|
for field in order_by:
|
|
desc = field.startswith("-")
|
|
clean_field = field[1:] if desc else field
|
|
order_clauses.append(f"t.{clean_field} {'DESC' if desc else 'ASC'}")
|
|
return " ORDER BY " + ", ".join(order_clauses)
|
|
|
|
# Build the main query
|
|
base_query = """
|
|
WITH matching_traces AS (
|
|
SELECT DISTINCT t.trace_id
|
|
FROM traces t
|
|
JOIN spans s ON t.trace_id = s.trace_id
|
|
{where_clause}
|
|
),
|
|
filtered_traces AS (
|
|
SELECT t.trace_id, t.root_span_id, t.start_time, t.end_time
|
|
FROM matching_traces mt
|
|
JOIN traces t ON mt.trace_id = t.trace_id
|
|
LEFT JOIN spans s ON t.trace_id = s.trace_id
|
|
{order_clause}
|
|
)
|
|
SELECT DISTINCT trace_id, root_span_id, start_time, end_time
|
|
FROM filtered_traces
|
|
LIMIT {limit} OFFSET {offset}
|
|
"""
|
|
|
|
where_clause, params = build_where_clause()
|
|
query = base_query.format(
|
|
where_clause=where_clause,
|
|
order_clause=build_order_clause(),
|
|
limit=limit,
|
|
offset=offset,
|
|
)
|
|
|
|
# Execute query and return results
|
|
async with aiosqlite.connect(self.conn_string) as conn:
|
|
conn.row_factory = aiosqlite.Row
|
|
async with conn.execute(query, params) as cursor:
|
|
rows = await cursor.fetchall()
|
|
return [
|
|
Trace(
|
|
trace_id=row["trace_id"],
|
|
root_span_id=row["root_span_id"],
|
|
start_time=datetime.fromisoformat(row["start_time"]),
|
|
end_time=datetime.fromisoformat(row["end_time"]),
|
|
)
|
|
for row in rows
|
|
]
|
|
|
|
async def get_span_tree(
|
|
self,
|
|
span_id: str,
|
|
attributes_to_return: Optional[List[str]] = None,
|
|
max_depth: Optional[int] = None,
|
|
) -> SpanWithChildren:
|
|
# Build the attributes selection
|
|
attributes_select = "s.attributes"
|
|
if attributes_to_return:
|
|
json_object = ", ".join(
|
|
f"'{key}', json_extract(s.attributes, '$.{key}')"
|
|
for key in attributes_to_return
|
|
)
|
|
attributes_select = f"json_object({json_object})"
|
|
|
|
# SQLite CTE query with filtered attributes
|
|
query = f"""
|
|
WITH RECURSIVE span_tree AS (
|
|
SELECT s.*, 1 as depth, {attributes_select} as filtered_attributes
|
|
FROM spans s
|
|
WHERE s.span_id = ?
|
|
|
|
UNION ALL
|
|
|
|
SELECT s.*, st.depth + 1, {attributes_select} as filtered_attributes
|
|
FROM spans s
|
|
JOIN span_tree st ON s.parent_span_id = st.span_id
|
|
WHERE (? IS NULL OR st.depth < ?)
|
|
)
|
|
SELECT *
|
|
FROM span_tree
|
|
ORDER BY depth, start_time
|
|
"""
|
|
|
|
async with aiosqlite.connect(self.conn_string) as conn:
|
|
conn.row_factory = aiosqlite.Row
|
|
async with conn.execute(query, (span_id, max_depth, max_depth)) as cursor:
|
|
rows = await cursor.fetchall()
|
|
|
|
if not rows:
|
|
raise ValueError(f"Span {span_id} not found")
|
|
|
|
# Build span tree
|
|
spans_by_id = {}
|
|
root_span = None
|
|
|
|
for row in rows:
|
|
span = SpanWithChildren(
|
|
span_id=row["span_id"],
|
|
trace_id=row["trace_id"],
|
|
parent_span_id=row["parent_span_id"],
|
|
name=row["name"],
|
|
start_time=datetime.fromisoformat(row["start_time"]),
|
|
end_time=datetime.fromisoformat(row["end_time"]),
|
|
attributes=json.loads(row["filtered_attributes"]),
|
|
status=row["status"].lower(),
|
|
children=[],
|
|
)
|
|
|
|
spans_by_id[span.span_id] = span
|
|
|
|
if span.span_id == span_id:
|
|
root_span = span
|
|
elif span.parent_span_id in spans_by_id:
|
|
spans_by_id[span.parent_span_id].children.append(span)
|
|
|
|
return root_span
|