From 5d0c502cdbc84404f146f9329821e0b39262d2f2 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 3 Dec 2024 16:25:20 -0800 Subject: [PATCH] consolidate telemetry to meta reference inline --- llama_stack/apis/telemetry/telemetry.py | 61 +++-- llama_stack/distribution/server/server.py | 8 +- .../meta_reference/telemetry/__init__.py | 15 -- .../inline/meta_reference/telemetry/config.py | 21 -- .../meta_reference/telemetry/console.py | 133 ---------- .../{remote => inline}/telemetry/__init__.py | 0 .../telemetry/meta_reference/__init__.py | 18 ++ .../inline/telemetry/meta_reference/config.py | 45 ++++ .../meta_reference/console_span_processor.py | 95 +++++++ .../meta_reference/sqlite_span_processor.py | 242 ++++++++++++++++++ .../telemetry/meta_reference/telemetry.py} | 170 +++++------- .../telemetry/sample/__init__.py | 0 .../telemetry/sample/config.py | 0 .../telemetry/sample/sample.py | 0 llama_stack/providers/registry/telemetry.py | 23 +- .../telemetry/opentelemetry/__init__.py | 15 -- .../remote/telemetry/opentelemetry/config.py | 39 --- .../opentelemetry/postgres_processor.py | 92 ------- .../providers/utils/telemetry/jaeger.py | 141 ---------- .../providers/utils/telemetry/postgres.py | 114 --------- .../providers/utils/telemetry/sqlite.py | 157 ++++++++++++ 21 files changed, 667 insertions(+), 722 deletions(-) delete mode 100644 llama_stack/providers/inline/meta_reference/telemetry/__init__.py delete mode 100644 llama_stack/providers/inline/meta_reference/telemetry/config.py delete mode 100644 llama_stack/providers/inline/meta_reference/telemetry/console.py rename llama_stack/providers/{remote => inline}/telemetry/__init__.py (100%) create mode 100644 llama_stack/providers/inline/telemetry/meta_reference/__init__.py create mode 100644 llama_stack/providers/inline/telemetry/meta_reference/config.py create mode 100644 llama_stack/providers/inline/telemetry/meta_reference/console_span_processor.py create mode 100644 llama_stack/providers/inline/telemetry/meta_reference/sqlite_span_processor.py rename llama_stack/providers/{remote/telemetry/opentelemetry/opentelemetry.py => inline/telemetry/meta_reference/telemetry.py} (61%) rename llama_stack/providers/{remote => inline}/telemetry/sample/__init__.py (100%) rename llama_stack/providers/{remote => inline}/telemetry/sample/config.py (100%) rename llama_stack/providers/{remote => inline}/telemetry/sample/sample.py (100%) delete mode 100644 llama_stack/providers/remote/telemetry/opentelemetry/__init__.py delete mode 100644 llama_stack/providers/remote/telemetry/opentelemetry/config.py delete mode 100644 llama_stack/providers/remote/telemetry/opentelemetry/postgres_processor.py delete mode 100644 llama_stack/providers/utils/telemetry/jaeger.py delete mode 100644 llama_stack/providers/utils/telemetry/postgres.py create mode 100644 llama_stack/providers/utils/telemetry/sqlite.py diff --git a/llama_stack/apis/telemetry/telemetry.py b/llama_stack/apis/telemetry/telemetry.py index b982bb9cc..e799851c9 100644 --- a/llama_stack/apis/telemetry/telemetry.py +++ b/llama_stack/apis/telemetry/telemetry.py @@ -147,48 +147,57 @@ class EvalTrace(BaseModel): @json_schema_type -class SpanNode(BaseModel): - span: Span - children: List["SpanNode"] = Field(default_factory=list) +class MaterializedSpan(Span): + children: List["MaterializedSpan"] = Field(default_factory=list) status: Optional[SpanStatus] = None @json_schema_type -class TraceTree(BaseModel): - trace: Trace - root: Optional[SpanNode] = None +class QueryCondition(BaseModel): + key: str + op: str + value: Any class TraceStore(Protocol): - async def get_trace( - self, - trace_id: str, - ) -> TraceTree: ... - async def get_traces_for_sessions( + async def query_traces( self, - session_ids: List[str], - ) -> [Trace]: ... + attribute_conditions: Optional[List[QueryCondition]] = None, + attribute_keys_to_return: Optional[List[str]] = None, + limit: Optional[int] = 100, + offset: Optional[int] = 0, + order_by: Optional[List[str]] = None, + ) -> List[Trace]: ... + + async def get_materialized_span( + self, + span_id: str, + attribute_keys_to_return: Optional[List[str]] = None, + max_depth: Optional[int] = None, + ) -> MaterializedSpan: ... @runtime_checkable class Telemetry(Protocol): @webmethod(route="/telemetry/log-event") - async def log_event(self, event: Event) -> None: ... + async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None: ... - @webmethod(route="/telemetry/get-trace", method="POST") - async def get_trace(self, trace_id: str) -> TraceTree: ... - - @webmethod(route="/telemetry/get-agent-trace", method="POST") - async def get_agent_trace( + @webmethod(route="/telemetry/query-traces", method="GET") + async def query_traces( self, - session_ids: List[str], - ) -> List[EvalTrace]: ... + attribute_conditions: Optional[List[QueryCondition]] = None, + attribute_keys_to_return: Optional[List[str]] = None, + limit: Optional[int] = 100, + offset: Optional[int] = 0, + order_by: Optional[List[str]] = None, + ) -> List[Trace]: ... - @webmethod(route="/telemetry/export-agent-trace", method="POST") - async def export_agent_trace( + @webmethod(route="/telemetry/get-materialized-span", method="GET") + async def get_materialized_span( self, - session_ids: List[str], - dataset_id: str, - ) -> None: ... + span_id: str, + attribute_keys_to_return: Optional[List[str]] = None, + max_depth: Optional[int] = None, + ) -> MaterializedSpan: ... diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 8116e2b39..4ae1854df 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -43,9 +43,9 @@ from llama_stack.distribution.stack import ( replace_env_vars, validate_env_pair, ) -from llama_stack.providers.inline.meta_reference.telemetry.console import ( - ConsoleConfig, - ConsoleTelemetryImpl, +from llama_stack.providers.inline.telemetry.meta_reference import ( + TelemetryAdapter, + TelemetryConfig, ) from .endpoints import get_all_api_endpoints @@ -290,7 +290,7 @@ def main(): if Api.telemetry in impls: setup_logger(impls[Api.telemetry]) else: - setup_logger(ConsoleTelemetryImpl(ConsoleConfig())) + setup_logger(TelemetryAdapter(TelemetryConfig())) all_endpoints = get_all_api_endpoints() diff --git a/llama_stack/providers/inline/meta_reference/telemetry/__init__.py b/llama_stack/providers/inline/meta_reference/telemetry/__init__.py deleted file mode 100644 index 4a0c2f6ee..000000000 --- a/llama_stack/providers/inline/meta_reference/telemetry/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from .config import ConsoleConfig - - -async def get_provider_impl(config: ConsoleConfig, _deps): - from .console import ConsoleTelemetryImpl - - impl = ConsoleTelemetryImpl(config) - await impl.initialize() - return impl diff --git a/llama_stack/providers/inline/meta_reference/telemetry/config.py b/llama_stack/providers/inline/meta_reference/telemetry/config.py deleted file mode 100644 index a1db1d4d8..000000000 --- a/llama_stack/providers/inline/meta_reference/telemetry/config.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from enum import Enum - -from llama_models.schema_utils import json_schema_type - -from pydantic import BaseModel - - -class LogFormat(Enum): - TEXT = "text" - JSON = "json" - - -@json_schema_type -class ConsoleConfig(BaseModel): - log_format: LogFormat = LogFormat.TEXT diff --git a/llama_stack/providers/inline/meta_reference/telemetry/console.py b/llama_stack/providers/inline/meta_reference/telemetry/console.py deleted file mode 100644 index 35dcf9561..000000000 --- a/llama_stack/providers/inline/meta_reference/telemetry/console.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import json -from typing import List, Optional - -from .config import LogFormat - -from llama_stack.apis.telemetry import * # noqa: F403 -from .config import ConsoleConfig - - -class ConsoleTelemetryImpl(Telemetry): - def __init__(self, config: ConsoleConfig) -> None: - self.config = config - self.spans = {} - - async def initialize(self) -> None: ... - - async def shutdown(self) -> None: ... - - async def log_event(self, event: Event): - if ( - isinstance(event, StructuredLogEvent) - and event.payload.type == StructuredLogType.SPAN_START.value - ): - self.spans[event.span_id] = event.payload - - names = [] - span_id = event.span_id - while True: - span_payload = self.spans.get(span_id) - if not span_payload: - break - - names = [span_payload.name] + names - span_id = span_payload.parent_span_id - - span_name = ".".join(names) if names else None - - if self.config.log_format == LogFormat.JSON: - formatted = format_event_json(event, span_name) - else: - formatted = format_event_text(event, span_name) - - if formatted: - print(formatted) - - async def get_trace(self, trace_id: str) -> TraceTree: - raise NotImplementedError("Console telemetry does not support trace retrieval") - - async def get_agent_trace( - self, - session_ids: List[str], - ) -> List[EvalTrace]: - raise NotImplementedError( - "Console telemetry does not support agent trace retrieval" - ) - - async def export_agent_trace( - self, - session_ids: List[str], - dataset_id: str, - ) -> None: - raise NotImplementedError( - "Console telemetry does not support agent trace export" - ) - - -COLORS = { - "reset": "\033[0m", - "bold": "\033[1m", - "dim": "\033[2m", - "red": "\033[31m", - "green": "\033[32m", - "yellow": "\033[33m", - "blue": "\033[34m", - "magenta": "\033[35m", - "cyan": "\033[36m", - "white": "\033[37m", -} - -SEVERITY_COLORS = { - LogSeverity.VERBOSE: COLORS["dim"] + COLORS["white"], - LogSeverity.DEBUG: COLORS["cyan"], - LogSeverity.INFO: COLORS["green"], - LogSeverity.WARN: COLORS["yellow"], - LogSeverity.ERROR: COLORS["red"], - LogSeverity.CRITICAL: COLORS["bold"] + COLORS["red"], -} - - -def format_event_text(event: Event, span_name: str) -> Optional[str]: - timestamp = event.timestamp.strftime("%H:%M:%S.%f")[:-3] - span = "" - if span_name: - span = f"{COLORS['magenta']}[{span_name}]{COLORS['reset']} " - if isinstance(event, UnstructuredLogEvent): - severity_color = SEVERITY_COLORS.get(event.severity, COLORS["reset"]) - return ( - f"{COLORS['dim']}{timestamp}{COLORS['reset']} " - f"{severity_color}[{event.severity.name}]{COLORS['reset']} " - f"{span}" - f"{event.message}" - ) - - elif isinstance(event, StructuredLogEvent): - return None - - return f"Unknown event type: {event}" - - -def format_event_json(event: Event, span_name: str) -> Optional[str]: - base_data = { - "timestamp": event.timestamp.isoformat(), - "trace_id": event.trace_id, - "span_id": event.span_id, - "span_name": span_name, - } - - if isinstance(event, UnstructuredLogEvent): - base_data.update( - {"type": "log", "severity": event.severity.name, "message": event.message} - ) - return json.dumps(base_data) - - elif isinstance(event, StructuredLogEvent): - return None - - return json.dumps({"error": f"Unknown event type: {event}"}) diff --git a/llama_stack/providers/remote/telemetry/__init__.py b/llama_stack/providers/inline/telemetry/__init__.py similarity index 100% rename from llama_stack/providers/remote/telemetry/__init__.py rename to llama_stack/providers/inline/telemetry/__init__.py diff --git a/llama_stack/providers/inline/telemetry/meta_reference/__init__.py b/llama_stack/providers/inline/telemetry/meta_reference/__init__.py new file mode 100644 index 000000000..6213d5536 --- /dev/null +++ b/llama_stack/providers/inline/telemetry/meta_reference/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict + +from .config import TelemetryConfig, TelemetrySink +from .telemetry import TelemetryAdapter + +__all__ = ["TelemetryConfig", "TelemetryAdapter", "TelemetrySink"] + + +async def get_provider_impl(config: TelemetryConfig, deps: Dict[str, Any]): + impl = TelemetryAdapter(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/inline/telemetry/meta_reference/config.py b/llama_stack/providers/inline/telemetry/meta_reference/config.py new file mode 100644 index 000000000..0230d24d2 --- /dev/null +++ b/llama_stack/providers/inline/telemetry/meta_reference/config.py @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from enum import Enum +from typing import Any, Dict, List + +from pydantic import BaseModel, Field + +from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR + + +class TelemetrySink(str, Enum): + JAEGER = "jaeger" + SQLITE = "sqlite" + CONSOLE = "console" + + +class TelemetryConfig(BaseModel): + otel_endpoint: str = Field( + default="http://localhost:4318/v1/traces", + description="The OpenTelemetry collector endpoint URL", + ) + service_name: str = Field( + default="llama-stack", + description="The service name to use for telemetry", + ) + sinks: List[TelemetrySink] = Field( + default=[TelemetrySink.CONSOLE, TelemetrySink.SQLITE], + description="List of telemetry sinks to enable (possible values: jaeger, sqlite, console)", + ) + sqlite_db_path: str = Field( + default=(RUNTIME_BASE_DIR / "trace_store.db").as_posix(), + description="The path to the SQLite database to use for storing traces", + ) + + @classmethod + def sample_run_config(cls, **kwargs) -> Dict[str, Any]: + return { + "service_name": "${env.OTEL_SERVICE_NAME:llama-stack}", + "sinks": "${env.TELEMETRY_SINKS:['console', 'sqlite']}", + "sqlite_db_path": "${env.SQLITE_DB_PATH:${runtime.base_dir}/trace_store.db}", + } diff --git a/llama_stack/providers/inline/telemetry/meta_reference/console_span_processor.py b/llama_stack/providers/inline/telemetry/meta_reference/console_span_processor.py new file mode 100644 index 000000000..8d6f779e6 --- /dev/null +++ b/llama_stack/providers/inline/telemetry/meta_reference/console_span_processor.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from datetime import datetime + +from opentelemetry.sdk.trace import ReadableSpan +from opentelemetry.sdk.trace.export import SpanProcessor + +# Colors for console output +COLORS = { + "reset": "\033[0m", + "bold": "\033[1m", + "dim": "\033[2m", + "red": "\033[31m", + "green": "\033[32m", + "yellow": "\033[33m", + "blue": "\033[34m", + "magenta": "\033[35m", + "cyan": "\033[36m", + "white": "\033[37m", +} + + +class ConsoleSpanProcessor(SpanProcessor): + """A SpanProcessor that prints spans to the console with color formatting.""" + + def on_start(self, span: ReadableSpan, parent_context=None) -> None: + """Called when a span starts.""" + timestamp = datetime.utcfromtimestamp(span.start_time / 1e9).strftime( + "%H:%M:%S.%f" + )[:-3] + + print( + f"{COLORS['dim']}{timestamp}{COLORS['reset']} " + f"{COLORS['magenta']}[START]{COLORS['reset']} " + f"{COLORS['cyan']}{span.name}{COLORS['reset']}" + ) + + def on_end(self, span: ReadableSpan) -> None: + """Called when a span ends.""" + timestamp = datetime.utcfromtimestamp(span.end_time / 1e9).strftime( + "%H:%M:%S.%f" + )[:-3] + + # Build the span context string + span_context = ( + f"{COLORS['dim']}{timestamp}{COLORS['reset']} " + f"{COLORS['magenta']}[END]{COLORS['reset']} " + f"{COLORS['cyan']}{span.name}{COLORS['reset']} " + ) + + # Add status if not OK + if span.status.status_code != 0: # UNSET or ERROR + status_color = ( + COLORS["red"] if span.status.status_code == 2 else COLORS["yellow"] + ) + span_context += ( + f" {status_color}[{span.status.status_code}]{COLORS['reset']}" + ) + + # Add duration + duration_ms = (span.end_time - span.start_time) / 1e6 + span_context += f" {COLORS['dim']}({duration_ms:.2f}ms){COLORS['reset']}" + + # Print the main span line + print(span_context) + + # Print attributes indented + if span.attributes: + for key, value in span.attributes.items(): + print(f" {COLORS['dim']}{key}: {value}{COLORS['reset']}") + + # Print events indented + for event in span.events: + event_time = datetime.utcfromtimestamp(event.timestamp / 1e9).strftime( + "%H:%M:%S.%f" + )[:-3] + print( + f" {COLORS['dim']}{event_time}{COLORS['reset']} " + f"{COLORS['cyan']}[EVENT]{COLORS['reset']} {event.name}" + ) + if event.attributes: + for key, value in event.attributes.items(): + print(f" {COLORS['dim']}{key}: {value}{COLORS['reset']}") + + def shutdown(self) -> None: + """Shutdown the processor.""" + pass + + def force_flush(self, timeout_millis: float = None) -> bool: + """Force flush any pending spans.""" + return True diff --git a/llama_stack/providers/inline/telemetry/meta_reference/sqlite_span_processor.py b/llama_stack/providers/inline/telemetry/meta_reference/sqlite_span_processor.py new file mode 100644 index 000000000..553dd5000 --- /dev/null +++ b/llama_stack/providers/inline/telemetry/meta_reference/sqlite_span_processor.py @@ -0,0 +1,242 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +import os +import sqlite3 +import threading +from datetime import datetime, timedelta +from typing import Dict + +from opentelemetry.sdk.trace import SpanProcessor +from opentelemetry.trace import Span + + +class SQLiteSpanProcessor(SpanProcessor): + def __init__(self, conn_string, ttl_days=30): + """Initialize the SQLite span processor with a connection string.""" + self.conn_string = conn_string + self.ttl_days = ttl_days + self.cleanup_task = None + self._thread_local = threading.local() + self._connections: Dict[int, sqlite3.Connection] = {} + self._lock = threading.Lock() + self.setup_database() + + def _get_connection(self) -> sqlite3.Connection: + """Get a thread-specific database connection.""" + thread_id = threading.get_ident() + with self._lock: + if thread_id not in self._connections: + conn = sqlite3.connect(self.conn_string) + self._connections[thread_id] = conn + return self._connections[thread_id] + + def setup_database(self): + """Create the necessary tables if they don't exist.""" + # Create directory if it doesn't exist + os.makedirs(os.path.dirname(self.conn_string), exist_ok=True) + + conn = self._get_connection() + cursor = conn.cursor() + + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS traces ( + trace_id TEXT PRIMARY KEY, + service_name TEXT, + root_span_id TEXT, + start_time TIMESTAMP, + end_time TIMESTAMP, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """ + ) + + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS spans ( + span_id TEXT PRIMARY KEY, + trace_id TEXT REFERENCES traces(trace_id), + parent_span_id TEXT, + name TEXT, + start_time TIMESTAMP, + end_time TIMESTAMP, + attributes TEXT, + status TEXT, + kind TEXT + ) + """ + ) + + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS span_events ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + span_id TEXT REFERENCES spans(span_id), + name TEXT, + timestamp TIMESTAMP, + attributes TEXT + ) + """ + ) + + cursor.execute( + """ + CREATE INDEX IF NOT EXISTS idx_traces_created_at + ON traces(created_at) + """ + ) + + conn.commit() + cursor.close() + + # Start periodic cleanup in a separate thread + self.cleanup_task = threading.Thread(target=self._periodic_cleanup, daemon=True) + self.cleanup_task.start() + + def _cleanup_old_data(self): + """Delete records older than TTL.""" + try: + conn = self._get_connection() + cutoff_date = (datetime.now() - timedelta(days=self.ttl_days)).isoformat() + cursor = conn.cursor() + + # Delete old span events + cursor.execute( + """ + DELETE FROM span_events + WHERE span_id IN ( + SELECT span_id FROM spans + WHERE trace_id IN ( + SELECT trace_id FROM traces + WHERE created_at < ? + ) + ) + """, + (cutoff_date,), + ) + + # Delete old spans + cursor.execute( + """ + DELETE FROM spans + WHERE trace_id IN ( + SELECT trace_id FROM traces + WHERE created_at < ? + ) + """, + (cutoff_date,), + ) + + # Delete old traces + cursor.execute("DELETE FROM traces WHERE created_at < ?", (cutoff_date,)) + + conn.commit() + cursor.close() + except Exception as e: + print(f"Error during cleanup: {e}") + + def _periodic_cleanup(self): + """Run cleanup periodically.""" + import time + + while True: + time.sleep(3600) # Sleep for 1 hour + self._cleanup_old_data() + + def on_start(self, span: Span, parent_context=None): + """Called when a span starts.""" + pass + + def on_end(self, span: Span): + """Called when a span ends. Export the span data to SQLite.""" + try: + conn = self._get_connection() + cursor = conn.cursor() + + trace_id = format(span.get_span_context().trace_id, "032x") + span_id = format(span.get_span_context().span_id, "016x") + service_name = span.resource.attributes.get("service.name", "unknown") + + parent_span_id = None + parent_context = span.parent + if parent_context: + parent_span_id = format(parent_context.span_id, "016x") + + # Insert into traces + cursor.execute( + """ + INSERT INTO traces ( + trace_id, service_name, root_span_id, start_time, end_time + ) VALUES (?, ?, ?, ?, ?) + ON CONFLICT(trace_id) DO UPDATE SET + root_span_id = COALESCE(root_span_id, excluded.root_span_id), + start_time = MIN(excluded.start_time, start_time), + end_time = MAX(excluded.end_time, end_time) + """, + ( + trace_id, + service_name, + (span_id if not parent_span_id else None), + datetime.fromtimestamp(span.start_time / 1e9).isoformat(), + datetime.fromtimestamp(span.end_time / 1e9).isoformat(), + ), + ) + + # Insert into spans + cursor.execute( + """ + INSERT INTO spans ( + span_id, trace_id, parent_span_id, name, + start_time, end_time, attributes, status, + kind + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + span_id, + trace_id, + parent_span_id, + span.name, + datetime.fromtimestamp(span.start_time / 1e9).isoformat(), + datetime.fromtimestamp(span.end_time / 1e9).isoformat(), + json.dumps(dict(span.attributes)), + span.status.status_code.name, + span.kind.name, + ), + ) + + for event in span.events: + cursor.execute( + """ + INSERT INTO span_events ( + span_id, name, timestamp, attributes + ) VALUES (?, ?, ?, ?) + """, + ( + span_id, + event.name, + datetime.fromtimestamp(event.timestamp / 1e9).isoformat(), + json.dumps(dict(event.attributes)), + ), + ) + + conn.commit() + cursor.close() + except Exception as e: + print(f"Error exporting span to SQLite: {e}") + + def shutdown(self): + """Cleanup any resources.""" + with self._lock: + for conn in self._connections.values(): + if conn: + conn.close() + self._connections.clear() + + def force_flush(self, timeout_millis=30000): + """Force export of spans.""" + pass diff --git a/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py similarity index 61% rename from llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py rename to llama_stack/providers/inline/telemetry/meta_reference/telemetry.py index 59094a080..1f27876e0 100644 --- a/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import threading -from typing import List +from typing import List, Optional from opentelemetry import metrics, trace from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter @@ -17,17 +17,18 @@ from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.semconv.resource import ResourceAttributes -from llama_stack.distribution.datatypes import Api -from llama_stack.providers.remote.telemetry.opentelemetry.postgres_processor import ( - PostgresSpanProcessor, +from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import ( + ConsoleSpanProcessor, ) -from llama_stack.providers.utils.telemetry.jaeger import JaegerTraceStore -from llama_stack.providers.utils.telemetry.postgres import PostgresTraceStore +from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor import ( + SQLiteSpanProcessor, +) +from llama_stack.providers.utils.telemetry.sqlite import SQLiteTraceStore from llama_stack.apis.telemetry import * # noqa: F403 -from .config import OpenTelemetryConfig +from .config import TelemetryConfig, TelemetrySink _GLOBAL_STORAGE = { "active_spans": {}, @@ -53,19 +54,9 @@ def is_tracing_enabled(tracer): return span.is_recording() -class OpenTelemetryAdapter(Telemetry): - def __init__(self, config: OpenTelemetryConfig, deps) -> None: +class TelemetryAdapter(Telemetry): + def __init__(self, config: TelemetryConfig) -> None: self.config = config - self.datasetio = deps[Api.datasetio] - - if config.trace_store == "jaeger": - self.trace_store = JaegerTraceStore( - config.jaeger_query_endpoint, config.service_name - ) - elif config.trace_store == "postgres": - self.trace_store = PostgresTraceStore(config.postgres_conn_string) - else: - raise ValueError(f"Invalid trace store: {config.trace_store}") resource = Resource.create( { @@ -75,25 +66,29 @@ class OpenTelemetryAdapter(Telemetry): provider = TracerProvider(resource=resource) trace.set_tracer_provider(provider) - otlp_exporter = OTLPSpanExporter( - endpoint=self.config.otel_endpoint, - ) - span_processor = BatchSpanProcessor(otlp_exporter) - trace.get_tracer_provider().add_span_processor(span_processor) - trace.get_tracer_provider().add_span_processor( - PostgresSpanProcessor(self.config.postgres_conn_string) - ) - # Set up metrics - metric_reader = PeriodicExportingMetricReader( - OTLPMetricExporter( + if TelemetrySink.JAEGER in self.config.sinks: + otlp_exporter = OTLPSpanExporter( endpoint=self.config.otel_endpoint, ) - ) - metric_provider = MeterProvider( - resource=resource, metric_readers=[metric_reader] - ) - metrics.set_meter_provider(metric_provider) - self.meter = metrics.get_meter(__name__) + span_processor = BatchSpanProcessor(otlp_exporter) + trace.get_tracer_provider().add_span_processor(span_processor) + metric_reader = PeriodicExportingMetricReader( + OTLPMetricExporter( + endpoint=self.config.otel_endpoint, + ) + ) + metric_provider = MeterProvider( + resource=resource, metric_readers=[metric_reader] + ) + metrics.set_meter_provider(metric_provider) + self.meter = metrics.get_meter(__name__) + if TelemetrySink.SQLITE in self.config.sinks: + trace.get_tracer_provider().add_span_processor( + SQLiteSpanProcessor(self.config.sqlite_db_path) + ) + self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path) + if TelemetrySink.CONSOLE in self.config.sinks: + trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor()) self._lock = _global_lock async def initialize(self) -> None: @@ -104,15 +99,17 @@ class OpenTelemetryAdapter(Telemetry): trace.get_tracer_provider().shutdown() metrics.get_meter_provider().shutdown() - async def log_event(self, event: Event) -> None: + async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None: if isinstance(event, UnstructuredLogEvent): - self._log_unstructured(event) + self._log_unstructured(event, ttl_seconds) elif isinstance(event, MetricEvent): self._log_metric(event) elif isinstance(event, StructuredLogEvent): - self._log_structured(event) + self._log_structured(event, ttl_seconds) + else: + raise ValueError(f"Unknown event type: {event}") - def _log_unstructured(self, event: UnstructuredLogEvent) -> None: + def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None: with self._lock: # Use global storage instead of instance storage span_id = string_to_span_id(event.span_id) @@ -125,6 +122,7 @@ class OpenTelemetryAdapter(Telemetry): attributes={ "message": event.message, "severity": event.severity.value, + "__ttl__": ttl_seconds, **event.attributes, }, timestamp=timestamp_ns, @@ -175,11 +173,14 @@ class OpenTelemetryAdapter(Telemetry): ) return _GLOBAL_STORAGE["up_down_counters"][name] - def _log_structured(self, event: StructuredLogEvent) -> None: + def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None: with self._lock: span_id = string_to_span_id(event.span_id) trace_id = string_to_trace_id(event.trace_id) tracer = trace.get_tracer(__name__) + if event.attributes is None: + event.attributes = {} + event.attributes["__ttl__"] = ttl_seconds if isinstance(event.payload, SpanStartPayload): # Check if span already exists to prevent duplicates @@ -216,66 +217,33 @@ class OpenTelemetryAdapter(Telemetry): span.set_status(status) span.end() _GLOBAL_STORAGE["active_spans"].pop(span_id, None) + else: + raise ValueError(f"Unknown structured log event: {event}") - async def get_trace(self, trace_id: str) -> TraceTree: - return await self.trace_store.get_trace(trace_id) - - async def get_agent_trace( + async def query_traces( self, - session_ids: List[str], - ) -> List[EvalTrace]: - traces = [] - for session_id in session_ids: - traces_for_session = await self.trace_store.get_traces_for_sessions( - [session_id] - ) - for session_trace in traces_for_session: - trace_details = await self._get_simplified_agent_trace( - session_trace.trace_id, session_id - ) - traces.extend(trace_details) + attribute_conditions: Optional[List[QueryCondition]] = None, + attribute_keys_to_return: Optional[List[str]] = None, + limit: Optional[int] = 100, + offset: Optional[int] = 0, + order_by: Optional[List[str]] = None, + ) -> List[Trace]: + return await self.trace_store.query_traces( + attribute_conditions=attribute_conditions, + attribute_keys_to_return=attribute_keys_to_return, + limit=limit, + offset=offset, + order_by=order_by, + ) - return traces - - async def export_agent_trace(self, session_ids: List[str], dataset_id: str) -> None: - traces = await self.get_agent_trace(session_ids) - traces_dict = [ - { - "step": trace.step, - "input": trace.input, - "output": trace.output, - "session_id": trace.session_id, - } - for trace in traces - ] - await self.datasetio.upload_rows(dataset_id, traces_dict) - - async def _get_simplified_agent_trace( - self, trace_id: str, session_id: str - ) -> List[EvalTrace]: - trace_tree = await self.get_trace(trace_id) - if not trace_tree or not trace_tree.root: - return [] - - def find_execute_turn_children(node: SpanNode) -> List[EvalTrace]: - results = [] - if node.span.name == "create_and_execute_turn": - # Sort children by start time - sorted_children = sorted(node.children, key=lambda x: x.span.start_time) - for child in sorted_children: - results.append( - EvalTrace( - step=child.span.name, - input=str(child.span.attributes.get("input", "")), - output=str(child.span.attributes.get("output", "")), - session_id=session_id, - expected_output="", - ) - ) - - # Recursively process children - for child in node.children: - results.extend(find_execute_turn_children(child)) - return results - - return find_execute_turn_children(trace_tree.root) + async def get_materialized_span( + self, + span_id: str, + attribute_keys_to_return: Optional[List[str]] = None, + max_depth: Optional[int] = None, + ) -> MaterializedSpan: + return await self.trace_store.get_materialized_span( + span_id=span_id, + attribute_keys_to_return=attribute_keys_to_return, + max_depth=max_depth, + ) diff --git a/llama_stack/providers/remote/telemetry/sample/__init__.py b/llama_stack/providers/inline/telemetry/sample/__init__.py similarity index 100% rename from llama_stack/providers/remote/telemetry/sample/__init__.py rename to llama_stack/providers/inline/telemetry/sample/__init__.py diff --git a/llama_stack/providers/remote/telemetry/sample/config.py b/llama_stack/providers/inline/telemetry/sample/config.py similarity index 100% rename from llama_stack/providers/remote/telemetry/sample/config.py rename to llama_stack/providers/inline/telemetry/sample/config.py diff --git a/llama_stack/providers/remote/telemetry/sample/sample.py b/llama_stack/providers/inline/telemetry/sample/sample.py similarity index 100% rename from llama_stack/providers/remote/telemetry/sample/sample.py rename to llama_stack/providers/inline/telemetry/sample/sample.py diff --git a/llama_stack/providers/registry/telemetry.py b/llama_stack/providers/registry/telemetry.py index 94b63ab1c..f7a7e65d3 100644 --- a/llama_stack/providers/registry/telemetry.py +++ b/llama_stack/providers/registry/telemetry.py @@ -15,8 +15,8 @@ def available_providers() -> List[ProviderSpec]: api=Api.telemetry, provider_type="inline::meta-reference", pip_packages=[], - module="llama_stack.providers.inline.meta_reference.telemetry", - config_class="llama_stack.providers.inline.meta_reference.telemetry.ConsoleConfig", + module="llama_stack.providers.inline.telemetry.meta_reference", + config_class="llama_stack.providers.inline.telemetry.meta_reference.config.TelemetryConfig", ), remote_provider_spec( api=Api.telemetry, @@ -27,23 +27,4 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.remote.telemetry.sample.SampleConfig", ), ), - RemoteProviderSpec( - api=Api.telemetry, - provider_type="remote::opentelemetry-jaeger", - config_class="llama_stack.providers.remote.telemetry.opentelemetry.OpenTelemetryConfig", - adapter=AdapterSpec( - adapter_type="opentelemetry-jaeger", - pip_packages=[ - "opentelemetry-api", - "opentelemetry-sdk", - "opentelemetry-exporter-jaeger", - "opentelemetry-semantic-conventions", - ], - module="llama_stack.providers.remote.telemetry.opentelemetry", - config_class="llama_stack.providers.remote.telemetry.opentelemetry.OpenTelemetryConfig", - ), - api_dependencies=[ - Api.datasetio, - ], - ), ] diff --git a/llama_stack/providers/remote/telemetry/opentelemetry/__init__.py b/llama_stack/providers/remote/telemetry/opentelemetry/__init__.py deleted file mode 100644 index 56050411d..000000000 --- a/llama_stack/providers/remote/telemetry/opentelemetry/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from .config import OpenTelemetryConfig - - -async def get_adapter_impl(config: OpenTelemetryConfig, deps): - from .opentelemetry import OpenTelemetryAdapter - - impl = OpenTelemetryAdapter(config, deps) - await impl.initialize() - return impl diff --git a/llama_stack/providers/remote/telemetry/opentelemetry/config.py b/llama_stack/providers/remote/telemetry/opentelemetry/config.py deleted file mode 100644 index 9d829d110..000000000 --- a/llama_stack/providers/remote/telemetry/opentelemetry/config.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any, Dict - -from pydantic import BaseModel, Field - - -class OpenTelemetryConfig(BaseModel): - otel_endpoint: str = Field( - default="http://localhost:4318/v1/traces", - description="The OpenTelemetry collector endpoint URL", - ) - service_name: str = Field( - default="llama-stack", - description="The service name to use for telemetry", - ) - trace_store: str = Field( - default="postgres", - description="The trace store to use for telemetry", - ) - jaeger_query_endpoint: str = Field( - default="http://localhost:16686/api/traces", - description="The Jaeger query endpoint URL", - ) - postgres_conn_string: str = Field( - default="host=localhost dbname=llama_stack user=llama_stack password=llama_stack port=5432", - description="The PostgreSQL connection string to use for storing traces", - ) - - @classmethod - def sample_run_config(cls, **kwargs) -> Dict[str, Any]: - return { - "otel_endpoint": "${env.OTEL_ENDPOINT:http://localhost:4318/v1/traces}", - "service_name": "${env.OTEL_SERVICE_NAME:llama-stack}", - } diff --git a/llama_stack/providers/remote/telemetry/opentelemetry/postgres_processor.py b/llama_stack/providers/remote/telemetry/opentelemetry/postgres_processor.py deleted file mode 100644 index de8bf15b6..000000000 --- a/llama_stack/providers/remote/telemetry/opentelemetry/postgres_processor.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import json -from datetime import datetime - -import psycopg2 -from opentelemetry.sdk.trace import SpanProcessor -from opentelemetry.trace import Span - - -class PostgresSpanProcessor(SpanProcessor): - def __init__(self, conn_string): - """Initialize the PostgreSQL span processor with a connection string.""" - self.conn_string = conn_string - self.conn = None - self.setup_database() - - def setup_database(self): - """Create the necessary table if it doesn't exist.""" - with psycopg2.connect(self.conn_string) as conn: - with conn.cursor() as cur: - cur.execute( - """ - CREATE TABLE IF NOT EXISTS traces ( - trace_id TEXT, - span_id TEXT, - parent_span_id TEXT, - name TEXT, - start_time TIMESTAMP, - end_time TIMESTAMP, - attributes JSONB, - status TEXT, - kind TEXT, - service_name TEXT, - session_id TEXT - ) - """ - ) - conn.commit() - - def on_start(self, span: Span, parent_context=None): - """Called when a span starts.""" - pass - - def on_end(self, span: Span): - """Called when a span ends. Export the span data to PostgreSQL.""" - try: - with psycopg2.connect(self.conn_string) as conn: - with conn.cursor() as cur: - - cur.execute( - """ - INSERT INTO traces ( - trace_id, span_id, parent_span_id, name, - start_time, end_time, attributes, status, - kind, service_name, session_id - ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) - """, - ( - format(span.get_span_context().trace_id, "032x"), - format(span.get_span_context().span_id, "016x"), - ( - format(span.parent.span_id, "016x") - if span.parent - else None - ), - span.name, - datetime.fromtimestamp(span.start_time / 1e9), - datetime.fromtimestamp(span.end_time / 1e9), - json.dumps(dict(span.attributes)), - span.status.status_code.name, - span.kind.name, - span.resource.attributes.get("service.name", "unknown"), - span.attributes.get("session_id", None), - ), - ) - conn.commit() - except Exception as e: - print(f"Error exporting span to PostgreSQL: {e}") - - def shutdown(self): - """Cleanup any resources.""" - if self.conn: - self.conn.close() - - def force_flush(self, timeout_millis=30000): - """Force export of spans.""" - pass diff --git a/llama_stack/providers/utils/telemetry/jaeger.py b/llama_stack/providers/utils/telemetry/jaeger.py deleted file mode 100644 index 3c28748ed..000000000 --- a/llama_stack/providers/utils/telemetry/jaeger.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from datetime import datetime, timedelta -from typing import List - -import aiohttp - -from llama_stack.apis.telemetry import Span, SpanNode, Trace, TraceStore, TraceTree - - -class JaegerTraceStore(TraceStore): - def __init__(self, endpoint: str, service_name: str): - self.endpoint = endpoint - self.service_name = service_name - - async def get_trace(self, trace_id: str) -> TraceTree: - params = { - "traceID": trace_id, - } - - try: - async with aiohttp.ClientSession() as session: - async with session.get( - f"{self.endpoint}/{trace_id}", params=params - ) as response: - if response.status != 200: - raise Exception( - f"Failed to query Jaeger: {response.status} {await response.text()}" - ) - - trace_data = await response.json() - if not trace_data.get("data") or not trace_data["data"]: - return None - - # First pass: Build span map - span_map = {} - for jaeger_span in trace_data["data"][0]["spans"]: - start_time = datetime.fromtimestamp( - jaeger_span["startTime"] / 1000000 - ) - - # Some systems store end time directly in the span - if "endTime" in jaeger_span: - end_time = datetime.fromtimestamp( - jaeger_span["endTime"] / 1000000 - ) - else: - duration_microseconds = jaeger_span.get("duration", 0) - duration_timedelta = timedelta( - microseconds=duration_microseconds - ) - end_time = start_time + duration_timedelta - - span = Span( - span_id=jaeger_span["spanID"], - trace_id=trace_id, - name=jaeger_span["operationName"], - start_time=start_time, - end_time=end_time, - parent_span_id=next( - ( - ref["spanID"] - for ref in jaeger_span.get("references", []) - if ref["refType"] == "CHILD_OF" - ), - None, - ), - attributes={ - tag["key"]: tag["value"] - for tag in jaeger_span.get("tags", []) - }, - ) - - span_map[span.span_id] = SpanNode(span=span) - - # Second pass: Build parent-child relationships - root_node = None - for span_node in span_map.values(): - parent_id = span_node.span.parent_span_id - if parent_id and parent_id in span_map: - span_map[parent_id].children.append(span_node) - elif not parent_id: - root_node = span_node - - trace = Trace( - trace_id=trace_id, - root_span_id=root_node.span.span_id if root_node else "", - start_time=( - root_node.span.start_time if root_node else datetime.now() - ), - end_time=root_node.span.end_time if root_node else None, - ) - - return TraceTree(trace=trace, root=root_node) - - except Exception as e: - raise Exception(f"Error querying Jaeger trace structure: {str(e)}") from e - - async def get_traces_for_sessions(self, session_ids: List[str]) -> List[Trace]: - traces = [] - - # Fetch traces for each session ID individually - for session_id in session_ids: - params = { - "service": self.service_name, - "tags": f'{{"session_id":"{session_id}"}}', - "limit": 100, - "lookback": "10000h", - } - - try: - async with aiohttp.ClientSession() as session: - async with session.get(self.endpoint, params=params) as response: - if response.status != 200: - raise Exception( - f"Failed to query Jaeger: {response.status} {await response.text()}" - ) - - traces_data = await response.json() - seen_trace_ids = set() - - for trace_data in traces_data.get("data", []): - trace_id = trace_data.get("traceID") - if trace_id and trace_id not in seen_trace_ids: - seen_trace_ids.add(trace_id) - traces.append( - Trace( - trace_id=trace_id, - root_span_id="", - start_time=datetime.now(), - ) - ) - - except Exception as e: - raise Exception(f"Error querying Jaeger traces: {str(e)}") from e - - return traces diff --git a/llama_stack/providers/utils/telemetry/postgres.py b/llama_stack/providers/utils/telemetry/postgres.py deleted file mode 100644 index ed68fc293..000000000 --- a/llama_stack/providers/utils/telemetry/postgres.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import json -from datetime import datetime -from typing import List, Optional - -import psycopg2 - -from llama_stack.apis.telemetry import Span, SpanNode, Trace, TraceStore, TraceTree - - -class PostgresTraceStore(TraceStore): - def __init__(self, conn_string: str): - self.conn_string = conn_string - - async def get_trace(self, trace_id: str) -> Optional[TraceTree]: - try: - with psycopg2.connect(self.conn_string) as conn: - with conn.cursor() as cur: - # Fetch all spans for the trace - cur.execute( - """ - SELECT trace_id, span_id, parent_span_id, name, - start_time, end_time, attributes - FROM traces - WHERE trace_id = %s - """, - (trace_id,), - ) - spans_data = cur.fetchall() - - if not spans_data: - return None - - # First pass: Build span map - span_map = {} - for span_data in spans_data: - # Ensure attributes is a string before parsing - attributes = span_data[6] - if isinstance(attributes, dict): - attributes = json.dumps(attributes) - - span = Span( - span_id=span_data[1], - trace_id=span_data[0], - name=span_data[3], - start_time=span_data[4], - end_time=span_data[5], - parent_span_id=span_data[2], - attributes=json.loads( - attributes - ), # Now safely parse the JSON string - ) - span_map[span.span_id] = SpanNode(span=span) - - # Second pass: Build parent-child relationships - root_node = None - for span_node in span_map.values(): - parent_id = span_node.span.parent_span_id - if parent_id and parent_id in span_map: - span_map[parent_id].children.append(span_node) - elif not parent_id: - root_node = span_node - - trace = Trace( - trace_id=trace_id, - root_span_id=root_node.span.span_id if root_node else "", - start_time=( - root_node.span.start_time if root_node else datetime.now() - ), - end_time=root_node.span.end_time if root_node else None, - ) - - return TraceTree(trace=trace, root=root_node) - - except Exception as e: - raise Exception( - f"Error querying PostgreSQL trace structure: {str(e)}" - ) from e - - async def get_traces_for_sessions(self, session_ids: List[str]) -> List[Trace]: - traces = [] - try: - with psycopg2.connect(self.conn_string) as conn: - with conn.cursor() as cur: - # Query traces for all session IDs - cur.execute( - """ - SELECT DISTINCT trace_id, MIN(start_time) as start_time - FROM traces - WHERE attributes->>'session_id' = ANY(%s) - GROUP BY trace_id - """, - (session_ids,), - ) - traces_data = cur.fetchall() - - for trace_data in traces_data: - traces.append( - Trace( - trace_id=trace_data[0], - root_span_id="", - start_time=trace_data[1], - ) - ) - - except Exception as e: - raise Exception(f"Error querying PostgreSQL traces: {str(e)}") from e - - return traces diff --git a/llama_stack/providers/utils/telemetry/sqlite.py b/llama_stack/providers/utils/telemetry/sqlite.py new file mode 100644 index 000000000..4ccabf200 --- /dev/null +++ b/llama_stack/providers/utils/telemetry/sqlite.py @@ -0,0 +1,157 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +from datetime import datetime +from typing import List, Optional + +import aiosqlite + +from llama_stack.apis.telemetry import ( + MaterializedSpan, + QueryCondition, + Trace, + TraceStore, +) + + +class SQLiteTraceStore(TraceStore): + def __init__(self, conn_string: str): + self.conn_string = conn_string + + async def query_traces( + self, + attribute_conditions: Optional[List[QueryCondition]] = None, + attribute_keys_to_return: Optional[List[str]] = None, + limit: Optional[int] = 100, + offset: Optional[int] = 0, + order_by: Optional[List[str]] = None, + ) -> List[Trace]: + # Build the SQL query with attribute selection + select_clause = """ + SELECT DISTINCT t.trace_id, t.root_span_id, t.start_time, t.end_time + """ + if attribute_keys_to_return: + for key in attribute_keys_to_return: + select_clause += ( + f", json_extract(s.attributes, '$.{key}') as attr_{key}" + ) + + query = ( + select_clause + + """ + FROM traces t + JOIN spans s ON t.trace_id = s.trace_id + """ + ) + params = [] + + # Add attribute conditions if present + if attribute_conditions: + conditions = [] + for condition in attribute_conditions: + conditions.append( + f"json_extract(s.attributes, '$.{condition.key}') {condition.op} ?" + ) + params.append(condition.value) + if conditions: + query += " WHERE " + " AND ".join(conditions) + + # Add ordering + if order_by: + order_clauses = [] + for field in order_by: + desc = False + if field.startswith("-"): + field = field[1:] + desc = True + order_clauses.append(f"t.{field} {'DESC' if desc else 'ASC'}") + query += " ORDER BY " + ", ".join(order_clauses) + + # Add limit and offset + query += f" LIMIT {limit} OFFSET {offset}" + + async with aiosqlite.connect(self.conn_string) as conn: + conn.row_factory = aiosqlite.Row + async with conn.execute(query, params) as cursor: + rows = await cursor.fetchall() + return [ + Trace( + trace_id=row["trace_id"], + root_span_id=row["root_span_id"], + start_time=datetime.fromisoformat(row["start_time"]), + end_time=datetime.fromisoformat(row["end_time"]), + ) + for row in rows + ] + + async def get_materialized_span( + self, + span_id: str, + attribute_keys_to_return: Optional[List[str]] = None, + max_depth: Optional[int] = None, + ) -> MaterializedSpan: + # Build the attributes selection + attributes_select = "s.attributes" + if attribute_keys_to_return: + json_object = ", ".join( + f"'{key}', json_extract(s.attributes, '$.{key}')" + for key in attribute_keys_to_return + ) + attributes_select = f"json_object({json_object})" + + # SQLite CTE query with filtered attributes + query = f""" + WITH RECURSIVE span_tree AS ( + SELECT s.*, 1 as depth, {attributes_select} as filtered_attributes + FROM spans s + WHERE s.span_id = ? + + UNION ALL + + SELECT s.*, st.depth + 1, {attributes_select} as filtered_attributes + FROM spans s + JOIN span_tree st ON s.parent_span_id = st.span_id + WHERE (? IS NULL OR st.depth < ?) + ) + SELECT * + FROM span_tree + ORDER BY depth, start_time + """ + + async with aiosqlite.connect(self.conn_string) as conn: + conn.row_factory = aiosqlite.Row + async with conn.execute(query, (span_id, max_depth, max_depth)) as cursor: + rows = await cursor.fetchall() + + if not rows: + raise ValueError(f"Span {span_id} not found") + + # Build span tree + spans_by_id = {} + root_span = None + + for row in rows: + span = MaterializedSpan( + span_id=row["span_id"], + trace_id=row["trace_id"], + parent_span_id=row["parent_span_id"], + name=row["name"], + start_time=datetime.fromisoformat(row["start_time"]), + end_time=datetime.fromisoformat(row["end_time"]), + attributes=json.loads(row["filtered_attributes"]), + status=row["status"].lower(), + children=[], + ) + + spans_by_id[span.span_id] = span + + if span.span_id == span_id: + root_span = span + elif span.parent_span_id in spans_by_id: + spans_by_id[span.parent_span_id].children.append(span) + + return root_span