diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 733eb79a2..d2243c96f 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -23,7 +23,7 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, ConfigDict, Field from typing_extensions import Annotated -from llama_stack.distribution.tracing import trace_protocol, traced +from llama_stack.distribution.tracing import trace_protocol from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.common.deployment_types import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403 @@ -428,7 +428,6 @@ class Agents(Protocol): ) -> AgentCreateResponse: ... @webmethod(route="/agents/turn/create") - @traced(input="messages") async def create_agent_turn( self, agent_id: str, diff --git a/llama_stack/apis/datasetio/datasetio.py b/llama_stack/apis/datasetio/datasetio.py index 2340ab377..22acc3211 100644 --- a/llama_stack/apis/datasetio/datasetio.py +++ b/llama_stack/apis/datasetio/datasetio.py @@ -38,7 +38,7 @@ class DatasetIO(Protocol): filter_condition: Optional[str] = None, ) -> PaginatedRowsResult: ... - @webmethod(route="/datasetio/upload", method="POST") - async def upload_rows( + @webmethod(route="/datasetio/append-rows", method="POST") + async def append_rows( self, dataset_id: str, rows: List[Dict[str, Any]] ) -> None: ... diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index e0f4d1e3e..85b29a147 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -21,7 +21,7 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field from typing_extensions import Annotated -from llama_stack.distribution.tracing import trace_protocol, traced +from llama_stack.distribution.tracing import trace_protocol from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.models import * # noqa: F403 @@ -227,7 +227,6 @@ class Inference(Protocol): model_store: ModelStore @webmethod(route="/inference/completion") - @traced(input="content") async def completion( self, model_id: str, @@ -239,7 +238,6 @@ class Inference(Protocol): ) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: ... @webmethod(route="/inference/chat-completion") - @traced(input="messages") async def chat_completion( self, model_id: str, @@ -257,7 +255,6 @@ class Inference(Protocol): ]: ... @webmethod(route="/inference/embeddings") - @traced(input="contents") async def embeddings( self, model_id: str, diff --git a/llama_stack/apis/memory/memory.py b/llama_stack/apis/memory/memory.py index c314cb513..b75df8a1a 100644 --- a/llama_stack/apis/memory/memory.py +++ b/llama_stack/apis/memory/memory.py @@ -16,7 +16,7 @@ from pydantic import BaseModel, Field from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.memory_banks import * # noqa: F403 -from llama_stack.distribution.tracing import trace_protocol, traced +from llama_stack.distribution.tracing import trace_protocol @json_schema_type @@ -50,7 +50,6 @@ class Memory(Protocol): # this will just block now until documents are inserted, but it should # probably return a Job instance which can be polled for completion - @traced(input="documents") @webmethod(route="/memory/insert") async def insert_documents( self, @@ -60,7 +59,6 @@ class Memory(Protocol): ) -> None: ... @webmethod(route="/memory/query") - @traced(input="query") async def query_documents( self, bank_id: str, diff --git a/llama_stack/apis/safety/safety.py b/llama_stack/apis/safety/safety.py index 870e178bc..41058f107 100644 --- a/llama_stack/apis/safety/safety.py +++ b/llama_stack/apis/safety/safety.py @@ -10,7 +10,7 @@ from typing import Any, Dict, List, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel -from llama_stack.distribution.tracing import trace_protocol, traced +from llama_stack.distribution.tracing import trace_protocol from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.shields import * # noqa: F403 @@ -50,7 +50,6 @@ class Safety(Protocol): shield_store: ShieldStore @webmethod(route="/safety/run-shield") - @traced(input="messages") async def run_shield( self, shield_id: str, diff --git a/llama_stack/apis/telemetry/telemetry.py b/llama_stack/apis/telemetry/telemetry.py index e799851c9..2ff783c46 100644 --- a/llama_stack/apis/telemetry/telemetry.py +++ b/llama_stack/apis/telemetry/telemetry.py @@ -21,6 +21,9 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field from typing_extensions import Annotated +# Add this constant near the top of the file, after the imports +DEFAULT_TTL_DAYS = 7 + @json_schema_type class SpanStatus(Enum): @@ -147,57 +150,39 @@ class EvalTrace(BaseModel): @json_schema_type -class MaterializedSpan(Span): - children: List["MaterializedSpan"] = Field(default_factory=list) +class SpanWithChildren(Span): + children: List["SpanWithChildren"] = Field(default_factory=list) status: Optional[SpanStatus] = None @json_schema_type class QueryCondition(BaseModel): key: str - op: str + op: Literal["eq", "ne", "gt", "lt"] value: Any -class TraceStore(Protocol): - - async def query_traces( - self, - attribute_conditions: Optional[List[QueryCondition]] = None, - attribute_keys_to_return: Optional[List[str]] = None, - limit: Optional[int] = 100, - offset: Optional[int] = 0, - order_by: Optional[List[str]] = None, - ) -> List[Trace]: ... - - async def get_materialized_span( - self, - span_id: str, - attribute_keys_to_return: Optional[List[str]] = None, - max_depth: Optional[int] = None, - ) -> MaterializedSpan: ... - - @runtime_checkable class Telemetry(Protocol): @webmethod(route="/telemetry/log-event") - async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None: ... + async def log_event( + self, event: Event, ttl_seconds: int = DEFAULT_TTL_DAYS * 86400 + ) -> None: ... - @webmethod(route="/telemetry/query-traces", method="GET") + @webmethod(route="/telemetry/query-traces", method="POST") async def query_traces( self, - attribute_conditions: Optional[List[QueryCondition]] = None, - attribute_keys_to_return: Optional[List[str]] = None, + attribute_filters: Optional[List[QueryCondition]] = None, limit: Optional[int] = 100, offset: Optional[int] = 0, order_by: Optional[List[str]] = None, ) -> List[Trace]: ... - @webmethod(route="/telemetry/get-materialized-span", method="GET") - async def get_materialized_span( + @webmethod(route="/telemetry/get-span-tree", method="POST") + async def get_span_tree( self, span_id: str, - attribute_keys_to_return: Optional[List[str]] = None, + attributes_to_return: Optional[List[str]] = None, max_depth: Optional[int] = None, - ) -> MaterializedSpan: ... + ) -> SpanWithChildren: ... diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 4e5d83763..5b75a525b 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -222,8 +222,8 @@ class DatasetIORouter(DatasetIO): filter_condition=filter_condition, ) - async def upload_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: - return await self.routing_table.get_provider_impl(dataset_id).upload_rows( + async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: + return await self.routing_table.get_provider_impl(dataset_id).append_rows( dataset_id=dataset_id, rows=rows, ) diff --git a/llama_stack/distribution/tracing.py b/llama_stack/distribution/tracing.py index 0e27b775a..ea663ec89 100644 --- a/llama_stack/distribution/tracing.py +++ b/llama_stack/distribution/tracing.py @@ -24,7 +24,7 @@ def serialize_value(value: Any) -> str: return value.model_dump_json() elif isinstance(value, list) and value and isinstance(value[0], BaseModel): return json.dumps([item.model_dump_json() for item in value]) - elif hasattr(value, "to_dict"): # For objects with to_dict method + elif hasattr(value, "to_dict"): return json.dumps(value.to_dict()) elif isinstance(value, (dict, list, int, float, str, bool)): return json.dumps(value) @@ -34,21 +34,6 @@ def serialize_value(value: Any) -> str: return str(value) -def traced(input: str = None): - """ - A method decorator that enables tracing with input and output capture. - - Args: - input: Name of the input parameter to capture in traces - """ - - def decorator(method: Callable) -> Callable: - method._trace_input = input - return method - - return decorator - - def trace_protocol(cls: Type[T]) -> Type[T]: """ A class decorator that automatically traces all methods in a protocol/base class @@ -59,22 +44,6 @@ def trace_protocol(cls: Type[T]) -> Type[T]: is_async = asyncio.iscoroutinefunction(method) is_async_gen = inspect.isasyncgenfunction(method) - def get_traced_input(args: tuple, kwargs: dict) -> dict: - trace_input = getattr(method, "_trace_input", None) - if not trace_input: - return {} - - # Get the mapping of parameter names to values - sig = inspect.signature(method) - bound_args = sig.bind(None, *args, **kwargs) # None for self - bound_args.apply_defaults() - params = dict(list(bound_args.arguments.items())[1:]) # Skip 'self' - - # Return the input value if the key exists - if trace_input in params: - return {"input": serialize_value(params[trace_input])} - return {} - def create_span_context(self: Any, *args: Any, **kwargs: Any) -> tuple: class_name = self.__class__.__name__ method_name = method.__name__ @@ -87,7 +56,6 @@ def trace_protocol(cls: Type[T]) -> Type[T]: "method": method_name, "type": span_type, "args": serialize_value(args), - **get_traced_input(args, kwargs), } return class_name, method_name, span_attributes @@ -145,33 +113,16 @@ def trace_protocol(cls: Type[T]) -> Type[T]: else: return sync_wrapper - # Store the original __init_subclass__ if it exists original_init_subclass = getattr(cls, "__init_subclass__", None) - # Define a new __init_subclass__ to handle child classes def __init_subclass__(cls_child, **kwargs): # noqa: N807 - # Call original __init_subclass__ if it exists if original_init_subclass: original_init_subclass(**kwargs) - traced_methods = {} - for parent in cls_child.__mro__[1:]: # Skip the class itself - for name, method in vars(parent).items(): - if inspect.isfunction(method) and getattr( - method, "_trace_input", None - ): # noqa: B009 - traced_methods[name] = getattr(method, "_trace_input") # noqa: B009 - - # Trace child class methods if their name matches a traced parent method for name, method in vars(cls_child).items(): if inspect.isfunction(method) and not name.startswith("_"): - if name in traced_methods: - # Copy the trace configuration from the parent - setattr(method, "_trace_input", traced_methods[name]) # noqa: B010 - setattr(cls_child, name, trace_method(method)) # noqa: B010 - # Set the new __init_subclass__ cls.__init_subclass__ = classmethod(__init_subclass__) return cls diff --git a/llama_stack/providers/inline/datasetio/localfs/datasetio.py b/llama_stack/providers/inline/datasetio/localfs/datasetio.py index 9d7ea9870..4775ba708 100644 --- a/llama_stack/providers/inline/datasetio/localfs/datasetio.py +++ b/llama_stack/providers/inline/datasetio/localfs/datasetio.py @@ -132,7 +132,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): next_page_token=str(end), ) - async def upload_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: + async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: dataset_info = self.dataset_infos.get(dataset_id) if dataset_info is None: raise ValueError(f"Dataset with id {dataset_id} not found") diff --git a/llama_stack/providers/inline/meta_reference/telemetry/console.py b/llama_stack/providers/inline/meta_reference/telemetry/console.py new file mode 100644 index 000000000..838aaa4e1 --- /dev/null +++ b/llama_stack/providers/inline/meta_reference/telemetry/console.py @@ -0,0 +1,135 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +from typing import List, Optional + +from .config import LogFormat + +from llama_stack.apis.telemetry import * # noqa: F403 +from .config import ConsoleConfig + + +class ConsoleTelemetryImpl(Telemetry): + def __init__(self, config: ConsoleConfig) -> None: + self.config = config + self.spans = {} + + async def initialize(self) -> None: ... + + async def shutdown(self) -> None: ... + + async def log_event(self, event: Event): + if ( + isinstance(event, StructuredLogEvent) + and event.payload.type == StructuredLogType.SPAN_START.value + ): + self.spans[event.span_id] = event.payload + + names = [] + span_id = event.span_id + while True: + span_payload = self.spans.get(span_id) + if not span_payload: + break + + names = [span_payload.name] + names + span_id = span_payload.parent_span_id + + span_name = ".".join(names) if names else None + + if self.config.log_format == LogFormat.JSON: + formatted = format_event_json(event, span_name) + else: + formatted = format_event_text(event, span_name) + + if formatted: + print(formatted) + + async def query_traces( + self, + attribute_conditions: Optional[List[QueryCondition]] = None, + attribute_keys_to_return: Optional[List[str]] = None, + limit: Optional[int] = 100, + offset: Optional[int] = 0, + order_by: Optional[List[str]] = None, + ) -> List[Trace]: + raise NotImplementedError("Console telemetry does not support trace querying") + + async def get_spans( + self, + span_id: str, + attribute_conditions: Optional[List[QueryCondition]] = None, + attribute_keys_to_return: Optional[List[str]] = None, + max_depth: Optional[int] = None, + limit: Optional[int] = 100, + offset: Optional[int] = 0, + order_by: Optional[List[str]] = None, + ) -> SpanWithChildren: + raise NotImplementedError("Console telemetry does not support span querying") + + +COLORS = { + "reset": "\033[0m", + "bold": "\033[1m", + "dim": "\033[2m", + "red": "\033[31m", + "green": "\033[32m", + "yellow": "\033[33m", + "blue": "\033[34m", + "magenta": "\033[35m", + "cyan": "\033[36m", + "white": "\033[37m", +} + +SEVERITY_COLORS = { + LogSeverity.VERBOSE: COLORS["dim"] + COLORS["white"], + LogSeverity.DEBUG: COLORS["cyan"], + LogSeverity.INFO: COLORS["green"], + LogSeverity.WARN: COLORS["yellow"], + LogSeverity.ERROR: COLORS["red"], + LogSeverity.CRITICAL: COLORS["bold"] + COLORS["red"], +} + + +def format_event_text(event: Event, span_name: str) -> Optional[str]: + timestamp = event.timestamp.strftime("%H:%M:%S.%f")[:-3] + span = "" + if span_name: + span = f"{COLORS['magenta']}[{span_name}]{COLORS['reset']} " + if isinstance(event, UnstructuredLogEvent): + severity_color = SEVERITY_COLORS.get(event.severity, COLORS["reset"]) + return ( + f"{COLORS['dim']}{timestamp}{COLORS['reset']} " + f"{severity_color}[{event.severity.name}]{COLORS['reset']} " + f"{span}" + f"{event.message}" + ) + + elif isinstance(event, StructuredLogEvent): + return None + + return f"Unknown event type: {event}" + + +def format_event_json(event: Event, span_name: str) -> Optional[str]: + base_data = { + "timestamp": event.timestamp.isoformat(), + "trace_id": event.trace_id, + "span_id": event.span_id, + "span_name": span_name, + } + + if isinstance(event, UnstructuredLogEvent): + base_data.update( + {"type": "log", "severity": event.severity.name, "message": event.message} + ) + return json.dumps(base_data) + + elif isinstance(event, StructuredLogEvent): + return None + + return json.dumps({"error": f"Unknown event type: {event}"}) diff --git a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py index 1f27876e0..6540a667f 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py @@ -24,7 +24,7 @@ from llama_stack.providers.inline.telemetry.meta_reference.console_span_processo from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor import ( SQLiteSpanProcessor, ) -from llama_stack.providers.utils.telemetry.sqlite import SQLiteTraceStore +from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore from llama_stack.apis.telemetry import * # noqa: F403 @@ -222,28 +222,26 @@ class TelemetryAdapter(Telemetry): async def query_traces( self, - attribute_conditions: Optional[List[QueryCondition]] = None, - attribute_keys_to_return: Optional[List[str]] = None, + attribute_filters: Optional[List[QueryCondition]] = None, limit: Optional[int] = 100, offset: Optional[int] = 0, order_by: Optional[List[str]] = None, ) -> List[Trace]: return await self.trace_store.query_traces( - attribute_conditions=attribute_conditions, - attribute_keys_to_return=attribute_keys_to_return, + attribute_filters=attribute_filters, limit=limit, offset=offset, order_by=order_by, ) - async def get_materialized_span( + async def get_span_tree( self, span_id: str, - attribute_keys_to_return: Optional[List[str]] = None, + attributes_to_return: Optional[List[str]] = None, max_depth: Optional[int] = None, - ) -> MaterializedSpan: + ) -> SpanWithChildren: return await self.trace_store.get_materialized_span( span_id=span_id, - attribute_keys_to_return=attribute_keys_to_return, + attributes_to_return=attributes_to_return, max_depth=max_depth, ) diff --git a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py index f43a1991a..fe5910cc8 100644 --- a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py +++ b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py @@ -96,7 +96,7 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): next_page_token=str(end), ) - async def upload_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: + async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: dataset_def = self.dataset_infos[dataset_id] loaded_dataset = load_hf_dataset(dataset_def) diff --git a/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py b/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py new file mode 100644 index 000000000..04eb71ce0 --- /dev/null +++ b/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py @@ -0,0 +1,259 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import threading +from typing import List, Optional + +from llama_stack.distribution.datatypes import Api +from llama_stack.providers.remote.telemetry.opentelemetry.console_span_processor import ( + ConsoleSpanProcessor, +) +from llama_stack.providers.remote.telemetry.opentelemetry.sqlite_span_processor import ( + SQLiteSpanProcessor, +) +from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore + +from opentelemetry import metrics, trace +from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry.sdk.metrics import MeterProvider +from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor +from opentelemetry.semconv.resource import ResourceAttributes + + +from llama_stack.apis.telemetry import * # noqa: F403 + +from .config import OpenTelemetryConfig, TelemetrySink + +_GLOBAL_STORAGE = { + "active_spans": {}, + "counters": {}, + "gauges": {}, + "up_down_counters": {}, +} +_global_lock = threading.Lock() + + +def string_to_trace_id(s: str) -> int: + # Convert the string to bytes and then to an integer + return int.from_bytes(s.encode(), byteorder="big", signed=False) + + +def string_to_span_id(s: str) -> int: + # Use only the first 8 bytes (64 bits) for span ID + return int.from_bytes(s.encode()[:8], byteorder="big", signed=False) + + +def is_tracing_enabled(tracer): + with tracer.start_as_current_span("check_tracing") as span: + return span.is_recording() + + +class OpenTelemetryAdapter(Telemetry): + def __init__(self, config: OpenTelemetryConfig, deps) -> None: + self.config = config + self.datasetio = deps[Api.datasetio] + + resource = Resource.create( + { + ResourceAttributes.SERVICE_NAME: self.config.service_name, + } + ) + + provider = TracerProvider(resource=resource) + trace.set_tracer_provider(provider) + if TelemetrySink.JAEGER in self.config.sinks: + otlp_exporter = OTLPSpanExporter( + endpoint=self.config.otel_endpoint, + ) + span_processor = BatchSpanProcessor(otlp_exporter) + trace.get_tracer_provider().add_span_processor(span_processor) + metric_reader = PeriodicExportingMetricReader( + OTLPMetricExporter( + endpoint=self.config.otel_endpoint, + ) + ) + metric_provider = MeterProvider( + resource=resource, metric_readers=[metric_reader] + ) + metrics.set_meter_provider(metric_provider) + self.meter = metrics.get_meter(__name__) + if TelemetrySink.SQLITE in self.config.sinks: + trace.get_tracer_provider().add_span_processor( + SQLiteSpanProcessor(self.config.sqlite_db_path) + ) + self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path) + if TelemetrySink.CONSOLE in self.config.sinks: + trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor()) + self._lock = _global_lock + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + trace.get_tracer_provider().force_flush() + trace.get_tracer_provider().shutdown() + metrics.get_meter_provider().shutdown() + + async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None: + if isinstance(event, UnstructuredLogEvent): + self._log_unstructured(event, ttl_seconds) + elif isinstance(event, MetricEvent): + self._log_metric(event) + elif isinstance(event, StructuredLogEvent): + self._log_structured(event, ttl_seconds) + else: + raise ValueError(f"Unknown event type: {event}") + + def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None: + with self._lock: + # Use global storage instead of instance storage + span_id = string_to_span_id(event.span_id) + span = _GLOBAL_STORAGE["active_spans"].get(span_id) + + if span: + timestamp_ns = int(event.timestamp.timestamp() * 1e9) + span.add_event( + name=event.type, + attributes={ + "message": event.message, + "severity": event.severity.value, + "__ttl__": ttl_seconds, + **event.attributes, + }, + timestamp=timestamp_ns, + ) + else: + print( + f"Warning: No active span found for span_id {span_id}. Dropping event: {event}" + ) + + def _get_or_create_counter(self, name: str, unit: str) -> metrics.Counter: + if name not in _GLOBAL_STORAGE["counters"]: + _GLOBAL_STORAGE["counters"][name] = self.meter.create_counter( + name=name, + unit=unit, + description=f"Counter for {name}", + ) + return _GLOBAL_STORAGE["counters"][name] + + def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge: + if name not in _GLOBAL_STORAGE["gauges"]: + _GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge( + name=name, + unit=unit, + description=f"Gauge for {name}", + ) + return _GLOBAL_STORAGE["gauges"][name] + + def _log_metric(self, event: MetricEvent) -> None: + if isinstance(event.value, int): + counter = self._get_or_create_counter(event.metric, event.unit) + counter.add(event.value, attributes=event.attributes) + elif isinstance(event.value, float): + up_down_counter = self._get_or_create_up_down_counter( + event.metric, event.unit + ) + up_down_counter.add(event.value, attributes=event.attributes) + + def _get_or_create_up_down_counter( + self, name: str, unit: str + ) -> metrics.UpDownCounter: + if name not in _GLOBAL_STORAGE["up_down_counters"]: + _GLOBAL_STORAGE["up_down_counters"][name] = ( + self.meter.create_up_down_counter( + name=name, + unit=unit, + description=f"UpDownCounter for {name}", + ) + ) + return _GLOBAL_STORAGE["up_down_counters"][name] + + def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None: + with self._lock: + span_id = string_to_span_id(event.span_id) + trace_id = string_to_trace_id(event.trace_id) + tracer = trace.get_tracer(__name__) + if event.attributes is None: + event.attributes = {} + event.attributes["__ttl__"] = ttl_seconds + + if isinstance(event.payload, SpanStartPayload): + # Check if span already exists to prevent duplicates + if span_id in _GLOBAL_STORAGE["active_spans"]: + return + + parent_span = None + if event.payload.parent_span_id: + parent_span_id = string_to_span_id(event.payload.parent_span_id) + parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id) + + context = trace.Context(trace_id=trace_id) + if parent_span: + context = trace.set_span_in_context(parent_span, context) + + span = tracer.start_span( + name=event.payload.name, + context=context, + attributes=event.attributes or {}, + ) + _GLOBAL_STORAGE["active_spans"][span_id] = span + + elif isinstance(event.payload, SpanEndPayload): + span = _GLOBAL_STORAGE["active_spans"].get(span_id) + if span: + if event.attributes: + span.set_attributes(event.attributes) + + status = ( + trace.Status(status_code=trace.StatusCode.OK) + if event.payload.status == SpanStatus.OK + else trace.Status(status_code=trace.StatusCode.ERROR) + ) + span.set_status(status) + span.end() + _GLOBAL_STORAGE["active_spans"].pop(span_id, None) + else: + raise ValueError(f"Unknown structured log event: {event}") + + async def query_traces( + self, + attribute_conditions: Optional[List[QueryCondition]] = None, + attribute_keys_to_return: Optional[List[str]] = None, + limit: Optional[int] = 100, + offset: Optional[int] = 0, + order_by: Optional[List[str]] = None, + ) -> List[Trace]: + return await self.trace_store.query_traces( + attribute_conditions=attribute_conditions, + attribute_keys_to_return=attribute_keys_to_return, + limit=limit, + offset=offset, + order_by=order_by, + ) + + async def get_spans( + self, + span_id: str, + attribute_conditions: Optional[List[QueryCondition]] = None, + attribute_keys_to_return: Optional[List[str]] = None, + max_depth: Optional[int] = None, + limit: Optional[int] = 100, + offset: Optional[int] = 0, + order_by: Optional[List[str]] = None, + ) -> SpanWithChildren: + return await self.trace_store.get_spans( + span_id=span_id, + attribute_conditions=attribute_conditions, + attribute_keys_to_return=attribute_keys_to_return, + max_depth=max_depth, + limit=limit, + offset=offset, + order_by=order_by, + ) diff --git a/llama_stack/providers/utils/telemetry/sqlite.py b/llama_stack/providers/utils/telemetry/sqlite.py index 4ccabf200..e7161fffa 100644 --- a/llama_stack/providers/utils/telemetry/sqlite.py +++ b/llama_stack/providers/utils/telemetry/sqlite.py @@ -11,8 +11,8 @@ from typing import List, Optional import aiosqlite from llama_stack.apis.telemetry import ( - MaterializedSpan, QueryCondition, + SpanWithChildren, Trace, TraceStore, ) @@ -24,56 +24,76 @@ class SQLiteTraceStore(TraceStore): async def query_traces( self, - attribute_conditions: Optional[List[QueryCondition]] = None, - attribute_keys_to_return: Optional[List[str]] = None, + attribute_filters: Optional[List[QueryCondition]] = None, + attributes_to_return: Optional[List[str]] = None, limit: Optional[int] = 100, offset: Optional[int] = 0, order_by: Optional[List[str]] = None, ) -> List[Trace]: - # Build the SQL query with attribute selection - select_clause = """ - SELECT DISTINCT t.trace_id, t.root_span_id, t.start_time, t.end_time - """ - if attribute_keys_to_return: - for key in attribute_keys_to_return: - select_clause += ( - f", json_extract(s.attributes, '$.{key}') as attr_{key}" - ) + print(attribute_filters, attributes_to_return, limit, offset, order_by) - query = ( - select_clause - + """ - FROM traces t - JOIN spans s ON t.trace_id = s.trace_id - """ - ) - params = [] + def build_attribute_select() -> str: + if not attributes_to_return: + return "" + return "".join( + f", json_extract(s.attributes, '$.{key}') as attr_{key}" + for key in attributes_to_return + ) - # Add attribute conditions if present - if attribute_conditions: - conditions = [] - for condition in attribute_conditions: - conditions.append( - f"json_extract(s.attributes, '$.{condition.key}') {condition.op} ?" - ) - params.append(condition.value) - if conditions: - query += " WHERE " + " AND ".join(conditions) + def build_where_clause() -> tuple[str, list]: + if not attribute_filters: + return "", [] + + conditions = [ + f"json_extract(s.attributes, '$.{condition.key}') {condition.op} ?" + for condition in attribute_filters + ] + params = [condition.value for condition in attribute_filters] + where_clause = " WHERE " + " AND ".join(conditions) + return where_clause, params + + def build_order_clause() -> str: + if not order_by: + return "" - # Add ordering - if order_by: order_clauses = [] for field in order_by: - desc = False - if field.startswith("-"): - field = field[1:] - desc = True - order_clauses.append(f"t.{field} {'DESC' if desc else 'ASC'}") - query += " ORDER BY " + ", ".join(order_clauses) + desc = field.startswith("-") + clean_field = field[1:] if desc else field + order_clauses.append(f"t.{clean_field} {'DESC' if desc else 'ASC'}") + return " ORDER BY " + ", ".join(order_clauses) - # Add limit and offset - query += f" LIMIT {limit} OFFSET {offset}" + # Build the main query + base_query = """ + WITH matching_traces AS ( + SELECT DISTINCT t.trace_id + FROM traces t + JOIN spans s ON t.trace_id = s.trace_id + {where_clause} + ), + filtered_traces AS ( + SELECT t.trace_id, t.root_span_id, t.start_time, t.end_time + {attribute_select} + FROM matching_traces mt + JOIN traces t ON mt.trace_id = t.trace_id + LEFT JOIN spans s ON t.trace_id = s.trace_id + {order_clause} + ) + SELECT DISTINCT trace_id, root_span_id, start_time, end_time + FROM filtered_traces + LIMIT {limit} OFFSET {offset} + """ + where_clause, params = build_where_clause() + query = base_query.format( + attribute_select=build_attribute_select(), + where_clause=where_clause, + order_clause=build_order_clause(), + limit=limit, + offset=offset, + ) + + # Execute query and return results async with aiosqlite.connect(self.conn_string) as conn: conn.row_factory = aiosqlite.Row async with conn.execute(query, params) as cursor: @@ -91,15 +111,15 @@ class SQLiteTraceStore(TraceStore): async def get_materialized_span( self, span_id: str, - attribute_keys_to_return: Optional[List[str]] = None, + attributes_to_return: Optional[List[str]] = None, max_depth: Optional[int] = None, - ) -> MaterializedSpan: + ) -> SpanWithChildren: # Build the attributes selection attributes_select = "s.attributes" - if attribute_keys_to_return: + if attributes_to_return: json_object = ", ".join( f"'{key}', json_extract(s.attributes, '$.{key}')" - for key in attribute_keys_to_return + for key in attributes_to_return ) attributes_select = f"json_object({json_object})" @@ -135,7 +155,7 @@ class SQLiteTraceStore(TraceStore): root_span = None for row in rows: - span = MaterializedSpan( + span = SpanWithChildren( span_id=row["span_id"], trace_id=row["trace_id"], parent_span_id=row["parent_span_id"], diff --git a/llama_stack/providers/utils/telemetry/sqlite_trace_store.py b/llama_stack/providers/utils/telemetry/sqlite_trace_store.py new file mode 100644 index 000000000..ed1343e0b --- /dev/null +++ b/llama_stack/providers/utils/telemetry/sqlite_trace_store.py @@ -0,0 +1,180 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +from datetime import datetime +from typing import List, Optional, Protocol + +import aiosqlite + +from llama_stack.apis.telemetry import QueryCondition, SpanWithChildren, Trace + + +class TraceStore(Protocol): + + async def query_traces( + self, + attribute_filters: Optional[List[QueryCondition]] = None, + limit: Optional[int] = 100, + offset: Optional[int] = 0, + order_by: Optional[List[str]] = None, + ) -> List[Trace]: ... + + async def get_materialized_span( + self, + span_id: str, + attributes_to_return: Optional[List[str]] = None, + max_depth: Optional[int] = None, + ) -> SpanWithChildren: ... + + +class SQLiteTraceStore(TraceStore): + def __init__(self, conn_string: str): + self.conn_string = conn_string + + async def query_traces( + self, + attribute_filters: Optional[List[QueryCondition]] = None, + limit: Optional[int] = 100, + offset: Optional[int] = 0, + order_by: Optional[List[str]] = None, + ) -> List[Trace]: + + def build_where_clause() -> tuple[str, list]: + if not attribute_filters: + return "", [] + + ops_map = {"eq": "=", "ne": "!=", "gt": ">", "lt": "<"} + + conditions = [ + f"json_extract(s.attributes, '$.{condition.key}') {ops_map[condition.op]} ?" + for condition in attribute_filters + ] + params = [condition.value for condition in attribute_filters] + where_clause = " WHERE " + " AND ".join(conditions) + return where_clause, params + + def build_order_clause() -> str: + if not order_by: + return "" + + order_clauses = [] + for field in order_by: + desc = field.startswith("-") + clean_field = field[1:] if desc else field + order_clauses.append(f"t.{clean_field} {'DESC' if desc else 'ASC'}") + return " ORDER BY " + ", ".join(order_clauses) + + # Build the main query + base_query = """ + WITH matching_traces AS ( + SELECT DISTINCT t.trace_id + FROM traces t + JOIN spans s ON t.trace_id = s.trace_id + {where_clause} + ), + filtered_traces AS ( + SELECT t.trace_id, t.root_span_id, t.start_time, t.end_time + FROM matching_traces mt + JOIN traces t ON mt.trace_id = t.trace_id + LEFT JOIN spans s ON t.trace_id = s.trace_id + {order_clause} + ) + SELECT DISTINCT trace_id, root_span_id, start_time, end_time + FROM filtered_traces + LIMIT {limit} OFFSET {offset} + """ + + where_clause, params = build_where_clause() + query = base_query.format( + where_clause=where_clause, + order_clause=build_order_clause(), + limit=limit, + offset=offset, + ) + + # Execute query and return results + async with aiosqlite.connect(self.conn_string) as conn: + conn.row_factory = aiosqlite.Row + async with conn.execute(query, params) as cursor: + rows = await cursor.fetchall() + return [ + Trace( + trace_id=row["trace_id"], + root_span_id=row["root_span_id"], + start_time=datetime.fromisoformat(row["start_time"]), + end_time=datetime.fromisoformat(row["end_time"]), + ) + for row in rows + ] + + async def get_materialized_span( + self, + span_id: str, + attributes_to_return: Optional[List[str]] = None, + max_depth: Optional[int] = None, + ) -> SpanWithChildren: + # Build the attributes selection + attributes_select = "s.attributes" + if attributes_to_return: + json_object = ", ".join( + f"'{key}', json_extract(s.attributes, '$.{key}')" + for key in attributes_to_return + ) + attributes_select = f"json_object({json_object})" + + # SQLite CTE query with filtered attributes + query = f""" + WITH RECURSIVE span_tree AS ( + SELECT s.*, 1 as depth, {attributes_select} as filtered_attributes + FROM spans s + WHERE s.span_id = ? + + UNION ALL + + SELECT s.*, st.depth + 1, {attributes_select} as filtered_attributes + FROM spans s + JOIN span_tree st ON s.parent_span_id = st.span_id + WHERE (? IS NULL OR st.depth < ?) + ) + SELECT * + FROM span_tree + ORDER BY depth, start_time + """ + + async with aiosqlite.connect(self.conn_string) as conn: + conn.row_factory = aiosqlite.Row + async with conn.execute(query, (span_id, max_depth, max_depth)) as cursor: + rows = await cursor.fetchall() + + if not rows: + raise ValueError(f"Span {span_id} not found") + + # Build span tree + spans_by_id = {} + root_span = None + + for row in rows: + span = SpanWithChildren( + span_id=row["span_id"], + trace_id=row["trace_id"], + parent_span_id=row["parent_span_id"], + name=row["name"], + start_time=datetime.fromisoformat(row["start_time"]), + end_time=datetime.fromisoformat(row["end_time"]), + attributes=json.loads(row["filtered_attributes"]), + status=row["status"].lower(), + children=[], + ) + + spans_by_id[span.span_id] = span + + if span.span_id == span_id: + root_span = span + elif span.parent_span_id in spans_by_id: + spans_by_id[span.parent_span_id].children.append(span) + + return root_span