llama-stack/llama_stack/providers/utils/telemetry/sqlite_trace_store.py
Dinesh Yeduguru c23363d561
Add ability to query and export spans to dataset (#574)
This PR adds two new methods to the telemetry API:
1) Gives the ability to query spans directly instead of first querying
traces and then using that to get spans
2) Another method save_spans_to_dataset, which builds on the query spans
to save it on dataset.

This give the ability to saves spans that are part of an agent session
to a dataset.

The unique aspect of this API is that we dont require each provider of
telemetry to implement this method. Hence, its implemented in the
protocol class itself. This required the protocol check to be slightly
modified.
2024-12-05 21:07:30 -08:00

180 lines
6.2 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from datetime import datetime
from typing import List, Optional, Protocol
import aiosqlite
from llama_stack.apis.telemetry import QueryCondition, SpanWithChildren, Trace
class TraceStore(Protocol):
async def query_traces(
self,
attribute_filters: Optional[List[QueryCondition]] = None,
limit: Optional[int] = 100,
offset: Optional[int] = 0,
order_by: Optional[List[str]] = None,
) -> List[Trace]: ...
async def get_span_tree(
self,
span_id: str,
attributes_to_return: Optional[List[str]] = None,
max_depth: Optional[int] = None,
) -> SpanWithChildren: ...
class SQLiteTraceStore(TraceStore):
def __init__(self, conn_string: str):
self.conn_string = conn_string
async def query_traces(
self,
attribute_filters: Optional[List[QueryCondition]] = None,
limit: Optional[int] = 100,
offset: Optional[int] = 0,
order_by: Optional[List[str]] = None,
) -> List[Trace]:
def build_where_clause() -> tuple[str, list]:
if not attribute_filters:
return "", []
ops_map = {"eq": "=", "ne": "!=", "gt": ">", "lt": "<"}
conditions = [
f"json_extract(s.attributes, '$.{condition.key}') {ops_map[condition.op]} ?"
for condition in attribute_filters
]
params = [condition.value for condition in attribute_filters]
where_clause = " WHERE " + " AND ".join(conditions)
return where_clause, params
def build_order_clause() -> str:
if not order_by:
return ""
order_clauses = []
for field in order_by:
desc = field.startswith("-")
clean_field = field[1:] if desc else field
order_clauses.append(f"t.{clean_field} {'DESC' if desc else 'ASC'}")
return " ORDER BY " + ", ".join(order_clauses)
# Build the main query
base_query = """
WITH matching_traces AS (
SELECT DISTINCT t.trace_id
FROM traces t
JOIN spans s ON t.trace_id = s.trace_id
{where_clause}
),
filtered_traces AS (
SELECT t.trace_id, t.root_span_id, t.start_time, t.end_time
FROM matching_traces mt
JOIN traces t ON mt.trace_id = t.trace_id
LEFT JOIN spans s ON t.trace_id = s.trace_id
{order_clause}
)
SELECT DISTINCT trace_id, root_span_id, start_time, end_time
FROM filtered_traces
LIMIT {limit} OFFSET {offset}
"""
where_clause, params = build_where_clause()
query = base_query.format(
where_clause=where_clause,
order_clause=build_order_clause(),
limit=limit,
offset=offset,
)
# Execute query and return results
async with aiosqlite.connect(self.conn_string) as conn:
conn.row_factory = aiosqlite.Row
async with conn.execute(query, params) as cursor:
rows = await cursor.fetchall()
return [
Trace(
trace_id=row["trace_id"],
root_span_id=row["root_span_id"],
start_time=datetime.fromisoformat(row["start_time"]),
end_time=datetime.fromisoformat(row["end_time"]),
)
for row in rows
]
async def get_span_tree(
self,
span_id: str,
attributes_to_return: Optional[List[str]] = None,
max_depth: Optional[int] = None,
) -> SpanWithChildren:
# Build the attributes selection
attributes_select = "s.attributes"
if attributes_to_return:
json_object = ", ".join(
f"'{key}', json_extract(s.attributes, '$.{key}')"
for key in attributes_to_return
)
attributes_select = f"json_object({json_object})"
# SQLite CTE query with filtered attributes
query = f"""
WITH RECURSIVE span_tree AS (
SELECT s.*, 1 as depth, {attributes_select} as filtered_attributes
FROM spans s
WHERE s.span_id = ?
UNION ALL
SELECT s.*, st.depth + 1, {attributes_select} as filtered_attributes
FROM spans s
JOIN span_tree st ON s.parent_span_id = st.span_id
WHERE (? IS NULL OR st.depth < ?)
)
SELECT *
FROM span_tree
ORDER BY depth, start_time
"""
async with aiosqlite.connect(self.conn_string) as conn:
conn.row_factory = aiosqlite.Row
async with conn.execute(query, (span_id, max_depth, max_depth)) as cursor:
rows = await cursor.fetchall()
if not rows:
raise ValueError(f"Span {span_id} not found")
# Build span tree
spans_by_id = {}
root_span = None
for row in rows:
span = SpanWithChildren(
span_id=row["span_id"],
trace_id=row["trace_id"],
parent_span_id=row["parent_span_id"],
name=row["name"],
start_time=datetime.fromisoformat(row["start_time"]),
end_time=datetime.fromisoformat(row["end_time"]),
attributes=json.loads(row["filtered_attributes"]),
status=row["status"].lower(),
children=[],
)
spans_by_id[span.span_id] = span
if span.span_id == span_id:
root_span = span
elif span.parent_span_id in spans_by_id:
spans_by_id[span.parent_span_id].children.append(span)
return root_span