address feedback

This commit is contained in:
Dinesh Yeduguru 2024-12-04 09:25:24 -08:00
parent 32af1f9dd4
commit b8c395c264
15 changed files with 672 additions and 151 deletions

View file

@ -23,7 +23,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated
from llama_stack.distribution.tracing import trace_protocol, traced
from llama_stack.distribution.tracing import trace_protocol
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.common.deployment_types import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
@ -428,7 +428,6 @@ class Agents(Protocol):
) -> AgentCreateResponse: ...
@webmethod(route="/agents/turn/create")
@traced(input="messages")
async def create_agent_turn(
self,
agent_id: str,

View file

@ -38,7 +38,7 @@ class DatasetIO(Protocol):
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult: ...
@webmethod(route="/datasetio/upload", method="POST")
async def upload_rows(
@webmethod(route="/datasetio/append-rows", method="POST")
async def append_rows(
self, dataset_id: str, rows: List[Dict[str, Any]]
) -> None: ...

View file

@ -21,7 +21,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.distribution.tracing import trace_protocol, traced
from llama_stack.distribution.tracing import trace_protocol
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
@ -227,7 +227,6 @@ class Inference(Protocol):
model_store: ModelStore
@webmethod(route="/inference/completion")
@traced(input="content")
async def completion(
self,
model_id: str,
@ -239,7 +238,6 @@ class Inference(Protocol):
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: ...
@webmethod(route="/inference/chat-completion")
@traced(input="messages")
async def chat_completion(
self,
model_id: str,
@ -257,7 +255,6 @@ class Inference(Protocol):
]: ...
@webmethod(route="/inference/embeddings")
@traced(input="contents")
async def embeddings(
self,
model_id: str,

View file

@ -16,7 +16,7 @@ from pydantic import BaseModel, Field
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.distribution.tracing import trace_protocol, traced
from llama_stack.distribution.tracing import trace_protocol
@json_schema_type
@ -50,7 +50,6 @@ class Memory(Protocol):
# this will just block now until documents are inserted, but it should
# probably return a Job instance which can be polled for completion
@traced(input="documents")
@webmethod(route="/memory/insert")
async def insert_documents(
self,
@ -60,7 +59,6 @@ class Memory(Protocol):
) -> None: ...
@webmethod(route="/memory/query")
@traced(input="query")
async def query_documents(
self,
bank_id: str,

View file

@ -10,7 +10,7 @@ from typing import Any, Dict, List, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
from llama_stack.distribution.tracing import trace_protocol, traced
from llama_stack.distribution.tracing import trace_protocol
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
@ -50,7 +50,6 @@ class Safety(Protocol):
shield_store: ShieldStore
@webmethod(route="/safety/run-shield")
@traced(input="messages")
async def run_shield(
self,
shield_id: str,

View file

@ -21,6 +21,9 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated
# Add this constant near the top of the file, after the imports
DEFAULT_TTL_DAYS = 7
@json_schema_type
class SpanStatus(Enum):
@ -147,57 +150,39 @@ class EvalTrace(BaseModel):
@json_schema_type
class MaterializedSpan(Span):
children: List["MaterializedSpan"] = Field(default_factory=list)
class SpanWithChildren(Span):
children: List["SpanWithChildren"] = Field(default_factory=list)
status: Optional[SpanStatus] = None
@json_schema_type
class QueryCondition(BaseModel):
key: str
op: str
op: Literal["eq", "ne", "gt", "lt"]
value: Any
class TraceStore(Protocol):
async def query_traces(
self,
attribute_conditions: Optional[List[QueryCondition]] = None,
attribute_keys_to_return: Optional[List[str]] = None,
limit: Optional[int] = 100,
offset: Optional[int] = 0,
order_by: Optional[List[str]] = None,
) -> List[Trace]: ...
async def get_materialized_span(
self,
span_id: str,
attribute_keys_to_return: Optional[List[str]] = None,
max_depth: Optional[int] = None,
) -> MaterializedSpan: ...
@runtime_checkable
class Telemetry(Protocol):
@webmethod(route="/telemetry/log-event")
async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None: ...
async def log_event(
self, event: Event, ttl_seconds: int = DEFAULT_TTL_DAYS * 86400
) -> None: ...
@webmethod(route="/telemetry/query-traces", method="GET")
@webmethod(route="/telemetry/query-traces", method="POST")
async def query_traces(
self,
attribute_conditions: Optional[List[QueryCondition]] = None,
attribute_keys_to_return: Optional[List[str]] = None,
attribute_filters: Optional[List[QueryCondition]] = None,
limit: Optional[int] = 100,
offset: Optional[int] = 0,
order_by: Optional[List[str]] = None,
) -> List[Trace]: ...
@webmethod(route="/telemetry/get-materialized-span", method="GET")
async def get_materialized_span(
@webmethod(route="/telemetry/get-span-tree", method="POST")
async def get_span_tree(
self,
span_id: str,
attribute_keys_to_return: Optional[List[str]] = None,
attributes_to_return: Optional[List[str]] = None,
max_depth: Optional[int] = None,
) -> MaterializedSpan: ...
) -> SpanWithChildren: ...

View file

@ -222,8 +222,8 @@ class DatasetIORouter(DatasetIO):
filter_condition=filter_condition,
)
async def upload_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
return await self.routing_table.get_provider_impl(dataset_id).upload_rows(
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
dataset_id=dataset_id,
rows=rows,
)

View file

@ -24,7 +24,7 @@ def serialize_value(value: Any) -> str:
return value.model_dump_json()
elif isinstance(value, list) and value and isinstance(value[0], BaseModel):
return json.dumps([item.model_dump_json() for item in value])
elif hasattr(value, "to_dict"): # For objects with to_dict method
elif hasattr(value, "to_dict"):
return json.dumps(value.to_dict())
elif isinstance(value, (dict, list, int, float, str, bool)):
return json.dumps(value)
@ -34,21 +34,6 @@ def serialize_value(value: Any) -> str:
return str(value)
def traced(input: str = None):
"""
A method decorator that enables tracing with input and output capture.
Args:
input: Name of the input parameter to capture in traces
"""
def decorator(method: Callable) -> Callable:
method._trace_input = input
return method
return decorator
def trace_protocol(cls: Type[T]) -> Type[T]:
"""
A class decorator that automatically traces all methods in a protocol/base class
@ -59,22 +44,6 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
is_async = asyncio.iscoroutinefunction(method)
is_async_gen = inspect.isasyncgenfunction(method)
def get_traced_input(args: tuple, kwargs: dict) -> dict:
trace_input = getattr(method, "_trace_input", None)
if not trace_input:
return {}
# Get the mapping of parameter names to values
sig = inspect.signature(method)
bound_args = sig.bind(None, *args, **kwargs) # None for self
bound_args.apply_defaults()
params = dict(list(bound_args.arguments.items())[1:]) # Skip 'self'
# Return the input value if the key exists
if trace_input in params:
return {"input": serialize_value(params[trace_input])}
return {}
def create_span_context(self: Any, *args: Any, **kwargs: Any) -> tuple:
class_name = self.__class__.__name__
method_name = method.__name__
@ -87,7 +56,6 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
"method": method_name,
"type": span_type,
"args": serialize_value(args),
**get_traced_input(args, kwargs),
}
return class_name, method_name, span_attributes
@ -145,33 +113,16 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
else:
return sync_wrapper
# Store the original __init_subclass__ if it exists
original_init_subclass = getattr(cls, "__init_subclass__", None)
# Define a new __init_subclass__ to handle child classes
def __init_subclass__(cls_child, **kwargs): # noqa: N807
# Call original __init_subclass__ if it exists
if original_init_subclass:
original_init_subclass(**kwargs)
traced_methods = {}
for parent in cls_child.__mro__[1:]: # Skip the class itself
for name, method in vars(parent).items():
if inspect.isfunction(method) and getattr(
method, "_trace_input", None
): # noqa: B009
traced_methods[name] = getattr(method, "_trace_input") # noqa: B009
# Trace child class methods if their name matches a traced parent method
for name, method in vars(cls_child).items():
if inspect.isfunction(method) and not name.startswith("_"):
if name in traced_methods:
# Copy the trace configuration from the parent
setattr(method, "_trace_input", traced_methods[name]) # noqa: B010
setattr(cls_child, name, trace_method(method)) # noqa: B010
# Set the new __init_subclass__
cls.__init_subclass__ = classmethod(__init_subclass__)
return cls

View file

@ -132,7 +132,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
next_page_token=str(end),
)
async def upload_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
dataset_info = self.dataset_infos.get(dataset_id)
if dataset_info is None:
raise ValueError(f"Dataset with id {dataset_id} not found")

View file

@ -0,0 +1,135 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import List, Optional
from .config import LogFormat
from llama_stack.apis.telemetry import * # noqa: F403
from .config import ConsoleConfig
class ConsoleTelemetryImpl(Telemetry):
def __init__(self, config: ConsoleConfig) -> None:
self.config = config
self.spans = {}
async def initialize(self) -> None: ...
async def shutdown(self) -> None: ...
async def log_event(self, event: Event):
if (
isinstance(event, StructuredLogEvent)
and event.payload.type == StructuredLogType.SPAN_START.value
):
self.spans[event.span_id] = event.payload
names = []
span_id = event.span_id
while True:
span_payload = self.spans.get(span_id)
if not span_payload:
break
names = [span_payload.name] + names
span_id = span_payload.parent_span_id
span_name = ".".join(names) if names else None
if self.config.log_format == LogFormat.JSON:
formatted = format_event_json(event, span_name)
else:
formatted = format_event_text(event, span_name)
if formatted:
print(formatted)
async def query_traces(
self,
attribute_conditions: Optional[List[QueryCondition]] = None,
attribute_keys_to_return: Optional[List[str]] = None,
limit: Optional[int] = 100,
offset: Optional[int] = 0,
order_by: Optional[List[str]] = None,
) -> List[Trace]:
raise NotImplementedError("Console telemetry does not support trace querying")
async def get_spans(
self,
span_id: str,
attribute_conditions: Optional[List[QueryCondition]] = None,
attribute_keys_to_return: Optional[List[str]] = None,
max_depth: Optional[int] = None,
limit: Optional[int] = 100,
offset: Optional[int] = 0,
order_by: Optional[List[str]] = None,
) -> SpanWithChildren:
raise NotImplementedError("Console telemetry does not support span querying")
COLORS = {
"reset": "\033[0m",
"bold": "\033[1m",
"dim": "\033[2m",
"red": "\033[31m",
"green": "\033[32m",
"yellow": "\033[33m",
"blue": "\033[34m",
"magenta": "\033[35m",
"cyan": "\033[36m",
"white": "\033[37m",
}
SEVERITY_COLORS = {
LogSeverity.VERBOSE: COLORS["dim"] + COLORS["white"],
LogSeverity.DEBUG: COLORS["cyan"],
LogSeverity.INFO: COLORS["green"],
LogSeverity.WARN: COLORS["yellow"],
LogSeverity.ERROR: COLORS["red"],
LogSeverity.CRITICAL: COLORS["bold"] + COLORS["red"],
}
def format_event_text(event: Event, span_name: str) -> Optional[str]:
timestamp = event.timestamp.strftime("%H:%M:%S.%f")[:-3]
span = ""
if span_name:
span = f"{COLORS['magenta']}[{span_name}]{COLORS['reset']} "
if isinstance(event, UnstructuredLogEvent):
severity_color = SEVERITY_COLORS.get(event.severity, COLORS["reset"])
return (
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
f"{severity_color}[{event.severity.name}]{COLORS['reset']} "
f"{span}"
f"{event.message}"
)
elif isinstance(event, StructuredLogEvent):
return None
return f"Unknown event type: {event}"
def format_event_json(event: Event, span_name: str) -> Optional[str]:
base_data = {
"timestamp": event.timestamp.isoformat(),
"trace_id": event.trace_id,
"span_id": event.span_id,
"span_name": span_name,
}
if isinstance(event, UnstructuredLogEvent):
base_data.update(
{"type": "log", "severity": event.severity.name, "message": event.message}
)
return json.dumps(base_data)
elif isinstance(event, StructuredLogEvent):
return None
return json.dumps({"error": f"Unknown event type: {event}"})

View file

@ -24,7 +24,7 @@ from llama_stack.providers.inline.telemetry.meta_reference.console_span_processo
from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor import (
SQLiteSpanProcessor,
)
from llama_stack.providers.utils.telemetry.sqlite import SQLiteTraceStore
from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore
from llama_stack.apis.telemetry import * # noqa: F403
@ -222,28 +222,26 @@ class TelemetryAdapter(Telemetry):
async def query_traces(
self,
attribute_conditions: Optional[List[QueryCondition]] = None,
attribute_keys_to_return: Optional[List[str]] = None,
attribute_filters: Optional[List[QueryCondition]] = None,
limit: Optional[int] = 100,
offset: Optional[int] = 0,
order_by: Optional[List[str]] = None,
) -> List[Trace]:
return await self.trace_store.query_traces(
attribute_conditions=attribute_conditions,
attribute_keys_to_return=attribute_keys_to_return,
attribute_filters=attribute_filters,
limit=limit,
offset=offset,
order_by=order_by,
)
async def get_materialized_span(
async def get_span_tree(
self,
span_id: str,
attribute_keys_to_return: Optional[List[str]] = None,
attributes_to_return: Optional[List[str]] = None,
max_depth: Optional[int] = None,
) -> MaterializedSpan:
) -> SpanWithChildren:
return await self.trace_store.get_materialized_span(
span_id=span_id,
attribute_keys_to_return=attribute_keys_to_return,
attributes_to_return=attributes_to_return,
max_depth=max_depth,
)

View file

@ -96,7 +96,7 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
next_page_token=str(end),
)
async def upload_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
dataset_def = self.dataset_infos[dataset_id]
loaded_dataset = load_hf_dataset(dataset_def)

View file

@ -0,0 +1,259 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import threading
from typing import List, Optional
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.remote.telemetry.opentelemetry.console_span_processor import (
ConsoleSpanProcessor,
)
from llama_stack.providers.remote.telemetry.opentelemetry.sqlite_span_processor import (
SQLiteSpanProcessor,
)
from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore
from opentelemetry import metrics, trace
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.semconv.resource import ResourceAttributes
from llama_stack.apis.telemetry import * # noqa: F403
from .config import OpenTelemetryConfig, TelemetrySink
_GLOBAL_STORAGE = {
"active_spans": {},
"counters": {},
"gauges": {},
"up_down_counters": {},
}
_global_lock = threading.Lock()
def string_to_trace_id(s: str) -> int:
# Convert the string to bytes and then to an integer
return int.from_bytes(s.encode(), byteorder="big", signed=False)
def string_to_span_id(s: str) -> int:
# Use only the first 8 bytes (64 bits) for span ID
return int.from_bytes(s.encode()[:8], byteorder="big", signed=False)
def is_tracing_enabled(tracer):
with tracer.start_as_current_span("check_tracing") as span:
return span.is_recording()
class OpenTelemetryAdapter(Telemetry):
def __init__(self, config: OpenTelemetryConfig, deps) -> None:
self.config = config
self.datasetio = deps[Api.datasetio]
resource = Resource.create(
{
ResourceAttributes.SERVICE_NAME: self.config.service_name,
}
)
provider = TracerProvider(resource=resource)
trace.set_tracer_provider(provider)
if TelemetrySink.JAEGER in self.config.sinks:
otlp_exporter = OTLPSpanExporter(
endpoint=self.config.otel_endpoint,
)
span_processor = BatchSpanProcessor(otlp_exporter)
trace.get_tracer_provider().add_span_processor(span_processor)
metric_reader = PeriodicExportingMetricReader(
OTLPMetricExporter(
endpoint=self.config.otel_endpoint,
)
)
metric_provider = MeterProvider(
resource=resource, metric_readers=[metric_reader]
)
metrics.set_meter_provider(metric_provider)
self.meter = metrics.get_meter(__name__)
if TelemetrySink.SQLITE in self.config.sinks:
trace.get_tracer_provider().add_span_processor(
SQLiteSpanProcessor(self.config.sqlite_db_path)
)
self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path)
if TelemetrySink.CONSOLE in self.config.sinks:
trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor())
self._lock = _global_lock
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
trace.get_tracer_provider().force_flush()
trace.get_tracer_provider().shutdown()
metrics.get_meter_provider().shutdown()
async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None:
if isinstance(event, UnstructuredLogEvent):
self._log_unstructured(event, ttl_seconds)
elif isinstance(event, MetricEvent):
self._log_metric(event)
elif isinstance(event, StructuredLogEvent):
self._log_structured(event, ttl_seconds)
else:
raise ValueError(f"Unknown event type: {event}")
def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None:
with self._lock:
# Use global storage instead of instance storage
span_id = string_to_span_id(event.span_id)
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
if span:
timestamp_ns = int(event.timestamp.timestamp() * 1e9)
span.add_event(
name=event.type,
attributes={
"message": event.message,
"severity": event.severity.value,
"__ttl__": ttl_seconds,
**event.attributes,
},
timestamp=timestamp_ns,
)
else:
print(
f"Warning: No active span found for span_id {span_id}. Dropping event: {event}"
)
def _get_or_create_counter(self, name: str, unit: str) -> metrics.Counter:
if name not in _GLOBAL_STORAGE["counters"]:
_GLOBAL_STORAGE["counters"][name] = self.meter.create_counter(
name=name,
unit=unit,
description=f"Counter for {name}",
)
return _GLOBAL_STORAGE["counters"][name]
def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge:
if name not in _GLOBAL_STORAGE["gauges"]:
_GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge(
name=name,
unit=unit,
description=f"Gauge for {name}",
)
return _GLOBAL_STORAGE["gauges"][name]
def _log_metric(self, event: MetricEvent) -> None:
if isinstance(event.value, int):
counter = self._get_or_create_counter(event.metric, event.unit)
counter.add(event.value, attributes=event.attributes)
elif isinstance(event.value, float):
up_down_counter = self._get_or_create_up_down_counter(
event.metric, event.unit
)
up_down_counter.add(event.value, attributes=event.attributes)
def _get_or_create_up_down_counter(
self, name: str, unit: str
) -> metrics.UpDownCounter:
if name not in _GLOBAL_STORAGE["up_down_counters"]:
_GLOBAL_STORAGE["up_down_counters"][name] = (
self.meter.create_up_down_counter(
name=name,
unit=unit,
description=f"UpDownCounter for {name}",
)
)
return _GLOBAL_STORAGE["up_down_counters"][name]
def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
with self._lock:
span_id = string_to_span_id(event.span_id)
trace_id = string_to_trace_id(event.trace_id)
tracer = trace.get_tracer(__name__)
if event.attributes is None:
event.attributes = {}
event.attributes["__ttl__"] = ttl_seconds
if isinstance(event.payload, SpanStartPayload):
# Check if span already exists to prevent duplicates
if span_id in _GLOBAL_STORAGE["active_spans"]:
return
parent_span = None
if event.payload.parent_span_id:
parent_span_id = string_to_span_id(event.payload.parent_span_id)
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
context = trace.Context(trace_id=trace_id)
if parent_span:
context = trace.set_span_in_context(parent_span, context)
span = tracer.start_span(
name=event.payload.name,
context=context,
attributes=event.attributes or {},
)
_GLOBAL_STORAGE["active_spans"][span_id] = span
elif isinstance(event.payload, SpanEndPayload):
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
if span:
if event.attributes:
span.set_attributes(event.attributes)
status = (
trace.Status(status_code=trace.StatusCode.OK)
if event.payload.status == SpanStatus.OK
else trace.Status(status_code=trace.StatusCode.ERROR)
)
span.set_status(status)
span.end()
_GLOBAL_STORAGE["active_spans"].pop(span_id, None)
else:
raise ValueError(f"Unknown structured log event: {event}")
async def query_traces(
self,
attribute_conditions: Optional[List[QueryCondition]] = None,
attribute_keys_to_return: Optional[List[str]] = None,
limit: Optional[int] = 100,
offset: Optional[int] = 0,
order_by: Optional[List[str]] = None,
) -> List[Trace]:
return await self.trace_store.query_traces(
attribute_conditions=attribute_conditions,
attribute_keys_to_return=attribute_keys_to_return,
limit=limit,
offset=offset,
order_by=order_by,
)
async def get_spans(
self,
span_id: str,
attribute_conditions: Optional[List[QueryCondition]] = None,
attribute_keys_to_return: Optional[List[str]] = None,
max_depth: Optional[int] = None,
limit: Optional[int] = 100,
offset: Optional[int] = 0,
order_by: Optional[List[str]] = None,
) -> SpanWithChildren:
return await self.trace_store.get_spans(
span_id=span_id,
attribute_conditions=attribute_conditions,
attribute_keys_to_return=attribute_keys_to_return,
max_depth=max_depth,
limit=limit,
offset=offset,
order_by=order_by,
)

View file

@ -11,8 +11,8 @@ from typing import List, Optional
import aiosqlite
from llama_stack.apis.telemetry import (
MaterializedSpan,
QueryCondition,
SpanWithChildren,
Trace,
TraceStore,
)
@ -24,56 +24,76 @@ class SQLiteTraceStore(TraceStore):
async def query_traces(
self,
attribute_conditions: Optional[List[QueryCondition]] = None,
attribute_keys_to_return: Optional[List[str]] = None,
attribute_filters: Optional[List[QueryCondition]] = None,
attributes_to_return: Optional[List[str]] = None,
limit: Optional[int] = 100,
offset: Optional[int] = 0,
order_by: Optional[List[str]] = None,
) -> List[Trace]:
# Build the SQL query with attribute selection
select_clause = """
SELECT DISTINCT t.trace_id, t.root_span_id, t.start_time, t.end_time
"""
if attribute_keys_to_return:
for key in attribute_keys_to_return:
select_clause += (
print(attribute_filters, attributes_to_return, limit, offset, order_by)
def build_attribute_select() -> str:
if not attributes_to_return:
return ""
return "".join(
f", json_extract(s.attributes, '$.{key}') as attr_{key}"
for key in attributes_to_return
)
query = (
select_clause
+ """
FROM traces t
JOIN spans s ON t.trace_id = s.trace_id
"""
)
params = []
def build_where_clause() -> tuple[str, list]:
if not attribute_filters:
return "", []
# Add attribute conditions if present
if attribute_conditions:
conditions = []
for condition in attribute_conditions:
conditions.append(
conditions = [
f"json_extract(s.attributes, '$.{condition.key}') {condition.op} ?"
)
params.append(condition.value)
if conditions:
query += " WHERE " + " AND ".join(conditions)
for condition in attribute_filters
]
params = [condition.value for condition in attribute_filters]
where_clause = " WHERE " + " AND ".join(conditions)
return where_clause, params
def build_order_clause() -> str:
if not order_by:
return ""
# Add ordering
if order_by:
order_clauses = []
for field in order_by:
desc = False
if field.startswith("-"):
field = field[1:]
desc = True
order_clauses.append(f"t.{field} {'DESC' if desc else 'ASC'}")
query += " ORDER BY " + ", ".join(order_clauses)
desc = field.startswith("-")
clean_field = field[1:] if desc else field
order_clauses.append(f"t.{clean_field} {'DESC' if desc else 'ASC'}")
return " ORDER BY " + ", ".join(order_clauses)
# Add limit and offset
query += f" LIMIT {limit} OFFSET {offset}"
# Build the main query
base_query = """
WITH matching_traces AS (
SELECT DISTINCT t.trace_id
FROM traces t
JOIN spans s ON t.trace_id = s.trace_id
{where_clause}
),
filtered_traces AS (
SELECT t.trace_id, t.root_span_id, t.start_time, t.end_time
{attribute_select}
FROM matching_traces mt
JOIN traces t ON mt.trace_id = t.trace_id
LEFT JOIN spans s ON t.trace_id = s.trace_id
{order_clause}
)
SELECT DISTINCT trace_id, root_span_id, start_time, end_time
FROM filtered_traces
LIMIT {limit} OFFSET {offset}
"""
where_clause, params = build_where_clause()
query = base_query.format(
attribute_select=build_attribute_select(),
where_clause=where_clause,
order_clause=build_order_clause(),
limit=limit,
offset=offset,
)
# Execute query and return results
async with aiosqlite.connect(self.conn_string) as conn:
conn.row_factory = aiosqlite.Row
async with conn.execute(query, params) as cursor:
@ -91,15 +111,15 @@ class SQLiteTraceStore(TraceStore):
async def get_materialized_span(
self,
span_id: str,
attribute_keys_to_return: Optional[List[str]] = None,
attributes_to_return: Optional[List[str]] = None,
max_depth: Optional[int] = None,
) -> MaterializedSpan:
) -> SpanWithChildren:
# Build the attributes selection
attributes_select = "s.attributes"
if attribute_keys_to_return:
if attributes_to_return:
json_object = ", ".join(
f"'{key}', json_extract(s.attributes, '$.{key}')"
for key in attribute_keys_to_return
for key in attributes_to_return
)
attributes_select = f"json_object({json_object})"
@ -135,7 +155,7 @@ class SQLiteTraceStore(TraceStore):
root_span = None
for row in rows:
span = MaterializedSpan(
span = SpanWithChildren(
span_id=row["span_id"],
trace_id=row["trace_id"],
parent_span_id=row["parent_span_id"],

View file

@ -0,0 +1,180 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from datetime import datetime
from typing import List, Optional, Protocol
import aiosqlite
from llama_stack.apis.telemetry import QueryCondition, SpanWithChildren, Trace
class TraceStore(Protocol):
async def query_traces(
self,
attribute_filters: Optional[List[QueryCondition]] = None,
limit: Optional[int] = 100,
offset: Optional[int] = 0,
order_by: Optional[List[str]] = None,
) -> List[Trace]: ...
async def get_materialized_span(
self,
span_id: str,
attributes_to_return: Optional[List[str]] = None,
max_depth: Optional[int] = None,
) -> SpanWithChildren: ...
class SQLiteTraceStore(TraceStore):
def __init__(self, conn_string: str):
self.conn_string = conn_string
async def query_traces(
self,
attribute_filters: Optional[List[QueryCondition]] = None,
limit: Optional[int] = 100,
offset: Optional[int] = 0,
order_by: Optional[List[str]] = None,
) -> List[Trace]:
def build_where_clause() -> tuple[str, list]:
if not attribute_filters:
return "", []
ops_map = {"eq": "=", "ne": "!=", "gt": ">", "lt": "<"}
conditions = [
f"json_extract(s.attributes, '$.{condition.key}') {ops_map[condition.op]} ?"
for condition in attribute_filters
]
params = [condition.value for condition in attribute_filters]
where_clause = " WHERE " + " AND ".join(conditions)
return where_clause, params
def build_order_clause() -> str:
if not order_by:
return ""
order_clauses = []
for field in order_by:
desc = field.startswith("-")
clean_field = field[1:] if desc else field
order_clauses.append(f"t.{clean_field} {'DESC' if desc else 'ASC'}")
return " ORDER BY " + ", ".join(order_clauses)
# Build the main query
base_query = """
WITH matching_traces AS (
SELECT DISTINCT t.trace_id
FROM traces t
JOIN spans s ON t.trace_id = s.trace_id
{where_clause}
),
filtered_traces AS (
SELECT t.trace_id, t.root_span_id, t.start_time, t.end_time
FROM matching_traces mt
JOIN traces t ON mt.trace_id = t.trace_id
LEFT JOIN spans s ON t.trace_id = s.trace_id
{order_clause}
)
SELECT DISTINCT trace_id, root_span_id, start_time, end_time
FROM filtered_traces
LIMIT {limit} OFFSET {offset}
"""
where_clause, params = build_where_clause()
query = base_query.format(
where_clause=where_clause,
order_clause=build_order_clause(),
limit=limit,
offset=offset,
)
# Execute query and return results
async with aiosqlite.connect(self.conn_string) as conn:
conn.row_factory = aiosqlite.Row
async with conn.execute(query, params) as cursor:
rows = await cursor.fetchall()
return [
Trace(
trace_id=row["trace_id"],
root_span_id=row["root_span_id"],
start_time=datetime.fromisoformat(row["start_time"]),
end_time=datetime.fromisoformat(row["end_time"]),
)
for row in rows
]
async def get_materialized_span(
self,
span_id: str,
attributes_to_return: Optional[List[str]] = None,
max_depth: Optional[int] = None,
) -> SpanWithChildren:
# Build the attributes selection
attributes_select = "s.attributes"
if attributes_to_return:
json_object = ", ".join(
f"'{key}', json_extract(s.attributes, '$.{key}')"
for key in attributes_to_return
)
attributes_select = f"json_object({json_object})"
# SQLite CTE query with filtered attributes
query = f"""
WITH RECURSIVE span_tree AS (
SELECT s.*, 1 as depth, {attributes_select} as filtered_attributes
FROM spans s
WHERE s.span_id = ?
UNION ALL
SELECT s.*, st.depth + 1, {attributes_select} as filtered_attributes
FROM spans s
JOIN span_tree st ON s.parent_span_id = st.span_id
WHERE (? IS NULL OR st.depth < ?)
)
SELECT *
FROM span_tree
ORDER BY depth, start_time
"""
async with aiosqlite.connect(self.conn_string) as conn:
conn.row_factory = aiosqlite.Row
async with conn.execute(query, (span_id, max_depth, max_depth)) as cursor:
rows = await cursor.fetchall()
if not rows:
raise ValueError(f"Span {span_id} not found")
# Build span tree
spans_by_id = {}
root_span = None
for row in rows:
span = SpanWithChildren(
span_id=row["span_id"],
trace_id=row["trace_id"],
parent_span_id=row["parent_span_id"],
name=row["name"],
start_time=datetime.fromisoformat(row["start_time"]),
end_time=datetime.fromisoformat(row["end_time"]),
attributes=json.loads(row["filtered_attributes"]),
status=row["status"].lower(),
children=[],
)
spans_by_id[span.span_id] = span
if span.span_id == span_id:
root_span = span
elif span.parent_span_id in spans_by_id:
spans_by_id[span.parent_span_id].children.append(span)
return root_span