mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
address feedback
This commit is contained in:
parent
32af1f9dd4
commit
b8c395c264
15 changed files with 672 additions and 151 deletions
|
@ -23,7 +23,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
|||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.distribution.tracing import trace_protocol, traced
|
||||
from llama_stack.distribution.tracing import trace_protocol
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.common.deployment_types import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
@ -428,7 +428,6 @@ class Agents(Protocol):
|
|||
) -> AgentCreateResponse: ...
|
||||
|
||||
@webmethod(route="/agents/turn/create")
|
||||
@traced(input="messages")
|
||||
async def create_agent_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
|
|
|
@ -38,7 +38,7 @@ class DatasetIO(Protocol):
|
|||
filter_condition: Optional[str] = None,
|
||||
) -> PaginatedRowsResult: ...
|
||||
|
||||
@webmethod(route="/datasetio/upload", method="POST")
|
||||
async def upload_rows(
|
||||
@webmethod(route="/datasetio/append-rows", method="POST")
|
||||
async def append_rows(
|
||||
self, dataset_id: str, rows: List[Dict[str, Any]]
|
||||
) -> None: ...
|
||||
|
|
|
@ -21,7 +21,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
|||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.distribution.tracing import trace_protocol, traced
|
||||
from llama_stack.distribution.tracing import trace_protocol
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
|
@ -227,7 +227,6 @@ class Inference(Protocol):
|
|||
model_store: ModelStore
|
||||
|
||||
@webmethod(route="/inference/completion")
|
||||
@traced(input="content")
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
|
@ -239,7 +238,6 @@ class Inference(Protocol):
|
|||
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: ...
|
||||
|
||||
@webmethod(route="/inference/chat-completion")
|
||||
@traced(input="messages")
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
|
@ -257,7 +255,6 @@ class Inference(Protocol):
|
|||
]: ...
|
||||
|
||||
@webmethod(route="/inference/embeddings")
|
||||
@traced(input="contents")
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
|
|
|
@ -16,7 +16,7 @@ from pydantic import BaseModel, Field
|
|||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||
from llama_stack.distribution.tracing import trace_protocol, traced
|
||||
from llama_stack.distribution.tracing import trace_protocol
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -50,7 +50,6 @@ class Memory(Protocol):
|
|||
|
||||
# this will just block now until documents are inserted, but it should
|
||||
# probably return a Job instance which can be polled for completion
|
||||
@traced(input="documents")
|
||||
@webmethod(route="/memory/insert")
|
||||
async def insert_documents(
|
||||
self,
|
||||
|
@ -60,7 +59,6 @@ class Memory(Protocol):
|
|||
) -> None: ...
|
||||
|
||||
@webmethod(route="/memory/query")
|
||||
@traced(input="query")
|
||||
async def query_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
|
|
|
@ -10,7 +10,7 @@ from typing import Any, Dict, List, Protocol, runtime_checkable
|
|||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.tracing import trace_protocol, traced
|
||||
from llama_stack.distribution.tracing import trace_protocol
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.shields import * # noqa: F403
|
||||
|
@ -50,7 +50,6 @@ class Safety(Protocol):
|
|||
shield_store: ShieldStore
|
||||
|
||||
@webmethod(route="/safety/run-shield")
|
||||
@traced(input="messages")
|
||||
async def run_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
|
|
|
@ -21,6 +21,9 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
|||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
# Add this constant near the top of the file, after the imports
|
||||
DEFAULT_TTL_DAYS = 7
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SpanStatus(Enum):
|
||||
|
@ -147,57 +150,39 @@ class EvalTrace(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class MaterializedSpan(Span):
|
||||
children: List["MaterializedSpan"] = Field(default_factory=list)
|
||||
class SpanWithChildren(Span):
|
||||
children: List["SpanWithChildren"] = Field(default_factory=list)
|
||||
status: Optional[SpanStatus] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class QueryCondition(BaseModel):
|
||||
key: str
|
||||
op: str
|
||||
op: Literal["eq", "ne", "gt", "lt"]
|
||||
value: Any
|
||||
|
||||
|
||||
class TraceStore(Protocol):
|
||||
|
||||
async def query_traces(
|
||||
self,
|
||||
attribute_conditions: Optional[List[QueryCondition]] = None,
|
||||
attribute_keys_to_return: Optional[List[str]] = None,
|
||||
limit: Optional[int] = 100,
|
||||
offset: Optional[int] = 0,
|
||||
order_by: Optional[List[str]] = None,
|
||||
) -> List[Trace]: ...
|
||||
|
||||
async def get_materialized_span(
|
||||
self,
|
||||
span_id: str,
|
||||
attribute_keys_to_return: Optional[List[str]] = None,
|
||||
max_depth: Optional[int] = None,
|
||||
) -> MaterializedSpan: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Telemetry(Protocol):
|
||||
|
||||
@webmethod(route="/telemetry/log-event")
|
||||
async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None: ...
|
||||
async def log_event(
|
||||
self, event: Event, ttl_seconds: int = DEFAULT_TTL_DAYS * 86400
|
||||
) -> None: ...
|
||||
|
||||
@webmethod(route="/telemetry/query-traces", method="GET")
|
||||
@webmethod(route="/telemetry/query-traces", method="POST")
|
||||
async def query_traces(
|
||||
self,
|
||||
attribute_conditions: Optional[List[QueryCondition]] = None,
|
||||
attribute_keys_to_return: Optional[List[str]] = None,
|
||||
attribute_filters: Optional[List[QueryCondition]] = None,
|
||||
limit: Optional[int] = 100,
|
||||
offset: Optional[int] = 0,
|
||||
order_by: Optional[List[str]] = None,
|
||||
) -> List[Trace]: ...
|
||||
|
||||
@webmethod(route="/telemetry/get-materialized-span", method="GET")
|
||||
async def get_materialized_span(
|
||||
@webmethod(route="/telemetry/get-span-tree", method="POST")
|
||||
async def get_span_tree(
|
||||
self,
|
||||
span_id: str,
|
||||
attribute_keys_to_return: Optional[List[str]] = None,
|
||||
attributes_to_return: Optional[List[str]] = None,
|
||||
max_depth: Optional[int] = None,
|
||||
) -> MaterializedSpan: ...
|
||||
) -> SpanWithChildren: ...
|
||||
|
|
|
@ -222,8 +222,8 @@ class DatasetIORouter(DatasetIO):
|
|||
filter_condition=filter_condition,
|
||||
)
|
||||
|
||||
async def upload_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||
return await self.routing_table.get_provider_impl(dataset_id).upload_rows(
|
||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
|
||||
dataset_id=dataset_id,
|
||||
rows=rows,
|
||||
)
|
||||
|
|
|
@ -24,7 +24,7 @@ def serialize_value(value: Any) -> str:
|
|||
return value.model_dump_json()
|
||||
elif isinstance(value, list) and value and isinstance(value[0], BaseModel):
|
||||
return json.dumps([item.model_dump_json() for item in value])
|
||||
elif hasattr(value, "to_dict"): # For objects with to_dict method
|
||||
elif hasattr(value, "to_dict"):
|
||||
return json.dumps(value.to_dict())
|
||||
elif isinstance(value, (dict, list, int, float, str, bool)):
|
||||
return json.dumps(value)
|
||||
|
@ -34,21 +34,6 @@ def serialize_value(value: Any) -> str:
|
|||
return str(value)
|
||||
|
||||
|
||||
def traced(input: str = None):
|
||||
"""
|
||||
A method decorator that enables tracing with input and output capture.
|
||||
|
||||
Args:
|
||||
input: Name of the input parameter to capture in traces
|
||||
"""
|
||||
|
||||
def decorator(method: Callable) -> Callable:
|
||||
method._trace_input = input
|
||||
return method
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def trace_protocol(cls: Type[T]) -> Type[T]:
|
||||
"""
|
||||
A class decorator that automatically traces all methods in a protocol/base class
|
||||
|
@ -59,22 +44,6 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
|
|||
is_async = asyncio.iscoroutinefunction(method)
|
||||
is_async_gen = inspect.isasyncgenfunction(method)
|
||||
|
||||
def get_traced_input(args: tuple, kwargs: dict) -> dict:
|
||||
trace_input = getattr(method, "_trace_input", None)
|
||||
if not trace_input:
|
||||
return {}
|
||||
|
||||
# Get the mapping of parameter names to values
|
||||
sig = inspect.signature(method)
|
||||
bound_args = sig.bind(None, *args, **kwargs) # None for self
|
||||
bound_args.apply_defaults()
|
||||
params = dict(list(bound_args.arguments.items())[1:]) # Skip 'self'
|
||||
|
||||
# Return the input value if the key exists
|
||||
if trace_input in params:
|
||||
return {"input": serialize_value(params[trace_input])}
|
||||
return {}
|
||||
|
||||
def create_span_context(self: Any, *args: Any, **kwargs: Any) -> tuple:
|
||||
class_name = self.__class__.__name__
|
||||
method_name = method.__name__
|
||||
|
@ -87,7 +56,6 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
|
|||
"method": method_name,
|
||||
"type": span_type,
|
||||
"args": serialize_value(args),
|
||||
**get_traced_input(args, kwargs),
|
||||
}
|
||||
|
||||
return class_name, method_name, span_attributes
|
||||
|
@ -145,33 +113,16 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
|
|||
else:
|
||||
return sync_wrapper
|
||||
|
||||
# Store the original __init_subclass__ if it exists
|
||||
original_init_subclass = getattr(cls, "__init_subclass__", None)
|
||||
|
||||
# Define a new __init_subclass__ to handle child classes
|
||||
def __init_subclass__(cls_child, **kwargs): # noqa: N807
|
||||
# Call original __init_subclass__ if it exists
|
||||
if original_init_subclass:
|
||||
original_init_subclass(**kwargs)
|
||||
|
||||
traced_methods = {}
|
||||
for parent in cls_child.__mro__[1:]: # Skip the class itself
|
||||
for name, method in vars(parent).items():
|
||||
if inspect.isfunction(method) and getattr(
|
||||
method, "_trace_input", None
|
||||
): # noqa: B009
|
||||
traced_methods[name] = getattr(method, "_trace_input") # noqa: B009
|
||||
|
||||
# Trace child class methods if their name matches a traced parent method
|
||||
for name, method in vars(cls_child).items():
|
||||
if inspect.isfunction(method) and not name.startswith("_"):
|
||||
if name in traced_methods:
|
||||
# Copy the trace configuration from the parent
|
||||
setattr(method, "_trace_input", traced_methods[name]) # noqa: B010
|
||||
|
||||
setattr(cls_child, name, trace_method(method)) # noqa: B010
|
||||
|
||||
# Set the new __init_subclass__
|
||||
cls.__init_subclass__ = classmethod(__init_subclass__)
|
||||
|
||||
return cls
|
||||
|
|
|
@ -132,7 +132,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
next_page_token=str(end),
|
||||
)
|
||||
|
||||
async def upload_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||
dataset_info = self.dataset_infos.get(dataset_id)
|
||||
if dataset_info is None:
|
||||
raise ValueError(f"Dataset with id {dataset_id} not found")
|
||||
|
|
135
llama_stack/providers/inline/meta_reference/telemetry/console.py
Normal file
135
llama_stack/providers/inline/meta_reference/telemetry/console.py
Normal file
|
@ -0,0 +1,135 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from typing import List, Optional
|
||||
|
||||
from .config import LogFormat
|
||||
|
||||
from llama_stack.apis.telemetry import * # noqa: F403
|
||||
from .config import ConsoleConfig
|
||||
|
||||
|
||||
class ConsoleTelemetryImpl(Telemetry):
|
||||
def __init__(self, config: ConsoleConfig) -> None:
|
||||
self.config = config
|
||||
self.spans = {}
|
||||
|
||||
async def initialize(self) -> None: ...
|
||||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def log_event(self, event: Event):
|
||||
if (
|
||||
isinstance(event, StructuredLogEvent)
|
||||
and event.payload.type == StructuredLogType.SPAN_START.value
|
||||
):
|
||||
self.spans[event.span_id] = event.payload
|
||||
|
||||
names = []
|
||||
span_id = event.span_id
|
||||
while True:
|
||||
span_payload = self.spans.get(span_id)
|
||||
if not span_payload:
|
||||
break
|
||||
|
||||
names = [span_payload.name] + names
|
||||
span_id = span_payload.parent_span_id
|
||||
|
||||
span_name = ".".join(names) if names else None
|
||||
|
||||
if self.config.log_format == LogFormat.JSON:
|
||||
formatted = format_event_json(event, span_name)
|
||||
else:
|
||||
formatted = format_event_text(event, span_name)
|
||||
|
||||
if formatted:
|
||||
print(formatted)
|
||||
|
||||
async def query_traces(
|
||||
self,
|
||||
attribute_conditions: Optional[List[QueryCondition]] = None,
|
||||
attribute_keys_to_return: Optional[List[str]] = None,
|
||||
limit: Optional[int] = 100,
|
||||
offset: Optional[int] = 0,
|
||||
order_by: Optional[List[str]] = None,
|
||||
) -> List[Trace]:
|
||||
raise NotImplementedError("Console telemetry does not support trace querying")
|
||||
|
||||
async def get_spans(
|
||||
self,
|
||||
span_id: str,
|
||||
attribute_conditions: Optional[List[QueryCondition]] = None,
|
||||
attribute_keys_to_return: Optional[List[str]] = None,
|
||||
max_depth: Optional[int] = None,
|
||||
limit: Optional[int] = 100,
|
||||
offset: Optional[int] = 0,
|
||||
order_by: Optional[List[str]] = None,
|
||||
) -> SpanWithChildren:
|
||||
raise NotImplementedError("Console telemetry does not support span querying")
|
||||
|
||||
|
||||
COLORS = {
|
||||
"reset": "\033[0m",
|
||||
"bold": "\033[1m",
|
||||
"dim": "\033[2m",
|
||||
"red": "\033[31m",
|
||||
"green": "\033[32m",
|
||||
"yellow": "\033[33m",
|
||||
"blue": "\033[34m",
|
||||
"magenta": "\033[35m",
|
||||
"cyan": "\033[36m",
|
||||
"white": "\033[37m",
|
||||
}
|
||||
|
||||
SEVERITY_COLORS = {
|
||||
LogSeverity.VERBOSE: COLORS["dim"] + COLORS["white"],
|
||||
LogSeverity.DEBUG: COLORS["cyan"],
|
||||
LogSeverity.INFO: COLORS["green"],
|
||||
LogSeverity.WARN: COLORS["yellow"],
|
||||
LogSeverity.ERROR: COLORS["red"],
|
||||
LogSeverity.CRITICAL: COLORS["bold"] + COLORS["red"],
|
||||
}
|
||||
|
||||
|
||||
def format_event_text(event: Event, span_name: str) -> Optional[str]:
|
||||
timestamp = event.timestamp.strftime("%H:%M:%S.%f")[:-3]
|
||||
span = ""
|
||||
if span_name:
|
||||
span = f"{COLORS['magenta']}[{span_name}]{COLORS['reset']} "
|
||||
if isinstance(event, UnstructuredLogEvent):
|
||||
severity_color = SEVERITY_COLORS.get(event.severity, COLORS["reset"])
|
||||
return (
|
||||
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
|
||||
f"{severity_color}[{event.severity.name}]{COLORS['reset']} "
|
||||
f"{span}"
|
||||
f"{event.message}"
|
||||
)
|
||||
|
||||
elif isinstance(event, StructuredLogEvent):
|
||||
return None
|
||||
|
||||
return f"Unknown event type: {event}"
|
||||
|
||||
|
||||
def format_event_json(event: Event, span_name: str) -> Optional[str]:
|
||||
base_data = {
|
||||
"timestamp": event.timestamp.isoformat(),
|
||||
"trace_id": event.trace_id,
|
||||
"span_id": event.span_id,
|
||||
"span_name": span_name,
|
||||
}
|
||||
|
||||
if isinstance(event, UnstructuredLogEvent):
|
||||
base_data.update(
|
||||
{"type": "log", "severity": event.severity.name, "message": event.message}
|
||||
)
|
||||
return json.dumps(base_data)
|
||||
|
||||
elif isinstance(event, StructuredLogEvent):
|
||||
return None
|
||||
|
||||
return json.dumps({"error": f"Unknown event type: {event}"})
|
|
@ -24,7 +24,7 @@ from llama_stack.providers.inline.telemetry.meta_reference.console_span_processo
|
|||
from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor import (
|
||||
SQLiteSpanProcessor,
|
||||
)
|
||||
from llama_stack.providers.utils.telemetry.sqlite import SQLiteTraceStore
|
||||
from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore
|
||||
|
||||
from llama_stack.apis.telemetry import * # noqa: F403
|
||||
|
||||
|
@ -222,28 +222,26 @@ class TelemetryAdapter(Telemetry):
|
|||
|
||||
async def query_traces(
|
||||
self,
|
||||
attribute_conditions: Optional[List[QueryCondition]] = None,
|
||||
attribute_keys_to_return: Optional[List[str]] = None,
|
||||
attribute_filters: Optional[List[QueryCondition]] = None,
|
||||
limit: Optional[int] = 100,
|
||||
offset: Optional[int] = 0,
|
||||
order_by: Optional[List[str]] = None,
|
||||
) -> List[Trace]:
|
||||
return await self.trace_store.query_traces(
|
||||
attribute_conditions=attribute_conditions,
|
||||
attribute_keys_to_return=attribute_keys_to_return,
|
||||
attribute_filters=attribute_filters,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
order_by=order_by,
|
||||
)
|
||||
|
||||
async def get_materialized_span(
|
||||
async def get_span_tree(
|
||||
self,
|
||||
span_id: str,
|
||||
attribute_keys_to_return: Optional[List[str]] = None,
|
||||
attributes_to_return: Optional[List[str]] = None,
|
||||
max_depth: Optional[int] = None,
|
||||
) -> MaterializedSpan:
|
||||
) -> SpanWithChildren:
|
||||
return await self.trace_store.get_materialized_span(
|
||||
span_id=span_id,
|
||||
attribute_keys_to_return=attribute_keys_to_return,
|
||||
attributes_to_return=attributes_to_return,
|
||||
max_depth=max_depth,
|
||||
)
|
||||
|
|
|
@ -96,7 +96,7 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
next_page_token=str(end),
|
||||
)
|
||||
|
||||
async def upload_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||
dataset_def = self.dataset_infos[dataset_id]
|
||||
loaded_dataset = load_hf_dataset(dataset_def)
|
||||
|
||||
|
|
|
@ -0,0 +1,259 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import threading
|
||||
from typing import List, Optional
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.providers.remote.telemetry.opentelemetry.console_span_processor import (
|
||||
ConsoleSpanProcessor,
|
||||
)
|
||||
from llama_stack.providers.remote.telemetry.opentelemetry.sqlite_span_processor import (
|
||||
SQLiteSpanProcessor,
|
||||
)
|
||||
from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore
|
||||
|
||||
from opentelemetry import metrics, trace
|
||||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.sdk.metrics import MeterProvider
|
||||
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.semconv.resource import ResourceAttributes
|
||||
|
||||
|
||||
from llama_stack.apis.telemetry import * # noqa: F403
|
||||
|
||||
from .config import OpenTelemetryConfig, TelemetrySink
|
||||
|
||||
_GLOBAL_STORAGE = {
|
||||
"active_spans": {},
|
||||
"counters": {},
|
||||
"gauges": {},
|
||||
"up_down_counters": {},
|
||||
}
|
||||
_global_lock = threading.Lock()
|
||||
|
||||
|
||||
def string_to_trace_id(s: str) -> int:
|
||||
# Convert the string to bytes and then to an integer
|
||||
return int.from_bytes(s.encode(), byteorder="big", signed=False)
|
||||
|
||||
|
||||
def string_to_span_id(s: str) -> int:
|
||||
# Use only the first 8 bytes (64 bits) for span ID
|
||||
return int.from_bytes(s.encode()[:8], byteorder="big", signed=False)
|
||||
|
||||
|
||||
def is_tracing_enabled(tracer):
|
||||
with tracer.start_as_current_span("check_tracing") as span:
|
||||
return span.is_recording()
|
||||
|
||||
|
||||
class OpenTelemetryAdapter(Telemetry):
|
||||
def __init__(self, config: OpenTelemetryConfig, deps) -> None:
|
||||
self.config = config
|
||||
self.datasetio = deps[Api.datasetio]
|
||||
|
||||
resource = Resource.create(
|
||||
{
|
||||
ResourceAttributes.SERVICE_NAME: self.config.service_name,
|
||||
}
|
||||
)
|
||||
|
||||
provider = TracerProvider(resource=resource)
|
||||
trace.set_tracer_provider(provider)
|
||||
if TelemetrySink.JAEGER in self.config.sinks:
|
||||
otlp_exporter = OTLPSpanExporter(
|
||||
endpoint=self.config.otel_endpoint,
|
||||
)
|
||||
span_processor = BatchSpanProcessor(otlp_exporter)
|
||||
trace.get_tracer_provider().add_span_processor(span_processor)
|
||||
metric_reader = PeriodicExportingMetricReader(
|
||||
OTLPMetricExporter(
|
||||
endpoint=self.config.otel_endpoint,
|
||||
)
|
||||
)
|
||||
metric_provider = MeterProvider(
|
||||
resource=resource, metric_readers=[metric_reader]
|
||||
)
|
||||
metrics.set_meter_provider(metric_provider)
|
||||
self.meter = metrics.get_meter(__name__)
|
||||
if TelemetrySink.SQLITE in self.config.sinks:
|
||||
trace.get_tracer_provider().add_span_processor(
|
||||
SQLiteSpanProcessor(self.config.sqlite_db_path)
|
||||
)
|
||||
self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path)
|
||||
if TelemetrySink.CONSOLE in self.config.sinks:
|
||||
trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor())
|
||||
self._lock = _global_lock
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
trace.get_tracer_provider().force_flush()
|
||||
trace.get_tracer_provider().shutdown()
|
||||
metrics.get_meter_provider().shutdown()
|
||||
|
||||
async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None:
|
||||
if isinstance(event, UnstructuredLogEvent):
|
||||
self._log_unstructured(event, ttl_seconds)
|
||||
elif isinstance(event, MetricEvent):
|
||||
self._log_metric(event)
|
||||
elif isinstance(event, StructuredLogEvent):
|
||||
self._log_structured(event, ttl_seconds)
|
||||
else:
|
||||
raise ValueError(f"Unknown event type: {event}")
|
||||
|
||||
def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None:
|
||||
with self._lock:
|
||||
# Use global storage instead of instance storage
|
||||
span_id = string_to_span_id(event.span_id)
|
||||
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
|
||||
|
||||
if span:
|
||||
timestamp_ns = int(event.timestamp.timestamp() * 1e9)
|
||||
span.add_event(
|
||||
name=event.type,
|
||||
attributes={
|
||||
"message": event.message,
|
||||
"severity": event.severity.value,
|
||||
"__ttl__": ttl_seconds,
|
||||
**event.attributes,
|
||||
},
|
||||
timestamp=timestamp_ns,
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"Warning: No active span found for span_id {span_id}. Dropping event: {event}"
|
||||
)
|
||||
|
||||
def _get_or_create_counter(self, name: str, unit: str) -> metrics.Counter:
|
||||
if name not in _GLOBAL_STORAGE["counters"]:
|
||||
_GLOBAL_STORAGE["counters"][name] = self.meter.create_counter(
|
||||
name=name,
|
||||
unit=unit,
|
||||
description=f"Counter for {name}",
|
||||
)
|
||||
return _GLOBAL_STORAGE["counters"][name]
|
||||
|
||||
def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge:
|
||||
if name not in _GLOBAL_STORAGE["gauges"]:
|
||||
_GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge(
|
||||
name=name,
|
||||
unit=unit,
|
||||
description=f"Gauge for {name}",
|
||||
)
|
||||
return _GLOBAL_STORAGE["gauges"][name]
|
||||
|
||||
def _log_metric(self, event: MetricEvent) -> None:
|
||||
if isinstance(event.value, int):
|
||||
counter = self._get_or_create_counter(event.metric, event.unit)
|
||||
counter.add(event.value, attributes=event.attributes)
|
||||
elif isinstance(event.value, float):
|
||||
up_down_counter = self._get_or_create_up_down_counter(
|
||||
event.metric, event.unit
|
||||
)
|
||||
up_down_counter.add(event.value, attributes=event.attributes)
|
||||
|
||||
def _get_or_create_up_down_counter(
|
||||
self, name: str, unit: str
|
||||
) -> metrics.UpDownCounter:
|
||||
if name not in _GLOBAL_STORAGE["up_down_counters"]:
|
||||
_GLOBAL_STORAGE["up_down_counters"][name] = (
|
||||
self.meter.create_up_down_counter(
|
||||
name=name,
|
||||
unit=unit,
|
||||
description=f"UpDownCounter for {name}",
|
||||
)
|
||||
)
|
||||
return _GLOBAL_STORAGE["up_down_counters"][name]
|
||||
|
||||
def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
|
||||
with self._lock:
|
||||
span_id = string_to_span_id(event.span_id)
|
||||
trace_id = string_to_trace_id(event.trace_id)
|
||||
tracer = trace.get_tracer(__name__)
|
||||
if event.attributes is None:
|
||||
event.attributes = {}
|
||||
event.attributes["__ttl__"] = ttl_seconds
|
||||
|
||||
if isinstance(event.payload, SpanStartPayload):
|
||||
# Check if span already exists to prevent duplicates
|
||||
if span_id in _GLOBAL_STORAGE["active_spans"]:
|
||||
return
|
||||
|
||||
parent_span = None
|
||||
if event.payload.parent_span_id:
|
||||
parent_span_id = string_to_span_id(event.payload.parent_span_id)
|
||||
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
|
||||
|
||||
context = trace.Context(trace_id=trace_id)
|
||||
if parent_span:
|
||||
context = trace.set_span_in_context(parent_span, context)
|
||||
|
||||
span = tracer.start_span(
|
||||
name=event.payload.name,
|
||||
context=context,
|
||||
attributes=event.attributes or {},
|
||||
)
|
||||
_GLOBAL_STORAGE["active_spans"][span_id] = span
|
||||
|
||||
elif isinstance(event.payload, SpanEndPayload):
|
||||
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
|
||||
if span:
|
||||
if event.attributes:
|
||||
span.set_attributes(event.attributes)
|
||||
|
||||
status = (
|
||||
trace.Status(status_code=trace.StatusCode.OK)
|
||||
if event.payload.status == SpanStatus.OK
|
||||
else trace.Status(status_code=trace.StatusCode.ERROR)
|
||||
)
|
||||
span.set_status(status)
|
||||
span.end()
|
||||
_GLOBAL_STORAGE["active_spans"].pop(span_id, None)
|
||||
else:
|
||||
raise ValueError(f"Unknown structured log event: {event}")
|
||||
|
||||
async def query_traces(
|
||||
self,
|
||||
attribute_conditions: Optional[List[QueryCondition]] = None,
|
||||
attribute_keys_to_return: Optional[List[str]] = None,
|
||||
limit: Optional[int] = 100,
|
||||
offset: Optional[int] = 0,
|
||||
order_by: Optional[List[str]] = None,
|
||||
) -> List[Trace]:
|
||||
return await self.trace_store.query_traces(
|
||||
attribute_conditions=attribute_conditions,
|
||||
attribute_keys_to_return=attribute_keys_to_return,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
order_by=order_by,
|
||||
)
|
||||
|
||||
async def get_spans(
|
||||
self,
|
||||
span_id: str,
|
||||
attribute_conditions: Optional[List[QueryCondition]] = None,
|
||||
attribute_keys_to_return: Optional[List[str]] = None,
|
||||
max_depth: Optional[int] = None,
|
||||
limit: Optional[int] = 100,
|
||||
offset: Optional[int] = 0,
|
||||
order_by: Optional[List[str]] = None,
|
||||
) -> SpanWithChildren:
|
||||
return await self.trace_store.get_spans(
|
||||
span_id=span_id,
|
||||
attribute_conditions=attribute_conditions,
|
||||
attribute_keys_to_return=attribute_keys_to_return,
|
||||
max_depth=max_depth,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
order_by=order_by,
|
||||
)
|
|
@ -11,8 +11,8 @@ from typing import List, Optional
|
|||
import aiosqlite
|
||||
|
||||
from llama_stack.apis.telemetry import (
|
||||
MaterializedSpan,
|
||||
QueryCondition,
|
||||
SpanWithChildren,
|
||||
Trace,
|
||||
TraceStore,
|
||||
)
|
||||
|
@ -24,56 +24,76 @@ class SQLiteTraceStore(TraceStore):
|
|||
|
||||
async def query_traces(
|
||||
self,
|
||||
attribute_conditions: Optional[List[QueryCondition]] = None,
|
||||
attribute_keys_to_return: Optional[List[str]] = None,
|
||||
attribute_filters: Optional[List[QueryCondition]] = None,
|
||||
attributes_to_return: Optional[List[str]] = None,
|
||||
limit: Optional[int] = 100,
|
||||
offset: Optional[int] = 0,
|
||||
order_by: Optional[List[str]] = None,
|
||||
) -> List[Trace]:
|
||||
# Build the SQL query with attribute selection
|
||||
select_clause = """
|
||||
SELECT DISTINCT t.trace_id, t.root_span_id, t.start_time, t.end_time
|
||||
"""
|
||||
if attribute_keys_to_return:
|
||||
for key in attribute_keys_to_return:
|
||||
select_clause += (
|
||||
f", json_extract(s.attributes, '$.{key}') as attr_{key}"
|
||||
)
|
||||
print(attribute_filters, attributes_to_return, limit, offset, order_by)
|
||||
|
||||
query = (
|
||||
select_clause
|
||||
+ """
|
||||
FROM traces t
|
||||
JOIN spans s ON t.trace_id = s.trace_id
|
||||
"""
|
||||
)
|
||||
params = []
|
||||
def build_attribute_select() -> str:
|
||||
if not attributes_to_return:
|
||||
return ""
|
||||
return "".join(
|
||||
f", json_extract(s.attributes, '$.{key}') as attr_{key}"
|
||||
for key in attributes_to_return
|
||||
)
|
||||
|
||||
# Add attribute conditions if present
|
||||
if attribute_conditions:
|
||||
conditions = []
|
||||
for condition in attribute_conditions:
|
||||
conditions.append(
|
||||
f"json_extract(s.attributes, '$.{condition.key}') {condition.op} ?"
|
||||
)
|
||||
params.append(condition.value)
|
||||
if conditions:
|
||||
query += " WHERE " + " AND ".join(conditions)
|
||||
def build_where_clause() -> tuple[str, list]:
|
||||
if not attribute_filters:
|
||||
return "", []
|
||||
|
||||
conditions = [
|
||||
f"json_extract(s.attributes, '$.{condition.key}') {condition.op} ?"
|
||||
for condition in attribute_filters
|
||||
]
|
||||
params = [condition.value for condition in attribute_filters]
|
||||
where_clause = " WHERE " + " AND ".join(conditions)
|
||||
return where_clause, params
|
||||
|
||||
def build_order_clause() -> str:
|
||||
if not order_by:
|
||||
return ""
|
||||
|
||||
# Add ordering
|
||||
if order_by:
|
||||
order_clauses = []
|
||||
for field in order_by:
|
||||
desc = False
|
||||
if field.startswith("-"):
|
||||
field = field[1:]
|
||||
desc = True
|
||||
order_clauses.append(f"t.{field} {'DESC' if desc else 'ASC'}")
|
||||
query += " ORDER BY " + ", ".join(order_clauses)
|
||||
desc = field.startswith("-")
|
||||
clean_field = field[1:] if desc else field
|
||||
order_clauses.append(f"t.{clean_field} {'DESC' if desc else 'ASC'}")
|
||||
return " ORDER BY " + ", ".join(order_clauses)
|
||||
|
||||
# Add limit and offset
|
||||
query += f" LIMIT {limit} OFFSET {offset}"
|
||||
# Build the main query
|
||||
base_query = """
|
||||
WITH matching_traces AS (
|
||||
SELECT DISTINCT t.trace_id
|
||||
FROM traces t
|
||||
JOIN spans s ON t.trace_id = s.trace_id
|
||||
{where_clause}
|
||||
),
|
||||
filtered_traces AS (
|
||||
SELECT t.trace_id, t.root_span_id, t.start_time, t.end_time
|
||||
{attribute_select}
|
||||
FROM matching_traces mt
|
||||
JOIN traces t ON mt.trace_id = t.trace_id
|
||||
LEFT JOIN spans s ON t.trace_id = s.trace_id
|
||||
{order_clause}
|
||||
)
|
||||
SELECT DISTINCT trace_id, root_span_id, start_time, end_time
|
||||
FROM filtered_traces
|
||||
LIMIT {limit} OFFSET {offset}
|
||||
"""
|
||||
|
||||
where_clause, params = build_where_clause()
|
||||
query = base_query.format(
|
||||
attribute_select=build_attribute_select(),
|
||||
where_clause=where_clause,
|
||||
order_clause=build_order_clause(),
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
# Execute query and return results
|
||||
async with aiosqlite.connect(self.conn_string) as conn:
|
||||
conn.row_factory = aiosqlite.Row
|
||||
async with conn.execute(query, params) as cursor:
|
||||
|
@ -91,15 +111,15 @@ class SQLiteTraceStore(TraceStore):
|
|||
async def get_materialized_span(
|
||||
self,
|
||||
span_id: str,
|
||||
attribute_keys_to_return: Optional[List[str]] = None,
|
||||
attributes_to_return: Optional[List[str]] = None,
|
||||
max_depth: Optional[int] = None,
|
||||
) -> MaterializedSpan:
|
||||
) -> SpanWithChildren:
|
||||
# Build the attributes selection
|
||||
attributes_select = "s.attributes"
|
||||
if attribute_keys_to_return:
|
||||
if attributes_to_return:
|
||||
json_object = ", ".join(
|
||||
f"'{key}', json_extract(s.attributes, '$.{key}')"
|
||||
for key in attribute_keys_to_return
|
||||
for key in attributes_to_return
|
||||
)
|
||||
attributes_select = f"json_object({json_object})"
|
||||
|
||||
|
@ -135,7 +155,7 @@ class SQLiteTraceStore(TraceStore):
|
|||
root_span = None
|
||||
|
||||
for row in rows:
|
||||
span = MaterializedSpan(
|
||||
span = SpanWithChildren(
|
||||
span_id=row["span_id"],
|
||||
trace_id=row["trace_id"],
|
||||
parent_span_id=row["parent_span_id"],
|
||||
|
|
180
llama_stack/providers/utils/telemetry/sqlite_trace_store.py
Normal file
180
llama_stack/providers/utils/telemetry/sqlite_trace_store.py
Normal file
|
@ -0,0 +1,180 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Protocol
|
||||
|
||||
import aiosqlite
|
||||
|
||||
from llama_stack.apis.telemetry import QueryCondition, SpanWithChildren, Trace
|
||||
|
||||
|
||||
class TraceStore(Protocol):
|
||||
|
||||
async def query_traces(
|
||||
self,
|
||||
attribute_filters: Optional[List[QueryCondition]] = None,
|
||||
limit: Optional[int] = 100,
|
||||
offset: Optional[int] = 0,
|
||||
order_by: Optional[List[str]] = None,
|
||||
) -> List[Trace]: ...
|
||||
|
||||
async def get_materialized_span(
|
||||
self,
|
||||
span_id: str,
|
||||
attributes_to_return: Optional[List[str]] = None,
|
||||
max_depth: Optional[int] = None,
|
||||
) -> SpanWithChildren: ...
|
||||
|
||||
|
||||
class SQLiteTraceStore(TraceStore):
|
||||
def __init__(self, conn_string: str):
|
||||
self.conn_string = conn_string
|
||||
|
||||
async def query_traces(
|
||||
self,
|
||||
attribute_filters: Optional[List[QueryCondition]] = None,
|
||||
limit: Optional[int] = 100,
|
||||
offset: Optional[int] = 0,
|
||||
order_by: Optional[List[str]] = None,
|
||||
) -> List[Trace]:
|
||||
|
||||
def build_where_clause() -> tuple[str, list]:
|
||||
if not attribute_filters:
|
||||
return "", []
|
||||
|
||||
ops_map = {"eq": "=", "ne": "!=", "gt": ">", "lt": "<"}
|
||||
|
||||
conditions = [
|
||||
f"json_extract(s.attributes, '$.{condition.key}') {ops_map[condition.op]} ?"
|
||||
for condition in attribute_filters
|
||||
]
|
||||
params = [condition.value for condition in attribute_filters]
|
||||
where_clause = " WHERE " + " AND ".join(conditions)
|
||||
return where_clause, params
|
||||
|
||||
def build_order_clause() -> str:
|
||||
if not order_by:
|
||||
return ""
|
||||
|
||||
order_clauses = []
|
||||
for field in order_by:
|
||||
desc = field.startswith("-")
|
||||
clean_field = field[1:] if desc else field
|
||||
order_clauses.append(f"t.{clean_field} {'DESC' if desc else 'ASC'}")
|
||||
return " ORDER BY " + ", ".join(order_clauses)
|
||||
|
||||
# Build the main query
|
||||
base_query = """
|
||||
WITH matching_traces AS (
|
||||
SELECT DISTINCT t.trace_id
|
||||
FROM traces t
|
||||
JOIN spans s ON t.trace_id = s.trace_id
|
||||
{where_clause}
|
||||
),
|
||||
filtered_traces AS (
|
||||
SELECT t.trace_id, t.root_span_id, t.start_time, t.end_time
|
||||
FROM matching_traces mt
|
||||
JOIN traces t ON mt.trace_id = t.trace_id
|
||||
LEFT JOIN spans s ON t.trace_id = s.trace_id
|
||||
{order_clause}
|
||||
)
|
||||
SELECT DISTINCT trace_id, root_span_id, start_time, end_time
|
||||
FROM filtered_traces
|
||||
LIMIT {limit} OFFSET {offset}
|
||||
"""
|
||||
|
||||
where_clause, params = build_where_clause()
|
||||
query = base_query.format(
|
||||
where_clause=where_clause,
|
||||
order_clause=build_order_clause(),
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
# Execute query and return results
|
||||
async with aiosqlite.connect(self.conn_string) as conn:
|
||||
conn.row_factory = aiosqlite.Row
|
||||
async with conn.execute(query, params) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
return [
|
||||
Trace(
|
||||
trace_id=row["trace_id"],
|
||||
root_span_id=row["root_span_id"],
|
||||
start_time=datetime.fromisoformat(row["start_time"]),
|
||||
end_time=datetime.fromisoformat(row["end_time"]),
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
async def get_materialized_span(
|
||||
self,
|
||||
span_id: str,
|
||||
attributes_to_return: Optional[List[str]] = None,
|
||||
max_depth: Optional[int] = None,
|
||||
) -> SpanWithChildren:
|
||||
# Build the attributes selection
|
||||
attributes_select = "s.attributes"
|
||||
if attributes_to_return:
|
||||
json_object = ", ".join(
|
||||
f"'{key}', json_extract(s.attributes, '$.{key}')"
|
||||
for key in attributes_to_return
|
||||
)
|
||||
attributes_select = f"json_object({json_object})"
|
||||
|
||||
# SQLite CTE query with filtered attributes
|
||||
query = f"""
|
||||
WITH RECURSIVE span_tree AS (
|
||||
SELECT s.*, 1 as depth, {attributes_select} as filtered_attributes
|
||||
FROM spans s
|
||||
WHERE s.span_id = ?
|
||||
|
||||
UNION ALL
|
||||
|
||||
SELECT s.*, st.depth + 1, {attributes_select} as filtered_attributes
|
||||
FROM spans s
|
||||
JOIN span_tree st ON s.parent_span_id = st.span_id
|
||||
WHERE (? IS NULL OR st.depth < ?)
|
||||
)
|
||||
SELECT *
|
||||
FROM span_tree
|
||||
ORDER BY depth, start_time
|
||||
"""
|
||||
|
||||
async with aiosqlite.connect(self.conn_string) as conn:
|
||||
conn.row_factory = aiosqlite.Row
|
||||
async with conn.execute(query, (span_id, max_depth, max_depth)) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
|
||||
if not rows:
|
||||
raise ValueError(f"Span {span_id} not found")
|
||||
|
||||
# Build span tree
|
||||
spans_by_id = {}
|
||||
root_span = None
|
||||
|
||||
for row in rows:
|
||||
span = SpanWithChildren(
|
||||
span_id=row["span_id"],
|
||||
trace_id=row["trace_id"],
|
||||
parent_span_id=row["parent_span_id"],
|
||||
name=row["name"],
|
||||
start_time=datetime.fromisoformat(row["start_time"]),
|
||||
end_time=datetime.fromisoformat(row["end_time"]),
|
||||
attributes=json.loads(row["filtered_attributes"]),
|
||||
status=row["status"].lower(),
|
||||
children=[],
|
||||
)
|
||||
|
||||
spans_by_id[span.span_id] = span
|
||||
|
||||
if span.span_id == span_id:
|
||||
root_span = span
|
||||
elif span.parent_span_id in spans_by_id:
|
||||
spans_by_id[span.parent_span_id].children.append(span)
|
||||
|
||||
return root_span
|
Loading…
Add table
Add a link
Reference in a new issue