Telemetry API redesign (#525)

# What does this PR do?
Change the Telemetry API to be able to support different use cases like
returning traces for the UI and ability to export for Evals.
Other changes:
* Add a new trace_protocol decorator to decorate all our API methods so
that any call to them will automatically get traced across all impls.
* There is some issue with the decorator pattern of span creation when
using async generators, where there are multiple yields with in the same
context. I think its much more explicit by using the explicit context
manager pattern using with. I moved the span creations in agent instance
to be using with
* Inject session id at the turn level, which should quickly give us all
traces across turns for a given session

Addresses #509

## Test Plan
```
llama stack run /Users/dineshyv/.llama/distributions/llamastack-together/together-run.yaml
PYTHONPATH=. python -m examples.agents.rag_with_memory_bank localhost 5000


 curl -X POST 'http://localhost:5000/alpha/telemetry/query-traces' \
-H 'Content-Type: application/json' \
-d '{
  "attribute_filters": [
    {
      "key": "session_id",
      "op": "eq",
      "value": "dd667b87-ca4b-4d30-9265-5a0de318fc65" }],
  "limit": 100,
  "offset": 0,
  "order_by": ["start_time"]
}' | jq .
[
  {
    "trace_id": "6902f54b83b4b48be18a6f422b13e16f",
    "root_span_id": "5f37b85543afc15a",
    "start_time": "2024-12-04T08:08:30.501587",
    "end_time": "2024-12-04T08:08:36.026463"
  },
  {
    "trace_id": "92227dac84c0615ed741be393813fb5f",
    "root_span_id": "af7c5bb46665c2c8",
    "start_time": "2024-12-04T08:08:36.031170",
    "end_time": "2024-12-04T08:08:41.693301"
  },
  {
    "trace_id": "7d578a6edac62f204ab479fba82f77b6",
    "root_span_id": "1d935e3362676896",
    "start_time": "2024-12-04T08:08:41.695204",
    "end_time": "2024-12-04T08:08:47.228016"
  },
  {
    "trace_id": "dbd767d76991bc816f9f078907dc9ff2",
    "root_span_id": "f5a7ee76683b9602",
    "start_time": "2024-12-04T08:08:47.234578",
    "end_time": "2024-12-04T08:08:53.189412"
  }
]


curl -X POST 'http://localhost:5000/alpha/telemetry/get-span-tree' \
-H 'Content-Type: application/json' \
-d '{ "span_id" : "6cceb4b48a156913", "max_depth": 2, "attributes_to_return": ["input"] }' | jq .
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100   875  100   790  100    85  18462   1986 --:--:-- --:--:-- --:--:-- 20833
{
  "span_id": "6cceb4b48a156913",
  "trace_id": "dafa796f6aaf925f511c04cd7c67fdda",
  "parent_span_id": "892a66d726c7f990",
  "name": "retrieve_rag_context",
  "start_time": "2024-12-04T09:28:21.781995",
  "end_time": "2024-12-04T09:28:21.913352",
  "attributes": {
    "input": [
      "{\"role\":\"system\",\"content\":\"You are a helpful assistant\"}",
      "{\"role\":\"user\",\"content\":\"What are the top 5 topics that were explained in the documentation? Only list succinct bullet points.\",\"context\":null}"
    ]
  },
  "children": [
    {
      "span_id": "1a2df181854064a8",
      "trace_id": "dafa796f6aaf925f511c04cd7c67fdda",
      "parent_span_id": "6cceb4b48a156913",
      "name": "MemoryRouter.query_documents",
      "start_time": "2024-12-04T09:28:21.787620",
      "end_time": "2024-12-04T09:28:21.906512",
      "attributes": {
        "input": null
      },
      "children": [],
      "status": "ok"
    }
  ],
  "status": "ok"
}

```

<img width="1677" alt="Screenshot 2024-12-04 at 9 42 56 AM"
src="https://github.com/user-attachments/assets/4d3cea93-05ce-415a-93d9-4b1628631bf8">
This commit is contained in:
Dinesh Yeduguru 2024-12-04 11:22:45 -08:00 committed by GitHub
parent 16769256b7
commit fcd6449519
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
34 changed files with 1551 additions and 245 deletions

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from .config import TelemetryConfig, TelemetrySink
from .telemetry import TelemetryAdapter
__all__ = ["TelemetryConfig", "TelemetryAdapter", "TelemetrySink"]
async def get_provider_impl(config: TelemetryConfig, deps: Dict[str, Any]):
impl = TelemetryAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,45 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List
from pydantic import BaseModel, Field
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
class TelemetrySink(str, Enum):
JAEGER = "jaeger"
SQLITE = "sqlite"
CONSOLE = "console"
class TelemetryConfig(BaseModel):
otel_endpoint: str = Field(
default="http://localhost:4318/v1/traces",
description="The OpenTelemetry collector endpoint URL",
)
service_name: str = Field(
default="llama-stack",
description="The service name to use for telemetry",
)
sinks: List[TelemetrySink] = Field(
default=[TelemetrySink.CONSOLE, TelemetrySink.SQLITE],
description="List of telemetry sinks to enable (possible values: jaeger, sqlite, console)",
)
sqlite_db_path: str = Field(
default=(RUNTIME_BASE_DIR / "trace_store.db").as_posix(),
description="The path to the SQLite database to use for storing traces",
)
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"service_name": "${env.OTEL_SERVICE_NAME:llama-stack}",
"sinks": "${env.TELEMETRY_SINKS:['console', 'sqlite']}",
"sqlite_db_path": "${env.SQLITE_DB_PATH:${runtime.base_dir}/trace_store.db}",
}

View file

@ -0,0 +1,95 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.trace.export import SpanProcessor
# Colors for console output
COLORS = {
"reset": "\033[0m",
"bold": "\033[1m",
"dim": "\033[2m",
"red": "\033[31m",
"green": "\033[32m",
"yellow": "\033[33m",
"blue": "\033[34m",
"magenta": "\033[35m",
"cyan": "\033[36m",
"white": "\033[37m",
}
class ConsoleSpanProcessor(SpanProcessor):
"""A SpanProcessor that prints spans to the console with color formatting."""
def on_start(self, span: ReadableSpan, parent_context=None) -> None:
"""Called when a span starts."""
timestamp = datetime.utcfromtimestamp(span.start_time / 1e9).strftime(
"%H:%M:%S.%f"
)[:-3]
print(
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
f"{COLORS['magenta']}[START]{COLORS['reset']} "
f"{COLORS['cyan']}{span.name}{COLORS['reset']}"
)
def on_end(self, span: ReadableSpan) -> None:
"""Called when a span ends."""
timestamp = datetime.utcfromtimestamp(span.end_time / 1e9).strftime(
"%H:%M:%S.%f"
)[:-3]
# Build the span context string
span_context = (
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
f"{COLORS['magenta']}[END]{COLORS['reset']} "
f"{COLORS['cyan']}{span.name}{COLORS['reset']} "
)
# Add status if not OK
if span.status.status_code != 0: # UNSET or ERROR
status_color = (
COLORS["red"] if span.status.status_code == 2 else COLORS["yellow"]
)
span_context += (
f" {status_color}[{span.status.status_code}]{COLORS['reset']}"
)
# Add duration
duration_ms = (span.end_time - span.start_time) / 1e6
span_context += f" {COLORS['dim']}({duration_ms:.2f}ms){COLORS['reset']}"
# Print the main span line
print(span_context)
# Print attributes indented
if span.attributes:
for key, value in span.attributes.items():
print(f" {COLORS['dim']}{key}: {value}{COLORS['reset']}")
# Print events indented
for event in span.events:
event_time = datetime.utcfromtimestamp(event.timestamp / 1e9).strftime(
"%H:%M:%S.%f"
)[:-3]
print(
f" {COLORS['dim']}{event_time}{COLORS['reset']} "
f"{COLORS['cyan']}[EVENT]{COLORS['reset']} {event.name}"
)
if event.attributes:
for key, value in event.attributes.items():
print(f" {COLORS['dim']}{key}: {value}{COLORS['reset']}")
def shutdown(self) -> None:
"""Shutdown the processor."""
pass
def force_flush(self, timeout_millis: float = None) -> bool:
"""Force flush any pending spans."""
return True

View file

@ -0,0 +1,242 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import os
import sqlite3
import threading
from datetime import datetime, timedelta
from typing import Dict
from opentelemetry.sdk.trace import SpanProcessor
from opentelemetry.trace import Span
class SQLiteSpanProcessor(SpanProcessor):
def __init__(self, conn_string, ttl_days=30):
"""Initialize the SQLite span processor with a connection string."""
self.conn_string = conn_string
self.ttl_days = ttl_days
self.cleanup_task = None
self._thread_local = threading.local()
self._connections: Dict[int, sqlite3.Connection] = {}
self._lock = threading.Lock()
self.setup_database()
def _get_connection(self) -> sqlite3.Connection:
"""Get a thread-specific database connection."""
thread_id = threading.get_ident()
with self._lock:
if thread_id not in self._connections:
conn = sqlite3.connect(self.conn_string)
self._connections[thread_id] = conn
return self._connections[thread_id]
def setup_database(self):
"""Create the necessary tables if they don't exist."""
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(self.conn_string), exist_ok=True)
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS traces (
trace_id TEXT PRIMARY KEY,
service_name TEXT,
root_span_id TEXT,
start_time TIMESTAMP,
end_time TIMESTAMP,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
)
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS spans (
span_id TEXT PRIMARY KEY,
trace_id TEXT REFERENCES traces(trace_id),
parent_span_id TEXT,
name TEXT,
start_time TIMESTAMP,
end_time TIMESTAMP,
attributes TEXT,
status TEXT,
kind TEXT
)
"""
)
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS span_events (
id INTEGER PRIMARY KEY AUTOINCREMENT,
span_id TEXT REFERENCES spans(span_id),
name TEXT,
timestamp TIMESTAMP,
attributes TEXT
)
"""
)
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_traces_created_at
ON traces(created_at)
"""
)
conn.commit()
cursor.close()
# Start periodic cleanup in a separate thread
self.cleanup_task = threading.Thread(target=self._periodic_cleanup, daemon=True)
self.cleanup_task.start()
def _cleanup_old_data(self):
"""Delete records older than TTL."""
try:
conn = self._get_connection()
cutoff_date = (datetime.now() - timedelta(days=self.ttl_days)).isoformat()
cursor = conn.cursor()
# Delete old span events
cursor.execute(
"""
DELETE FROM span_events
WHERE span_id IN (
SELECT span_id FROM spans
WHERE trace_id IN (
SELECT trace_id FROM traces
WHERE created_at < ?
)
)
""",
(cutoff_date,),
)
# Delete old spans
cursor.execute(
"""
DELETE FROM spans
WHERE trace_id IN (
SELECT trace_id FROM traces
WHERE created_at < ?
)
""",
(cutoff_date,),
)
# Delete old traces
cursor.execute("DELETE FROM traces WHERE created_at < ?", (cutoff_date,))
conn.commit()
cursor.close()
except Exception as e:
print(f"Error during cleanup: {e}")
def _periodic_cleanup(self):
"""Run cleanup periodically."""
import time
while True:
time.sleep(3600) # Sleep for 1 hour
self._cleanup_old_data()
def on_start(self, span: Span, parent_context=None):
"""Called when a span starts."""
pass
def on_end(self, span: Span):
"""Called when a span ends. Export the span data to SQLite."""
try:
conn = self._get_connection()
cursor = conn.cursor()
trace_id = format(span.get_span_context().trace_id, "032x")
span_id = format(span.get_span_context().span_id, "016x")
service_name = span.resource.attributes.get("service.name", "unknown")
parent_span_id = None
parent_context = span.parent
if parent_context:
parent_span_id = format(parent_context.span_id, "016x")
# Insert into traces
cursor.execute(
"""
INSERT INTO traces (
trace_id, service_name, root_span_id, start_time, end_time
) VALUES (?, ?, ?, ?, ?)
ON CONFLICT(trace_id) DO UPDATE SET
root_span_id = COALESCE(root_span_id, excluded.root_span_id),
start_time = MIN(excluded.start_time, start_time),
end_time = MAX(excluded.end_time, end_time)
""",
(
trace_id,
service_name,
(span_id if not parent_span_id else None),
datetime.fromtimestamp(span.start_time / 1e9).isoformat(),
datetime.fromtimestamp(span.end_time / 1e9).isoformat(),
),
)
# Insert into spans
cursor.execute(
"""
INSERT INTO spans (
span_id, trace_id, parent_span_id, name,
start_time, end_time, attributes, status,
kind
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
span_id,
trace_id,
parent_span_id,
span.name,
datetime.fromtimestamp(span.start_time / 1e9).isoformat(),
datetime.fromtimestamp(span.end_time / 1e9).isoformat(),
json.dumps(dict(span.attributes)),
span.status.status_code.name,
span.kind.name,
),
)
for event in span.events:
cursor.execute(
"""
INSERT INTO span_events (
span_id, name, timestamp, attributes
) VALUES (?, ?, ?, ?)
""",
(
span_id,
event.name,
datetime.fromtimestamp(event.timestamp / 1e9).isoformat(),
json.dumps(dict(event.attributes)),
),
)
conn.commit()
cursor.close()
except Exception as e:
print(f"Error exporting span to SQLite: {e}")
def shutdown(self):
"""Cleanup any resources."""
with self._lock:
for conn in self._connections.values():
if conn:
conn.close()
self._connections.clear()
def force_flush(self, timeout_millis=30000):
"""Force export of spans."""
pass

View file

@ -0,0 +1,247 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import threading
from typing import List, Optional
from opentelemetry import metrics, trace
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.semconv.resource import ResourceAttributes
from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import (
ConsoleSpanProcessor,
)
from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor import (
SQLiteSpanProcessor,
)
from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore
from llama_stack.apis.telemetry import * # noqa: F403
from .config import TelemetryConfig, TelemetrySink
_GLOBAL_STORAGE = {
"active_spans": {},
"counters": {},
"gauges": {},
"up_down_counters": {},
}
_global_lock = threading.Lock()
def string_to_trace_id(s: str) -> int:
# Convert the string to bytes and then to an integer
return int.from_bytes(s.encode(), byteorder="big", signed=False)
def string_to_span_id(s: str) -> int:
# Use only the first 8 bytes (64 bits) for span ID
return int.from_bytes(s.encode()[:8], byteorder="big", signed=False)
def is_tracing_enabled(tracer):
with tracer.start_as_current_span("check_tracing") as span:
return span.is_recording()
class TelemetryAdapter(Telemetry):
def __init__(self, config: TelemetryConfig) -> None:
self.config = config
resource = Resource.create(
{
ResourceAttributes.SERVICE_NAME: self.config.service_name,
}
)
provider = TracerProvider(resource=resource)
trace.set_tracer_provider(provider)
if TelemetrySink.JAEGER in self.config.sinks:
otlp_exporter = OTLPSpanExporter(
endpoint=self.config.otel_endpoint,
)
span_processor = BatchSpanProcessor(otlp_exporter)
trace.get_tracer_provider().add_span_processor(span_processor)
metric_reader = PeriodicExportingMetricReader(
OTLPMetricExporter(
endpoint=self.config.otel_endpoint,
)
)
metric_provider = MeterProvider(
resource=resource, metric_readers=[metric_reader]
)
metrics.set_meter_provider(metric_provider)
self.meter = metrics.get_meter(__name__)
if TelemetrySink.SQLITE in self.config.sinks:
trace.get_tracer_provider().add_span_processor(
SQLiteSpanProcessor(self.config.sqlite_db_path)
)
self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path)
if TelemetrySink.CONSOLE in self.config.sinks:
trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor())
self._lock = _global_lock
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
trace.get_tracer_provider().force_flush()
trace.get_tracer_provider().shutdown()
metrics.get_meter_provider().shutdown()
async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None:
if isinstance(event, UnstructuredLogEvent):
self._log_unstructured(event, ttl_seconds)
elif isinstance(event, MetricEvent):
self._log_metric(event)
elif isinstance(event, StructuredLogEvent):
self._log_structured(event, ttl_seconds)
else:
raise ValueError(f"Unknown event type: {event}")
def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None:
with self._lock:
# Use global storage instead of instance storage
span_id = string_to_span_id(event.span_id)
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
if span:
timestamp_ns = int(event.timestamp.timestamp() * 1e9)
span.add_event(
name=event.type,
attributes={
"message": event.message,
"severity": event.severity.value,
"__ttl__": ttl_seconds,
**event.attributes,
},
timestamp=timestamp_ns,
)
else:
print(
f"Warning: No active span found for span_id {span_id}. Dropping event: {event}"
)
def _get_or_create_counter(self, name: str, unit: str) -> metrics.Counter:
if name not in _GLOBAL_STORAGE["counters"]:
_GLOBAL_STORAGE["counters"][name] = self.meter.create_counter(
name=name,
unit=unit,
description=f"Counter for {name}",
)
return _GLOBAL_STORAGE["counters"][name]
def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge:
if name not in _GLOBAL_STORAGE["gauges"]:
_GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge(
name=name,
unit=unit,
description=f"Gauge for {name}",
)
return _GLOBAL_STORAGE["gauges"][name]
def _log_metric(self, event: MetricEvent) -> None:
if isinstance(event.value, int):
counter = self._get_or_create_counter(event.metric, event.unit)
counter.add(event.value, attributes=event.attributes)
elif isinstance(event.value, float):
up_down_counter = self._get_or_create_up_down_counter(
event.metric, event.unit
)
up_down_counter.add(event.value, attributes=event.attributes)
def _get_or_create_up_down_counter(
self, name: str, unit: str
) -> metrics.UpDownCounter:
if name not in _GLOBAL_STORAGE["up_down_counters"]:
_GLOBAL_STORAGE["up_down_counters"][name] = (
self.meter.create_up_down_counter(
name=name,
unit=unit,
description=f"UpDownCounter for {name}",
)
)
return _GLOBAL_STORAGE["up_down_counters"][name]
def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
with self._lock:
span_id = string_to_span_id(event.span_id)
trace_id = string_to_trace_id(event.trace_id)
tracer = trace.get_tracer(__name__)
if event.attributes is None:
event.attributes = {}
event.attributes["__ttl__"] = ttl_seconds
if isinstance(event.payload, SpanStartPayload):
# Check if span already exists to prevent duplicates
if span_id in _GLOBAL_STORAGE["active_spans"]:
return
parent_span = None
if event.payload.parent_span_id:
parent_span_id = string_to_span_id(event.payload.parent_span_id)
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
context = trace.Context(trace_id=trace_id)
if parent_span:
context = trace.set_span_in_context(parent_span, context)
span = tracer.start_span(
name=event.payload.name,
context=context,
attributes=event.attributes or {},
)
_GLOBAL_STORAGE["active_spans"][span_id] = span
elif isinstance(event.payload, SpanEndPayload):
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
if span:
if event.attributes:
span.set_attributes(event.attributes)
status = (
trace.Status(status_code=trace.StatusCode.OK)
if event.payload.status == SpanStatus.OK
else trace.Status(status_code=trace.StatusCode.ERROR)
)
span.set_status(status)
span.end()
_GLOBAL_STORAGE["active_spans"].pop(span_id, None)
else:
raise ValueError(f"Unknown structured log event: {event}")
async def query_traces(
self,
attribute_filters: Optional[List[QueryCondition]] = None,
limit: Optional[int] = 100,
offset: Optional[int] = 0,
order_by: Optional[List[str]] = None,
) -> List[Trace]:
return await self.trace_store.query_traces(
attribute_filters=attribute_filters,
limit=limit,
offset=offset,
order_by=order_by,
)
async def get_span_tree(
self,
span_id: str,
attributes_to_return: Optional[List[str]] = None,
max_depth: Optional[int] = None,
) -> SpanWithChildren:
return await self.trace_store.get_materialized_span(
span_id=span_id,
attributes_to_return=attributes_to_return,
max_depth=max_depth,
)