consolidate telemetry to meta reference inline

This commit is contained in:
Dinesh Yeduguru 2024-12-03 16:25:20 -08:00
parent cb49d21a49
commit 5d0c502cdb
21 changed files with 667 additions and 722 deletions

View file

@ -147,48 +147,57 @@ class EvalTrace(BaseModel):
@json_schema_type
class SpanNode(BaseModel):
span: Span
children: List["SpanNode"] = Field(default_factory=list)
class MaterializedSpan(Span):
children: List["MaterializedSpan"] = Field(default_factory=list)
status: Optional[SpanStatus] = None
@json_schema_type
class TraceTree(BaseModel):
trace: Trace
root: Optional[SpanNode] = None
class QueryCondition(BaseModel):
key: str
op: str
value: Any
class TraceStore(Protocol):
async def get_trace(
self,
trace_id: str,
) -> TraceTree: ...
async def get_traces_for_sessions(
async def query_traces(
self,
session_ids: List[str],
) -> [Trace]: ...
attribute_conditions: Optional[List[QueryCondition]] = None,
attribute_keys_to_return: Optional[List[str]] = None,
limit: Optional[int] = 100,
offset: Optional[int] = 0,
order_by: Optional[List[str]] = None,
) -> List[Trace]: ...
async def get_materialized_span(
self,
span_id: str,
attribute_keys_to_return: Optional[List[str]] = None,
max_depth: Optional[int] = None,
) -> MaterializedSpan: ...
@runtime_checkable
class Telemetry(Protocol):
@webmethod(route="/telemetry/log-event")
async def log_event(self, event: Event) -> None: ...
async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None: ...
@webmethod(route="/telemetry/get-trace", method="POST")
async def get_trace(self, trace_id: str) -> TraceTree: ...
@webmethod(route="/telemetry/get-agent-trace", method="POST")
async def get_agent_trace(
@webmethod(route="/telemetry/query-traces", method="GET")
async def query_traces(
self,
session_ids: List[str],
) -> List[EvalTrace]: ...
attribute_conditions: Optional[List[QueryCondition]] = None,
attribute_keys_to_return: Optional[List[str]] = None,
limit: Optional[int] = 100,
offset: Optional[int] = 0,
order_by: Optional[List[str]] = None,
) -> List[Trace]: ...
@webmethod(route="/telemetry/export-agent-trace", method="POST")
async def export_agent_trace(
@webmethod(route="/telemetry/get-materialized-span", method="GET")
async def get_materialized_span(
self,
session_ids: List[str],
dataset_id: str,
) -> None: ...
span_id: str,
attribute_keys_to_return: Optional[List[str]] = None,
max_depth: Optional[int] = None,
) -> MaterializedSpan: ...

View file

@ -43,9 +43,9 @@ from llama_stack.distribution.stack import (
replace_env_vars,
validate_env_pair,
)
from llama_stack.providers.inline.meta_reference.telemetry.console import (
ConsoleConfig,
ConsoleTelemetryImpl,
from llama_stack.providers.inline.telemetry.meta_reference import (
TelemetryAdapter,
TelemetryConfig,
)
from .endpoints import get_all_api_endpoints
@ -290,7 +290,7 @@ def main():
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])
else:
setup_logger(ConsoleTelemetryImpl(ConsoleConfig()))
setup_logger(TelemetryAdapter(TelemetryConfig()))
all_endpoints = get_all_api_endpoints()

View file

@ -1,15 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import ConsoleConfig
async def get_provider_impl(config: ConsoleConfig, _deps):
from .console import ConsoleTelemetryImpl
impl = ConsoleTelemetryImpl(config)
await impl.initialize()
return impl

View file

@ -1,21 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
class LogFormat(Enum):
TEXT = "text"
JSON = "json"
@json_schema_type
class ConsoleConfig(BaseModel):
log_format: LogFormat = LogFormat.TEXT

View file

@ -1,133 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import List, Optional
from .config import LogFormat
from llama_stack.apis.telemetry import * # noqa: F403
from .config import ConsoleConfig
class ConsoleTelemetryImpl(Telemetry):
def __init__(self, config: ConsoleConfig) -> None:
self.config = config
self.spans = {}
async def initialize(self) -> None: ...
async def shutdown(self) -> None: ...
async def log_event(self, event: Event):
if (
isinstance(event, StructuredLogEvent)
and event.payload.type == StructuredLogType.SPAN_START.value
):
self.spans[event.span_id] = event.payload
names = []
span_id = event.span_id
while True:
span_payload = self.spans.get(span_id)
if not span_payload:
break
names = [span_payload.name] + names
span_id = span_payload.parent_span_id
span_name = ".".join(names) if names else None
if self.config.log_format == LogFormat.JSON:
formatted = format_event_json(event, span_name)
else:
formatted = format_event_text(event, span_name)
if formatted:
print(formatted)
async def get_trace(self, trace_id: str) -> TraceTree:
raise NotImplementedError("Console telemetry does not support trace retrieval")
async def get_agent_trace(
self,
session_ids: List[str],
) -> List[EvalTrace]:
raise NotImplementedError(
"Console telemetry does not support agent trace retrieval"
)
async def export_agent_trace(
self,
session_ids: List[str],
dataset_id: str,
) -> None:
raise NotImplementedError(
"Console telemetry does not support agent trace export"
)
COLORS = {
"reset": "\033[0m",
"bold": "\033[1m",
"dim": "\033[2m",
"red": "\033[31m",
"green": "\033[32m",
"yellow": "\033[33m",
"blue": "\033[34m",
"magenta": "\033[35m",
"cyan": "\033[36m",
"white": "\033[37m",
}
SEVERITY_COLORS = {
LogSeverity.VERBOSE: COLORS["dim"] + COLORS["white"],
LogSeverity.DEBUG: COLORS["cyan"],
LogSeverity.INFO: COLORS["green"],
LogSeverity.WARN: COLORS["yellow"],
LogSeverity.ERROR: COLORS["red"],
LogSeverity.CRITICAL: COLORS["bold"] + COLORS["red"],
}
def format_event_text(event: Event, span_name: str) -> Optional[str]:
timestamp = event.timestamp.strftime("%H:%M:%S.%f")[:-3]
span = ""
if span_name:
span = f"{COLORS['magenta']}[{span_name}]{COLORS['reset']} "
if isinstance(event, UnstructuredLogEvent):
severity_color = SEVERITY_COLORS.get(event.severity, COLORS["reset"])
return (
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
f"{severity_color}[{event.severity.name}]{COLORS['reset']} "
f"{span}"
f"{event.message}"
)
elif isinstance(event, StructuredLogEvent):
return None
return f"Unknown event type: {event}"
def format_event_json(event: Event, span_name: str) -> Optional[str]:
base_data = {
"timestamp": event.timestamp.isoformat(),
"trace_id": event.trace_id,
"span_id": event.span_id,
"span_name": span_name,
}
if isinstance(event, UnstructuredLogEvent):
base_data.update(
{"type": "log", "severity": event.severity.name, "message": event.message}
)
return json.dumps(base_data)
elif isinstance(event, StructuredLogEvent):
return None
return json.dumps({"error": f"Unknown event type: {event}"})

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from .config import TelemetryConfig, TelemetrySink
from .telemetry import TelemetryAdapter
__all__ = ["TelemetryConfig", "TelemetryAdapter", "TelemetrySink"]
async def get_provider_impl(config: TelemetryConfig, deps: Dict[str, Any]):
impl = TelemetryAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,45 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List
from pydantic import BaseModel, Field
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
class TelemetrySink(str, Enum):
JAEGER = "jaeger"
SQLITE = "sqlite"
CONSOLE = "console"
class TelemetryConfig(BaseModel):
otel_endpoint: str = Field(
default="http://localhost:4318/v1/traces",
description="The OpenTelemetry collector endpoint URL",
)
service_name: str = Field(
default="llama-stack",
description="The service name to use for telemetry",
)
sinks: List[TelemetrySink] = Field(
default=[TelemetrySink.CONSOLE, TelemetrySink.SQLITE],
description="List of telemetry sinks to enable (possible values: jaeger, sqlite, console)",
)
sqlite_db_path: str = Field(
default=(RUNTIME_BASE_DIR / "trace_store.db").as_posix(),
description="The path to the SQLite database to use for storing traces",
)
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"service_name": "${env.OTEL_SERVICE_NAME:llama-stack}",
"sinks": "${env.TELEMETRY_SINKS:['console', 'sqlite']}",
"sqlite_db_path": "${env.SQLITE_DB_PATH:${runtime.base_dir}/trace_store.db}",
}

View file

@ -0,0 +1,95 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.trace.export import SpanProcessor
# Colors for console output
COLORS = {
"reset": "\033[0m",
"bold": "\033[1m",
"dim": "\033[2m",
"red": "\033[31m",
"green": "\033[32m",
"yellow": "\033[33m",
"blue": "\033[34m",
"magenta": "\033[35m",
"cyan": "\033[36m",
"white": "\033[37m",
}
class ConsoleSpanProcessor(SpanProcessor):
"""A SpanProcessor that prints spans to the console with color formatting."""
def on_start(self, span: ReadableSpan, parent_context=None) -> None:
"""Called when a span starts."""
timestamp = datetime.utcfromtimestamp(span.start_time / 1e9).strftime(
"%H:%M:%S.%f"
)[:-3]
print(
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
f"{COLORS['magenta']}[START]{COLORS['reset']} "
f"{COLORS['cyan']}{span.name}{COLORS['reset']}"
)
def on_end(self, span: ReadableSpan) -> None:
"""Called when a span ends."""
timestamp = datetime.utcfromtimestamp(span.end_time / 1e9).strftime(
"%H:%M:%S.%f"
)[:-3]
# Build the span context string
span_context = (
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
f"{COLORS['magenta']}[END]{COLORS['reset']} "
f"{COLORS['cyan']}{span.name}{COLORS['reset']} "
)
# Add status if not OK
if span.status.status_code != 0: # UNSET or ERROR
status_color = (
COLORS["red"] if span.status.status_code == 2 else COLORS["yellow"]
)
span_context += (
f" {status_color}[{span.status.status_code}]{COLORS['reset']}"
)
# Add duration
duration_ms = (span.end_time - span.start_time) / 1e6
span_context += f" {COLORS['dim']}({duration_ms:.2f}ms){COLORS['reset']}"
# Print the main span line
print(span_context)
# Print attributes indented
if span.attributes:
for key, value in span.attributes.items():
print(f" {COLORS['dim']}{key}: {value}{COLORS['reset']}")
# Print events indented
for event in span.events:
event_time = datetime.utcfromtimestamp(event.timestamp / 1e9).strftime(
"%H:%M:%S.%f"
)[:-3]
print(
f" {COLORS['dim']}{event_time}{COLORS['reset']} "
f"{COLORS['cyan']}[EVENT]{COLORS['reset']} {event.name}"
)
if event.attributes:
for key, value in event.attributes.items():
print(f" {COLORS['dim']}{key}: {value}{COLORS['reset']}")
def shutdown(self) -> None:
"""Shutdown the processor."""
pass
def force_flush(self, timeout_millis: float = None) -> bool:
"""Force flush any pending spans."""
return True

View file

@ -0,0 +1,242 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import os
import sqlite3
import threading
from datetime import datetime, timedelta
from typing import Dict
from opentelemetry.sdk.trace import SpanProcessor
from opentelemetry.trace import Span
class SQLiteSpanProcessor(SpanProcessor):
def __init__(self, conn_string, ttl_days=30):
"""Initialize the SQLite span processor with a connection string."""
self.conn_string = conn_string
self.ttl_days = ttl_days
self.cleanup_task = None
self._thread_local = threading.local()
self._connections: Dict[int, sqlite3.Connection] = {}
self._lock = threading.Lock()
self.setup_database()
def _get_connection(self) -> sqlite3.Connection:
"""Get a thread-specific database connection."""
thread_id = threading.get_ident()
with self._lock:
if thread_id not in self._connections:
conn = sqlite3.connect(self.conn_string)
self._connections[thread_id] = conn
return self._connections[thread_id]
def setup_database(self):
"""Create the necessary tables if they don't exist."""
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(self.conn_string), exist_ok=True)
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS traces (
trace_id TEXT PRIMARY KEY,
service_name TEXT,
root_span_id TEXT,
start_time TIMESTAMP,
end_time TIMESTAMP,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
)
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS spans (
span_id TEXT PRIMARY KEY,
trace_id TEXT REFERENCES traces(trace_id),
parent_span_id TEXT,
name TEXT,
start_time TIMESTAMP,
end_time TIMESTAMP,
attributes TEXT,
status TEXT,
kind TEXT
)
"""
)
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS span_events (
id INTEGER PRIMARY KEY AUTOINCREMENT,
span_id TEXT REFERENCES spans(span_id),
name TEXT,
timestamp TIMESTAMP,
attributes TEXT
)
"""
)
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_traces_created_at
ON traces(created_at)
"""
)
conn.commit()
cursor.close()
# Start periodic cleanup in a separate thread
self.cleanup_task = threading.Thread(target=self._periodic_cleanup, daemon=True)
self.cleanup_task.start()
def _cleanup_old_data(self):
"""Delete records older than TTL."""
try:
conn = self._get_connection()
cutoff_date = (datetime.now() - timedelta(days=self.ttl_days)).isoformat()
cursor = conn.cursor()
# Delete old span events
cursor.execute(
"""
DELETE FROM span_events
WHERE span_id IN (
SELECT span_id FROM spans
WHERE trace_id IN (
SELECT trace_id FROM traces
WHERE created_at < ?
)
)
""",
(cutoff_date,),
)
# Delete old spans
cursor.execute(
"""
DELETE FROM spans
WHERE trace_id IN (
SELECT trace_id FROM traces
WHERE created_at < ?
)
""",
(cutoff_date,),
)
# Delete old traces
cursor.execute("DELETE FROM traces WHERE created_at < ?", (cutoff_date,))
conn.commit()
cursor.close()
except Exception as e:
print(f"Error during cleanup: {e}")
def _periodic_cleanup(self):
"""Run cleanup periodically."""
import time
while True:
time.sleep(3600) # Sleep for 1 hour
self._cleanup_old_data()
def on_start(self, span: Span, parent_context=None):
"""Called when a span starts."""
pass
def on_end(self, span: Span):
"""Called when a span ends. Export the span data to SQLite."""
try:
conn = self._get_connection()
cursor = conn.cursor()
trace_id = format(span.get_span_context().trace_id, "032x")
span_id = format(span.get_span_context().span_id, "016x")
service_name = span.resource.attributes.get("service.name", "unknown")
parent_span_id = None
parent_context = span.parent
if parent_context:
parent_span_id = format(parent_context.span_id, "016x")
# Insert into traces
cursor.execute(
"""
INSERT INTO traces (
trace_id, service_name, root_span_id, start_time, end_time
) VALUES (?, ?, ?, ?, ?)
ON CONFLICT(trace_id) DO UPDATE SET
root_span_id = COALESCE(root_span_id, excluded.root_span_id),
start_time = MIN(excluded.start_time, start_time),
end_time = MAX(excluded.end_time, end_time)
""",
(
trace_id,
service_name,
(span_id if not parent_span_id else None),
datetime.fromtimestamp(span.start_time / 1e9).isoformat(),
datetime.fromtimestamp(span.end_time / 1e9).isoformat(),
),
)
# Insert into spans
cursor.execute(
"""
INSERT INTO spans (
span_id, trace_id, parent_span_id, name,
start_time, end_time, attributes, status,
kind
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
span_id,
trace_id,
parent_span_id,
span.name,
datetime.fromtimestamp(span.start_time / 1e9).isoformat(),
datetime.fromtimestamp(span.end_time / 1e9).isoformat(),
json.dumps(dict(span.attributes)),
span.status.status_code.name,
span.kind.name,
),
)
for event in span.events:
cursor.execute(
"""
INSERT INTO span_events (
span_id, name, timestamp, attributes
) VALUES (?, ?, ?, ?)
""",
(
span_id,
event.name,
datetime.fromtimestamp(event.timestamp / 1e9).isoformat(),
json.dumps(dict(event.attributes)),
),
)
conn.commit()
cursor.close()
except Exception as e:
print(f"Error exporting span to SQLite: {e}")
def shutdown(self):
"""Cleanup any resources."""
with self._lock:
for conn in self._connections.values():
if conn:
conn.close()
self._connections.clear()
def force_flush(self, timeout_millis=30000):
"""Force export of spans."""
pass

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import threading
from typing import List
from typing import List, Optional
from opentelemetry import metrics, trace
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
@ -17,17 +17,18 @@ from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.semconv.resource import ResourceAttributes
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.remote.telemetry.opentelemetry.postgres_processor import (
PostgresSpanProcessor,
from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import (
ConsoleSpanProcessor,
)
from llama_stack.providers.utils.telemetry.jaeger import JaegerTraceStore
from llama_stack.providers.utils.telemetry.postgres import PostgresTraceStore
from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor import (
SQLiteSpanProcessor,
)
from llama_stack.providers.utils.telemetry.sqlite import SQLiteTraceStore
from llama_stack.apis.telemetry import * # noqa: F403
from .config import OpenTelemetryConfig
from .config import TelemetryConfig, TelemetrySink
_GLOBAL_STORAGE = {
"active_spans": {},
@ -53,19 +54,9 @@ def is_tracing_enabled(tracer):
return span.is_recording()
class OpenTelemetryAdapter(Telemetry):
def __init__(self, config: OpenTelemetryConfig, deps) -> None:
class TelemetryAdapter(Telemetry):
def __init__(self, config: TelemetryConfig) -> None:
self.config = config
self.datasetio = deps[Api.datasetio]
if config.trace_store == "jaeger":
self.trace_store = JaegerTraceStore(
config.jaeger_query_endpoint, config.service_name
)
elif config.trace_store == "postgres":
self.trace_store = PostgresTraceStore(config.postgres_conn_string)
else:
raise ValueError(f"Invalid trace store: {config.trace_store}")
resource = Resource.create(
{
@ -75,25 +66,29 @@ class OpenTelemetryAdapter(Telemetry):
provider = TracerProvider(resource=resource)
trace.set_tracer_provider(provider)
otlp_exporter = OTLPSpanExporter(
endpoint=self.config.otel_endpoint,
)
span_processor = BatchSpanProcessor(otlp_exporter)
trace.get_tracer_provider().add_span_processor(span_processor)
trace.get_tracer_provider().add_span_processor(
PostgresSpanProcessor(self.config.postgres_conn_string)
)
# Set up metrics
metric_reader = PeriodicExportingMetricReader(
OTLPMetricExporter(
if TelemetrySink.JAEGER in self.config.sinks:
otlp_exporter = OTLPSpanExporter(
endpoint=self.config.otel_endpoint,
)
)
metric_provider = MeterProvider(
resource=resource, metric_readers=[metric_reader]
)
metrics.set_meter_provider(metric_provider)
self.meter = metrics.get_meter(__name__)
span_processor = BatchSpanProcessor(otlp_exporter)
trace.get_tracer_provider().add_span_processor(span_processor)
metric_reader = PeriodicExportingMetricReader(
OTLPMetricExporter(
endpoint=self.config.otel_endpoint,
)
)
metric_provider = MeterProvider(
resource=resource, metric_readers=[metric_reader]
)
metrics.set_meter_provider(metric_provider)
self.meter = metrics.get_meter(__name__)
if TelemetrySink.SQLITE in self.config.sinks:
trace.get_tracer_provider().add_span_processor(
SQLiteSpanProcessor(self.config.sqlite_db_path)
)
self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path)
if TelemetrySink.CONSOLE in self.config.sinks:
trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor())
self._lock = _global_lock
async def initialize(self) -> None:
@ -104,15 +99,17 @@ class OpenTelemetryAdapter(Telemetry):
trace.get_tracer_provider().shutdown()
metrics.get_meter_provider().shutdown()
async def log_event(self, event: Event) -> None:
async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None:
if isinstance(event, UnstructuredLogEvent):
self._log_unstructured(event)
self._log_unstructured(event, ttl_seconds)
elif isinstance(event, MetricEvent):
self._log_metric(event)
elif isinstance(event, StructuredLogEvent):
self._log_structured(event)
self._log_structured(event, ttl_seconds)
else:
raise ValueError(f"Unknown event type: {event}")
def _log_unstructured(self, event: UnstructuredLogEvent) -> None:
def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None:
with self._lock:
# Use global storage instead of instance storage
span_id = string_to_span_id(event.span_id)
@ -125,6 +122,7 @@ class OpenTelemetryAdapter(Telemetry):
attributes={
"message": event.message,
"severity": event.severity.value,
"__ttl__": ttl_seconds,
**event.attributes,
},
timestamp=timestamp_ns,
@ -175,11 +173,14 @@ class OpenTelemetryAdapter(Telemetry):
)
return _GLOBAL_STORAGE["up_down_counters"][name]
def _log_structured(self, event: StructuredLogEvent) -> None:
def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
with self._lock:
span_id = string_to_span_id(event.span_id)
trace_id = string_to_trace_id(event.trace_id)
tracer = trace.get_tracer(__name__)
if event.attributes is None:
event.attributes = {}
event.attributes["__ttl__"] = ttl_seconds
if isinstance(event.payload, SpanStartPayload):
# Check if span already exists to prevent duplicates
@ -216,66 +217,33 @@ class OpenTelemetryAdapter(Telemetry):
span.set_status(status)
span.end()
_GLOBAL_STORAGE["active_spans"].pop(span_id, None)
else:
raise ValueError(f"Unknown structured log event: {event}")
async def get_trace(self, trace_id: str) -> TraceTree:
return await self.trace_store.get_trace(trace_id)
async def get_agent_trace(
async def query_traces(
self,
session_ids: List[str],
) -> List[EvalTrace]:
traces = []
for session_id in session_ids:
traces_for_session = await self.trace_store.get_traces_for_sessions(
[session_id]
)
for session_trace in traces_for_session:
trace_details = await self._get_simplified_agent_trace(
session_trace.trace_id, session_id
)
traces.extend(trace_details)
attribute_conditions: Optional[List[QueryCondition]] = None,
attribute_keys_to_return: Optional[List[str]] = None,
limit: Optional[int] = 100,
offset: Optional[int] = 0,
order_by: Optional[List[str]] = None,
) -> List[Trace]:
return await self.trace_store.query_traces(
attribute_conditions=attribute_conditions,
attribute_keys_to_return=attribute_keys_to_return,
limit=limit,
offset=offset,
order_by=order_by,
)
return traces
async def export_agent_trace(self, session_ids: List[str], dataset_id: str) -> None:
traces = await self.get_agent_trace(session_ids)
traces_dict = [
{
"step": trace.step,
"input": trace.input,
"output": trace.output,
"session_id": trace.session_id,
}
for trace in traces
]
await self.datasetio.upload_rows(dataset_id, traces_dict)
async def _get_simplified_agent_trace(
self, trace_id: str, session_id: str
) -> List[EvalTrace]:
trace_tree = await self.get_trace(trace_id)
if not trace_tree or not trace_tree.root:
return []
def find_execute_turn_children(node: SpanNode) -> List[EvalTrace]:
results = []
if node.span.name == "create_and_execute_turn":
# Sort children by start time
sorted_children = sorted(node.children, key=lambda x: x.span.start_time)
for child in sorted_children:
results.append(
EvalTrace(
step=child.span.name,
input=str(child.span.attributes.get("input", "")),
output=str(child.span.attributes.get("output", "")),
session_id=session_id,
expected_output="",
)
)
# Recursively process children
for child in node.children:
results.extend(find_execute_turn_children(child))
return results
return find_execute_turn_children(trace_tree.root)
async def get_materialized_span(
self,
span_id: str,
attribute_keys_to_return: Optional[List[str]] = None,
max_depth: Optional[int] = None,
) -> MaterializedSpan:
return await self.trace_store.get_materialized_span(
span_id=span_id,
attribute_keys_to_return=attribute_keys_to_return,
max_depth=max_depth,
)

View file

@ -15,8 +15,8 @@ def available_providers() -> List[ProviderSpec]:
api=Api.telemetry,
provider_type="inline::meta-reference",
pip_packages=[],
module="llama_stack.providers.inline.meta_reference.telemetry",
config_class="llama_stack.providers.inline.meta_reference.telemetry.ConsoleConfig",
module="llama_stack.providers.inline.telemetry.meta_reference",
config_class="llama_stack.providers.inline.telemetry.meta_reference.config.TelemetryConfig",
),
remote_provider_spec(
api=Api.telemetry,
@ -27,23 +27,4 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.remote.telemetry.sample.SampleConfig",
),
),
RemoteProviderSpec(
api=Api.telemetry,
provider_type="remote::opentelemetry-jaeger",
config_class="llama_stack.providers.remote.telemetry.opentelemetry.OpenTelemetryConfig",
adapter=AdapterSpec(
adapter_type="opentelemetry-jaeger",
pip_packages=[
"opentelemetry-api",
"opentelemetry-sdk",
"opentelemetry-exporter-jaeger",
"opentelemetry-semantic-conventions",
],
module="llama_stack.providers.remote.telemetry.opentelemetry",
config_class="llama_stack.providers.remote.telemetry.opentelemetry.OpenTelemetryConfig",
),
api_dependencies=[
Api.datasetio,
],
),
]

View file

@ -1,15 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import OpenTelemetryConfig
async def get_adapter_impl(config: OpenTelemetryConfig, deps):
from .opentelemetry import OpenTelemetryAdapter
impl = OpenTelemetryAdapter(config, deps)
await impl.initialize()
return impl

View file

@ -1,39 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from pydantic import BaseModel, Field
class OpenTelemetryConfig(BaseModel):
otel_endpoint: str = Field(
default="http://localhost:4318/v1/traces",
description="The OpenTelemetry collector endpoint URL",
)
service_name: str = Field(
default="llama-stack",
description="The service name to use for telemetry",
)
trace_store: str = Field(
default="postgres",
description="The trace store to use for telemetry",
)
jaeger_query_endpoint: str = Field(
default="http://localhost:16686/api/traces",
description="The Jaeger query endpoint URL",
)
postgres_conn_string: str = Field(
default="host=localhost dbname=llama_stack user=llama_stack password=llama_stack port=5432",
description="The PostgreSQL connection string to use for storing traces",
)
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"otel_endpoint": "${env.OTEL_ENDPOINT:http://localhost:4318/v1/traces}",
"service_name": "${env.OTEL_SERVICE_NAME:llama-stack}",
}

View file

@ -1,92 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from datetime import datetime
import psycopg2
from opentelemetry.sdk.trace import SpanProcessor
from opentelemetry.trace import Span
class PostgresSpanProcessor(SpanProcessor):
def __init__(self, conn_string):
"""Initialize the PostgreSQL span processor with a connection string."""
self.conn_string = conn_string
self.conn = None
self.setup_database()
def setup_database(self):
"""Create the necessary table if it doesn't exist."""
with psycopg2.connect(self.conn_string) as conn:
with conn.cursor() as cur:
cur.execute(
"""
CREATE TABLE IF NOT EXISTS traces (
trace_id TEXT,
span_id TEXT,
parent_span_id TEXT,
name TEXT,
start_time TIMESTAMP,
end_time TIMESTAMP,
attributes JSONB,
status TEXT,
kind TEXT,
service_name TEXT,
session_id TEXT
)
"""
)
conn.commit()
def on_start(self, span: Span, parent_context=None):
"""Called when a span starts."""
pass
def on_end(self, span: Span):
"""Called when a span ends. Export the span data to PostgreSQL."""
try:
with psycopg2.connect(self.conn_string) as conn:
with conn.cursor() as cur:
cur.execute(
"""
INSERT INTO traces (
trace_id, span_id, parent_span_id, name,
start_time, end_time, attributes, status,
kind, service_name, session_id
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""",
(
format(span.get_span_context().trace_id, "032x"),
format(span.get_span_context().span_id, "016x"),
(
format(span.parent.span_id, "016x")
if span.parent
else None
),
span.name,
datetime.fromtimestamp(span.start_time / 1e9),
datetime.fromtimestamp(span.end_time / 1e9),
json.dumps(dict(span.attributes)),
span.status.status_code.name,
span.kind.name,
span.resource.attributes.get("service.name", "unknown"),
span.attributes.get("session_id", None),
),
)
conn.commit()
except Exception as e:
print(f"Error exporting span to PostgreSQL: {e}")
def shutdown(self):
"""Cleanup any resources."""
if self.conn:
self.conn.close()
def force_flush(self, timeout_millis=30000):
"""Force export of spans."""
pass

View file

@ -1,141 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime, timedelta
from typing import List
import aiohttp
from llama_stack.apis.telemetry import Span, SpanNode, Trace, TraceStore, TraceTree
class JaegerTraceStore(TraceStore):
def __init__(self, endpoint: str, service_name: str):
self.endpoint = endpoint
self.service_name = service_name
async def get_trace(self, trace_id: str) -> TraceTree:
params = {
"traceID": trace_id,
}
try:
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.endpoint}/{trace_id}", params=params
) as response:
if response.status != 200:
raise Exception(
f"Failed to query Jaeger: {response.status} {await response.text()}"
)
trace_data = await response.json()
if not trace_data.get("data") or not trace_data["data"]:
return None
# First pass: Build span map
span_map = {}
for jaeger_span in trace_data["data"][0]["spans"]:
start_time = datetime.fromtimestamp(
jaeger_span["startTime"] / 1000000
)
# Some systems store end time directly in the span
if "endTime" in jaeger_span:
end_time = datetime.fromtimestamp(
jaeger_span["endTime"] / 1000000
)
else:
duration_microseconds = jaeger_span.get("duration", 0)
duration_timedelta = timedelta(
microseconds=duration_microseconds
)
end_time = start_time + duration_timedelta
span = Span(
span_id=jaeger_span["spanID"],
trace_id=trace_id,
name=jaeger_span["operationName"],
start_time=start_time,
end_time=end_time,
parent_span_id=next(
(
ref["spanID"]
for ref in jaeger_span.get("references", [])
if ref["refType"] == "CHILD_OF"
),
None,
),
attributes={
tag["key"]: tag["value"]
for tag in jaeger_span.get("tags", [])
},
)
span_map[span.span_id] = SpanNode(span=span)
# Second pass: Build parent-child relationships
root_node = None
for span_node in span_map.values():
parent_id = span_node.span.parent_span_id
if parent_id and parent_id in span_map:
span_map[parent_id].children.append(span_node)
elif not parent_id:
root_node = span_node
trace = Trace(
trace_id=trace_id,
root_span_id=root_node.span.span_id if root_node else "",
start_time=(
root_node.span.start_time if root_node else datetime.now()
),
end_time=root_node.span.end_time if root_node else None,
)
return TraceTree(trace=trace, root=root_node)
except Exception as e:
raise Exception(f"Error querying Jaeger trace structure: {str(e)}") from e
async def get_traces_for_sessions(self, session_ids: List[str]) -> List[Trace]:
traces = []
# Fetch traces for each session ID individually
for session_id in session_ids:
params = {
"service": self.service_name,
"tags": f'{{"session_id":"{session_id}"}}',
"limit": 100,
"lookback": "10000h",
}
try:
async with aiohttp.ClientSession() as session:
async with session.get(self.endpoint, params=params) as response:
if response.status != 200:
raise Exception(
f"Failed to query Jaeger: {response.status} {await response.text()}"
)
traces_data = await response.json()
seen_trace_ids = set()
for trace_data in traces_data.get("data", []):
trace_id = trace_data.get("traceID")
if trace_id and trace_id not in seen_trace_ids:
seen_trace_ids.add(trace_id)
traces.append(
Trace(
trace_id=trace_id,
root_span_id="",
start_time=datetime.now(),
)
)
except Exception as e:
raise Exception(f"Error querying Jaeger traces: {str(e)}") from e
return traces

View file

@ -1,114 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from datetime import datetime
from typing import List, Optional
import psycopg2
from llama_stack.apis.telemetry import Span, SpanNode, Trace, TraceStore, TraceTree
class PostgresTraceStore(TraceStore):
def __init__(self, conn_string: str):
self.conn_string = conn_string
async def get_trace(self, trace_id: str) -> Optional[TraceTree]:
try:
with psycopg2.connect(self.conn_string) as conn:
with conn.cursor() as cur:
# Fetch all spans for the trace
cur.execute(
"""
SELECT trace_id, span_id, parent_span_id, name,
start_time, end_time, attributes
FROM traces
WHERE trace_id = %s
""",
(trace_id,),
)
spans_data = cur.fetchall()
if not spans_data:
return None
# First pass: Build span map
span_map = {}
for span_data in spans_data:
# Ensure attributes is a string before parsing
attributes = span_data[6]
if isinstance(attributes, dict):
attributes = json.dumps(attributes)
span = Span(
span_id=span_data[1],
trace_id=span_data[0],
name=span_data[3],
start_time=span_data[4],
end_time=span_data[5],
parent_span_id=span_data[2],
attributes=json.loads(
attributes
), # Now safely parse the JSON string
)
span_map[span.span_id] = SpanNode(span=span)
# Second pass: Build parent-child relationships
root_node = None
for span_node in span_map.values():
parent_id = span_node.span.parent_span_id
if parent_id and parent_id in span_map:
span_map[parent_id].children.append(span_node)
elif not parent_id:
root_node = span_node
trace = Trace(
trace_id=trace_id,
root_span_id=root_node.span.span_id if root_node else "",
start_time=(
root_node.span.start_time if root_node else datetime.now()
),
end_time=root_node.span.end_time if root_node else None,
)
return TraceTree(trace=trace, root=root_node)
except Exception as e:
raise Exception(
f"Error querying PostgreSQL trace structure: {str(e)}"
) from e
async def get_traces_for_sessions(self, session_ids: List[str]) -> List[Trace]:
traces = []
try:
with psycopg2.connect(self.conn_string) as conn:
with conn.cursor() as cur:
# Query traces for all session IDs
cur.execute(
"""
SELECT DISTINCT trace_id, MIN(start_time) as start_time
FROM traces
WHERE attributes->>'session_id' = ANY(%s)
GROUP BY trace_id
""",
(session_ids,),
)
traces_data = cur.fetchall()
for trace_data in traces_data:
traces.append(
Trace(
trace_id=trace_data[0],
root_span_id="",
start_time=trace_data[1],
)
)
except Exception as e:
raise Exception(f"Error querying PostgreSQL traces: {str(e)}") from e
return traces

View file

@ -0,0 +1,157 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from datetime import datetime
from typing import List, Optional
import aiosqlite
from llama_stack.apis.telemetry import (
MaterializedSpan,
QueryCondition,
Trace,
TraceStore,
)
class SQLiteTraceStore(TraceStore):
def __init__(self, conn_string: str):
self.conn_string = conn_string
async def query_traces(
self,
attribute_conditions: Optional[List[QueryCondition]] = None,
attribute_keys_to_return: Optional[List[str]] = None,
limit: Optional[int] = 100,
offset: Optional[int] = 0,
order_by: Optional[List[str]] = None,
) -> List[Trace]:
# Build the SQL query with attribute selection
select_clause = """
SELECT DISTINCT t.trace_id, t.root_span_id, t.start_time, t.end_time
"""
if attribute_keys_to_return:
for key in attribute_keys_to_return:
select_clause += (
f", json_extract(s.attributes, '$.{key}') as attr_{key}"
)
query = (
select_clause
+ """
FROM traces t
JOIN spans s ON t.trace_id = s.trace_id
"""
)
params = []
# Add attribute conditions if present
if attribute_conditions:
conditions = []
for condition in attribute_conditions:
conditions.append(
f"json_extract(s.attributes, '$.{condition.key}') {condition.op} ?"
)
params.append(condition.value)
if conditions:
query += " WHERE " + " AND ".join(conditions)
# Add ordering
if order_by:
order_clauses = []
for field in order_by:
desc = False
if field.startswith("-"):
field = field[1:]
desc = True
order_clauses.append(f"t.{field} {'DESC' if desc else 'ASC'}")
query += " ORDER BY " + ", ".join(order_clauses)
# Add limit and offset
query += f" LIMIT {limit} OFFSET {offset}"
async with aiosqlite.connect(self.conn_string) as conn:
conn.row_factory = aiosqlite.Row
async with conn.execute(query, params) as cursor:
rows = await cursor.fetchall()
return [
Trace(
trace_id=row["trace_id"],
root_span_id=row["root_span_id"],
start_time=datetime.fromisoformat(row["start_time"]),
end_time=datetime.fromisoformat(row["end_time"]),
)
for row in rows
]
async def get_materialized_span(
self,
span_id: str,
attribute_keys_to_return: Optional[List[str]] = None,
max_depth: Optional[int] = None,
) -> MaterializedSpan:
# Build the attributes selection
attributes_select = "s.attributes"
if attribute_keys_to_return:
json_object = ", ".join(
f"'{key}', json_extract(s.attributes, '$.{key}')"
for key in attribute_keys_to_return
)
attributes_select = f"json_object({json_object})"
# SQLite CTE query with filtered attributes
query = f"""
WITH RECURSIVE span_tree AS (
SELECT s.*, 1 as depth, {attributes_select} as filtered_attributes
FROM spans s
WHERE s.span_id = ?
UNION ALL
SELECT s.*, st.depth + 1, {attributes_select} as filtered_attributes
FROM spans s
JOIN span_tree st ON s.parent_span_id = st.span_id
WHERE (? IS NULL OR st.depth < ?)
)
SELECT *
FROM span_tree
ORDER BY depth, start_time
"""
async with aiosqlite.connect(self.conn_string) as conn:
conn.row_factory = aiosqlite.Row
async with conn.execute(query, (span_id, max_depth, max_depth)) as cursor:
rows = await cursor.fetchall()
if not rows:
raise ValueError(f"Span {span_id} not found")
# Build span tree
spans_by_id = {}
root_span = None
for row in rows:
span = MaterializedSpan(
span_id=row["span_id"],
trace_id=row["trace_id"],
parent_span_id=row["parent_span_id"],
name=row["name"],
start_time=datetime.fromisoformat(row["start_time"]),
end_time=datetime.fromisoformat(row["end_time"]),
attributes=json.loads(row["filtered_attributes"]),
status=row["status"].lower(),
children=[],
)
spans_by_id[span.span_id] = span
if span.span_id == span_id:
root_span = span
elif span.parent_span_id in spans_by_id:
spans_by_id[span.parent_span_id].children.append(span)
return root_span