mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
address feedback
This commit is contained in:
parent
32af1f9dd4
commit
b8c395c264
15 changed files with 672 additions and 151 deletions
|
@ -23,7 +23,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.distribution.tracing import trace_protocol, traced
|
from llama_stack.distribution.tracing import trace_protocol
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.apis.common.deployment_types import * # noqa: F403
|
from llama_stack.apis.common.deployment_types import * # noqa: F403
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
@ -428,7 +428,6 @@ class Agents(Protocol):
|
||||||
) -> AgentCreateResponse: ...
|
) -> AgentCreateResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/agents/turn/create")
|
@webmethod(route="/agents/turn/create")
|
||||||
@traced(input="messages")
|
|
||||||
async def create_agent_turn(
|
async def create_agent_turn(
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
|
|
|
@ -38,7 +38,7 @@ class DatasetIO(Protocol):
|
||||||
filter_condition: Optional[str] = None,
|
filter_condition: Optional[str] = None,
|
||||||
) -> PaginatedRowsResult: ...
|
) -> PaginatedRowsResult: ...
|
||||||
|
|
||||||
@webmethod(route="/datasetio/upload", method="POST")
|
@webmethod(route="/datasetio/append-rows", method="POST")
|
||||||
async def upload_rows(
|
async def append_rows(
|
||||||
self, dataset_id: str, rows: List[Dict[str, Any]]
|
self, dataset_id: str, rows: List[Dict[str, Any]]
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
|
@ -21,7 +21,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.distribution.tracing import trace_protocol, traced
|
from llama_stack.distribution.tracing import trace_protocol
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.apis.models import * # noqa: F403
|
from llama_stack.apis.models import * # noqa: F403
|
||||||
|
@ -227,7 +227,6 @@ class Inference(Protocol):
|
||||||
model_store: ModelStore
|
model_store: ModelStore
|
||||||
|
|
||||||
@webmethod(route="/inference/completion")
|
@webmethod(route="/inference/completion")
|
||||||
@traced(input="content")
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -239,7 +238,6 @@ class Inference(Protocol):
|
||||||
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: ...
|
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: ...
|
||||||
|
|
||||||
@webmethod(route="/inference/chat-completion")
|
@webmethod(route="/inference/chat-completion")
|
||||||
@traced(input="messages")
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -257,7 +255,6 @@ class Inference(Protocol):
|
||||||
]: ...
|
]: ...
|
||||||
|
|
||||||
@webmethod(route="/inference/embeddings")
|
@webmethod(route="/inference/embeddings")
|
||||||
@traced(input="contents")
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
|
|
@ -16,7 +16,7 @@ from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||||
from llama_stack.distribution.tracing import trace_protocol, traced
|
from llama_stack.distribution.tracing import trace_protocol
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -50,7 +50,6 @@ class Memory(Protocol):
|
||||||
|
|
||||||
# this will just block now until documents are inserted, but it should
|
# this will just block now until documents are inserted, but it should
|
||||||
# probably return a Job instance which can be polled for completion
|
# probably return a Job instance which can be polled for completion
|
||||||
@traced(input="documents")
|
|
||||||
@webmethod(route="/memory/insert")
|
@webmethod(route="/memory/insert")
|
||||||
async def insert_documents(
|
async def insert_documents(
|
||||||
self,
|
self,
|
||||||
|
@ -60,7 +59,6 @@ class Memory(Protocol):
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
@webmethod(route="/memory/query")
|
@webmethod(route="/memory/query")
|
||||||
@traced(input="query")
|
|
||||||
async def query_documents(
|
async def query_documents(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
|
|
|
@ -10,7 +10,7 @@ from typing import Any, Dict, List, Protocol, runtime_checkable
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.distribution.tracing import trace_protocol, traced
|
from llama_stack.distribution.tracing import trace_protocol
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.apis.shields import * # noqa: F403
|
from llama_stack.apis.shields import * # noqa: F403
|
||||||
|
@ -50,7 +50,6 @@ class Safety(Protocol):
|
||||||
shield_store: ShieldStore
|
shield_store: ShieldStore
|
||||||
|
|
||||||
@webmethod(route="/safety/run-shield")
|
@webmethod(route="/safety/run-shield")
|
||||||
@traced(input="messages")
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
shield_id: str,
|
shield_id: str,
|
||||||
|
|
|
@ -21,6 +21,9 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
# Add this constant near the top of the file, after the imports
|
||||||
|
DEFAULT_TTL_DAYS = 7
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class SpanStatus(Enum):
|
class SpanStatus(Enum):
|
||||||
|
@ -147,57 +150,39 @@ class EvalTrace(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class MaterializedSpan(Span):
|
class SpanWithChildren(Span):
|
||||||
children: List["MaterializedSpan"] = Field(default_factory=list)
|
children: List["SpanWithChildren"] = Field(default_factory=list)
|
||||||
status: Optional[SpanStatus] = None
|
status: Optional[SpanStatus] = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class QueryCondition(BaseModel):
|
class QueryCondition(BaseModel):
|
||||||
key: str
|
key: str
|
||||||
op: str
|
op: Literal["eq", "ne", "gt", "lt"]
|
||||||
value: Any
|
value: Any
|
||||||
|
|
||||||
|
|
||||||
class TraceStore(Protocol):
|
|
||||||
|
|
||||||
async def query_traces(
|
|
||||||
self,
|
|
||||||
attribute_conditions: Optional[List[QueryCondition]] = None,
|
|
||||||
attribute_keys_to_return: Optional[List[str]] = None,
|
|
||||||
limit: Optional[int] = 100,
|
|
||||||
offset: Optional[int] = 0,
|
|
||||||
order_by: Optional[List[str]] = None,
|
|
||||||
) -> List[Trace]: ...
|
|
||||||
|
|
||||||
async def get_materialized_span(
|
|
||||||
self,
|
|
||||||
span_id: str,
|
|
||||||
attribute_keys_to_return: Optional[List[str]] = None,
|
|
||||||
max_depth: Optional[int] = None,
|
|
||||||
) -> MaterializedSpan: ...
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class Telemetry(Protocol):
|
class Telemetry(Protocol):
|
||||||
|
|
||||||
@webmethod(route="/telemetry/log-event")
|
@webmethod(route="/telemetry/log-event")
|
||||||
async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None: ...
|
async def log_event(
|
||||||
|
self, event: Event, ttl_seconds: int = DEFAULT_TTL_DAYS * 86400
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
@webmethod(route="/telemetry/query-traces", method="GET")
|
@webmethod(route="/telemetry/query-traces", method="POST")
|
||||||
async def query_traces(
|
async def query_traces(
|
||||||
self,
|
self,
|
||||||
attribute_conditions: Optional[List[QueryCondition]] = None,
|
attribute_filters: Optional[List[QueryCondition]] = None,
|
||||||
attribute_keys_to_return: Optional[List[str]] = None,
|
|
||||||
limit: Optional[int] = 100,
|
limit: Optional[int] = 100,
|
||||||
offset: Optional[int] = 0,
|
offset: Optional[int] = 0,
|
||||||
order_by: Optional[List[str]] = None,
|
order_by: Optional[List[str]] = None,
|
||||||
) -> List[Trace]: ...
|
) -> List[Trace]: ...
|
||||||
|
|
||||||
@webmethod(route="/telemetry/get-materialized-span", method="GET")
|
@webmethod(route="/telemetry/get-span-tree", method="POST")
|
||||||
async def get_materialized_span(
|
async def get_span_tree(
|
||||||
self,
|
self,
|
||||||
span_id: str,
|
span_id: str,
|
||||||
attribute_keys_to_return: Optional[List[str]] = None,
|
attributes_to_return: Optional[List[str]] = None,
|
||||||
max_depth: Optional[int] = None,
|
max_depth: Optional[int] = None,
|
||||||
) -> MaterializedSpan: ...
|
) -> SpanWithChildren: ...
|
||||||
|
|
|
@ -222,8 +222,8 @@ class DatasetIORouter(DatasetIO):
|
||||||
filter_condition=filter_condition,
|
filter_condition=filter_condition,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def upload_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||||
return await self.routing_table.get_provider_impl(dataset_id).upload_rows(
|
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
rows=rows,
|
rows=rows,
|
||||||
)
|
)
|
||||||
|
|
|
@ -24,7 +24,7 @@ def serialize_value(value: Any) -> str:
|
||||||
return value.model_dump_json()
|
return value.model_dump_json()
|
||||||
elif isinstance(value, list) and value and isinstance(value[0], BaseModel):
|
elif isinstance(value, list) and value and isinstance(value[0], BaseModel):
|
||||||
return json.dumps([item.model_dump_json() for item in value])
|
return json.dumps([item.model_dump_json() for item in value])
|
||||||
elif hasattr(value, "to_dict"): # For objects with to_dict method
|
elif hasattr(value, "to_dict"):
|
||||||
return json.dumps(value.to_dict())
|
return json.dumps(value.to_dict())
|
||||||
elif isinstance(value, (dict, list, int, float, str, bool)):
|
elif isinstance(value, (dict, list, int, float, str, bool)):
|
||||||
return json.dumps(value)
|
return json.dumps(value)
|
||||||
|
@ -34,21 +34,6 @@ def serialize_value(value: Any) -> str:
|
||||||
return str(value)
|
return str(value)
|
||||||
|
|
||||||
|
|
||||||
def traced(input: str = None):
|
|
||||||
"""
|
|
||||||
A method decorator that enables tracing with input and output capture.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input: Name of the input parameter to capture in traces
|
|
||||||
"""
|
|
||||||
|
|
||||||
def decorator(method: Callable) -> Callable:
|
|
||||||
method._trace_input = input
|
|
||||||
return method
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
def trace_protocol(cls: Type[T]) -> Type[T]:
|
def trace_protocol(cls: Type[T]) -> Type[T]:
|
||||||
"""
|
"""
|
||||||
A class decorator that automatically traces all methods in a protocol/base class
|
A class decorator that automatically traces all methods in a protocol/base class
|
||||||
|
@ -59,22 +44,6 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
|
||||||
is_async = asyncio.iscoroutinefunction(method)
|
is_async = asyncio.iscoroutinefunction(method)
|
||||||
is_async_gen = inspect.isasyncgenfunction(method)
|
is_async_gen = inspect.isasyncgenfunction(method)
|
||||||
|
|
||||||
def get_traced_input(args: tuple, kwargs: dict) -> dict:
|
|
||||||
trace_input = getattr(method, "_trace_input", None)
|
|
||||||
if not trace_input:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
# Get the mapping of parameter names to values
|
|
||||||
sig = inspect.signature(method)
|
|
||||||
bound_args = sig.bind(None, *args, **kwargs) # None for self
|
|
||||||
bound_args.apply_defaults()
|
|
||||||
params = dict(list(bound_args.arguments.items())[1:]) # Skip 'self'
|
|
||||||
|
|
||||||
# Return the input value if the key exists
|
|
||||||
if trace_input in params:
|
|
||||||
return {"input": serialize_value(params[trace_input])}
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def create_span_context(self: Any, *args: Any, **kwargs: Any) -> tuple:
|
def create_span_context(self: Any, *args: Any, **kwargs: Any) -> tuple:
|
||||||
class_name = self.__class__.__name__
|
class_name = self.__class__.__name__
|
||||||
method_name = method.__name__
|
method_name = method.__name__
|
||||||
|
@ -87,7 +56,6 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
|
||||||
"method": method_name,
|
"method": method_name,
|
||||||
"type": span_type,
|
"type": span_type,
|
||||||
"args": serialize_value(args),
|
"args": serialize_value(args),
|
||||||
**get_traced_input(args, kwargs),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return class_name, method_name, span_attributes
|
return class_name, method_name, span_attributes
|
||||||
|
@ -145,33 +113,16 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
|
||||||
else:
|
else:
|
||||||
return sync_wrapper
|
return sync_wrapper
|
||||||
|
|
||||||
# Store the original __init_subclass__ if it exists
|
|
||||||
original_init_subclass = getattr(cls, "__init_subclass__", None)
|
original_init_subclass = getattr(cls, "__init_subclass__", None)
|
||||||
|
|
||||||
# Define a new __init_subclass__ to handle child classes
|
|
||||||
def __init_subclass__(cls_child, **kwargs): # noqa: N807
|
def __init_subclass__(cls_child, **kwargs): # noqa: N807
|
||||||
# Call original __init_subclass__ if it exists
|
|
||||||
if original_init_subclass:
|
if original_init_subclass:
|
||||||
original_init_subclass(**kwargs)
|
original_init_subclass(**kwargs)
|
||||||
|
|
||||||
traced_methods = {}
|
|
||||||
for parent in cls_child.__mro__[1:]: # Skip the class itself
|
|
||||||
for name, method in vars(parent).items():
|
|
||||||
if inspect.isfunction(method) and getattr(
|
|
||||||
method, "_trace_input", None
|
|
||||||
): # noqa: B009
|
|
||||||
traced_methods[name] = getattr(method, "_trace_input") # noqa: B009
|
|
||||||
|
|
||||||
# Trace child class methods if their name matches a traced parent method
|
|
||||||
for name, method in vars(cls_child).items():
|
for name, method in vars(cls_child).items():
|
||||||
if inspect.isfunction(method) and not name.startswith("_"):
|
if inspect.isfunction(method) and not name.startswith("_"):
|
||||||
if name in traced_methods:
|
|
||||||
# Copy the trace configuration from the parent
|
|
||||||
setattr(method, "_trace_input", traced_methods[name]) # noqa: B010
|
|
||||||
|
|
||||||
setattr(cls_child, name, trace_method(method)) # noqa: B010
|
setattr(cls_child, name, trace_method(method)) # noqa: B010
|
||||||
|
|
||||||
# Set the new __init_subclass__
|
|
||||||
cls.__init_subclass__ = classmethod(__init_subclass__)
|
cls.__init_subclass__ = classmethod(__init_subclass__)
|
||||||
|
|
||||||
return cls
|
return cls
|
||||||
|
|
|
@ -132,7 +132,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
next_page_token=str(end),
|
next_page_token=str(end),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def upload_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||||
dataset_info = self.dataset_infos.get(dataset_id)
|
dataset_info = self.dataset_infos.get(dataset_id)
|
||||||
if dataset_info is None:
|
if dataset_info is None:
|
||||||
raise ValueError(f"Dataset with id {dataset_id} not found")
|
raise ValueError(f"Dataset with id {dataset_id} not found")
|
||||||
|
|
135
llama_stack/providers/inline/meta_reference/telemetry/console.py
Normal file
135
llama_stack/providers/inline/meta_reference/telemetry/console.py
Normal file
|
@ -0,0 +1,135 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from .config import LogFormat
|
||||||
|
|
||||||
|
from llama_stack.apis.telemetry import * # noqa: F403
|
||||||
|
from .config import ConsoleConfig
|
||||||
|
|
||||||
|
|
||||||
|
class ConsoleTelemetryImpl(Telemetry):
|
||||||
|
def __init__(self, config: ConsoleConfig) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.spans = {}
|
||||||
|
|
||||||
|
async def initialize(self) -> None: ...
|
||||||
|
|
||||||
|
async def shutdown(self) -> None: ...
|
||||||
|
|
||||||
|
async def log_event(self, event: Event):
|
||||||
|
if (
|
||||||
|
isinstance(event, StructuredLogEvent)
|
||||||
|
and event.payload.type == StructuredLogType.SPAN_START.value
|
||||||
|
):
|
||||||
|
self.spans[event.span_id] = event.payload
|
||||||
|
|
||||||
|
names = []
|
||||||
|
span_id = event.span_id
|
||||||
|
while True:
|
||||||
|
span_payload = self.spans.get(span_id)
|
||||||
|
if not span_payload:
|
||||||
|
break
|
||||||
|
|
||||||
|
names = [span_payload.name] + names
|
||||||
|
span_id = span_payload.parent_span_id
|
||||||
|
|
||||||
|
span_name = ".".join(names) if names else None
|
||||||
|
|
||||||
|
if self.config.log_format == LogFormat.JSON:
|
||||||
|
formatted = format_event_json(event, span_name)
|
||||||
|
else:
|
||||||
|
formatted = format_event_text(event, span_name)
|
||||||
|
|
||||||
|
if formatted:
|
||||||
|
print(formatted)
|
||||||
|
|
||||||
|
async def query_traces(
|
||||||
|
self,
|
||||||
|
attribute_conditions: Optional[List[QueryCondition]] = None,
|
||||||
|
attribute_keys_to_return: Optional[List[str]] = None,
|
||||||
|
limit: Optional[int] = 100,
|
||||||
|
offset: Optional[int] = 0,
|
||||||
|
order_by: Optional[List[str]] = None,
|
||||||
|
) -> List[Trace]:
|
||||||
|
raise NotImplementedError("Console telemetry does not support trace querying")
|
||||||
|
|
||||||
|
async def get_spans(
|
||||||
|
self,
|
||||||
|
span_id: str,
|
||||||
|
attribute_conditions: Optional[List[QueryCondition]] = None,
|
||||||
|
attribute_keys_to_return: Optional[List[str]] = None,
|
||||||
|
max_depth: Optional[int] = None,
|
||||||
|
limit: Optional[int] = 100,
|
||||||
|
offset: Optional[int] = 0,
|
||||||
|
order_by: Optional[List[str]] = None,
|
||||||
|
) -> SpanWithChildren:
|
||||||
|
raise NotImplementedError("Console telemetry does not support span querying")
|
||||||
|
|
||||||
|
|
||||||
|
COLORS = {
|
||||||
|
"reset": "\033[0m",
|
||||||
|
"bold": "\033[1m",
|
||||||
|
"dim": "\033[2m",
|
||||||
|
"red": "\033[31m",
|
||||||
|
"green": "\033[32m",
|
||||||
|
"yellow": "\033[33m",
|
||||||
|
"blue": "\033[34m",
|
||||||
|
"magenta": "\033[35m",
|
||||||
|
"cyan": "\033[36m",
|
||||||
|
"white": "\033[37m",
|
||||||
|
}
|
||||||
|
|
||||||
|
SEVERITY_COLORS = {
|
||||||
|
LogSeverity.VERBOSE: COLORS["dim"] + COLORS["white"],
|
||||||
|
LogSeverity.DEBUG: COLORS["cyan"],
|
||||||
|
LogSeverity.INFO: COLORS["green"],
|
||||||
|
LogSeverity.WARN: COLORS["yellow"],
|
||||||
|
LogSeverity.ERROR: COLORS["red"],
|
||||||
|
LogSeverity.CRITICAL: COLORS["bold"] + COLORS["red"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def format_event_text(event: Event, span_name: str) -> Optional[str]:
|
||||||
|
timestamp = event.timestamp.strftime("%H:%M:%S.%f")[:-3]
|
||||||
|
span = ""
|
||||||
|
if span_name:
|
||||||
|
span = f"{COLORS['magenta']}[{span_name}]{COLORS['reset']} "
|
||||||
|
if isinstance(event, UnstructuredLogEvent):
|
||||||
|
severity_color = SEVERITY_COLORS.get(event.severity, COLORS["reset"])
|
||||||
|
return (
|
||||||
|
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
|
||||||
|
f"{severity_color}[{event.severity.name}]{COLORS['reset']} "
|
||||||
|
f"{span}"
|
||||||
|
f"{event.message}"
|
||||||
|
)
|
||||||
|
|
||||||
|
elif isinstance(event, StructuredLogEvent):
|
||||||
|
return None
|
||||||
|
|
||||||
|
return f"Unknown event type: {event}"
|
||||||
|
|
||||||
|
|
||||||
|
def format_event_json(event: Event, span_name: str) -> Optional[str]:
|
||||||
|
base_data = {
|
||||||
|
"timestamp": event.timestamp.isoformat(),
|
||||||
|
"trace_id": event.trace_id,
|
||||||
|
"span_id": event.span_id,
|
||||||
|
"span_name": span_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
if isinstance(event, UnstructuredLogEvent):
|
||||||
|
base_data.update(
|
||||||
|
{"type": "log", "severity": event.severity.name, "message": event.message}
|
||||||
|
)
|
||||||
|
return json.dumps(base_data)
|
||||||
|
|
||||||
|
elif isinstance(event, StructuredLogEvent):
|
||||||
|
return None
|
||||||
|
|
||||||
|
return json.dumps({"error": f"Unknown event type: {event}"})
|
|
@ -24,7 +24,7 @@ from llama_stack.providers.inline.telemetry.meta_reference.console_span_processo
|
||||||
from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor import (
|
from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor import (
|
||||||
SQLiteSpanProcessor,
|
SQLiteSpanProcessor,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.telemetry.sqlite import SQLiteTraceStore
|
from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore
|
||||||
|
|
||||||
from llama_stack.apis.telemetry import * # noqa: F403
|
from llama_stack.apis.telemetry import * # noqa: F403
|
||||||
|
|
||||||
|
@ -222,28 +222,26 @@ class TelemetryAdapter(Telemetry):
|
||||||
|
|
||||||
async def query_traces(
|
async def query_traces(
|
||||||
self,
|
self,
|
||||||
attribute_conditions: Optional[List[QueryCondition]] = None,
|
attribute_filters: Optional[List[QueryCondition]] = None,
|
||||||
attribute_keys_to_return: Optional[List[str]] = None,
|
|
||||||
limit: Optional[int] = 100,
|
limit: Optional[int] = 100,
|
||||||
offset: Optional[int] = 0,
|
offset: Optional[int] = 0,
|
||||||
order_by: Optional[List[str]] = None,
|
order_by: Optional[List[str]] = None,
|
||||||
) -> List[Trace]:
|
) -> List[Trace]:
|
||||||
return await self.trace_store.query_traces(
|
return await self.trace_store.query_traces(
|
||||||
attribute_conditions=attribute_conditions,
|
attribute_filters=attribute_filters,
|
||||||
attribute_keys_to_return=attribute_keys_to_return,
|
|
||||||
limit=limit,
|
limit=limit,
|
||||||
offset=offset,
|
offset=offset,
|
||||||
order_by=order_by,
|
order_by=order_by,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_materialized_span(
|
async def get_span_tree(
|
||||||
self,
|
self,
|
||||||
span_id: str,
|
span_id: str,
|
||||||
attribute_keys_to_return: Optional[List[str]] = None,
|
attributes_to_return: Optional[List[str]] = None,
|
||||||
max_depth: Optional[int] = None,
|
max_depth: Optional[int] = None,
|
||||||
) -> MaterializedSpan:
|
) -> SpanWithChildren:
|
||||||
return await self.trace_store.get_materialized_span(
|
return await self.trace_store.get_materialized_span(
|
||||||
span_id=span_id,
|
span_id=span_id,
|
||||||
attribute_keys_to_return=attribute_keys_to_return,
|
attributes_to_return=attributes_to_return,
|
||||||
max_depth=max_depth,
|
max_depth=max_depth,
|
||||||
)
|
)
|
||||||
|
|
|
@ -96,7 +96,7 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
next_page_token=str(end),
|
next_page_token=str(end),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def upload_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||||
dataset_def = self.dataset_infos[dataset_id]
|
dataset_def = self.dataset_infos[dataset_id]
|
||||||
loaded_dataset = load_hf_dataset(dataset_def)
|
loaded_dataset = load_hf_dataset(dataset_def)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,259 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import threading
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
from llama_stack.providers.remote.telemetry.opentelemetry.console_span_processor import (
|
||||||
|
ConsoleSpanProcessor,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.remote.telemetry.opentelemetry.sqlite_span_processor import (
|
||||||
|
SQLiteSpanProcessor,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore
|
||||||
|
|
||||||
|
from opentelemetry import metrics, trace
|
||||||
|
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
|
||||||
|
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||||
|
from opentelemetry.sdk.metrics import MeterProvider
|
||||||
|
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
|
||||||
|
from opentelemetry.sdk.resources import Resource
|
||||||
|
from opentelemetry.sdk.trace import TracerProvider
|
||||||
|
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||||
|
from opentelemetry.semconv.resource import ResourceAttributes
|
||||||
|
|
||||||
|
|
||||||
|
from llama_stack.apis.telemetry import * # noqa: F403
|
||||||
|
|
||||||
|
from .config import OpenTelemetryConfig, TelemetrySink
|
||||||
|
|
||||||
|
_GLOBAL_STORAGE = {
|
||||||
|
"active_spans": {},
|
||||||
|
"counters": {},
|
||||||
|
"gauges": {},
|
||||||
|
"up_down_counters": {},
|
||||||
|
}
|
||||||
|
_global_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def string_to_trace_id(s: str) -> int:
|
||||||
|
# Convert the string to bytes and then to an integer
|
||||||
|
return int.from_bytes(s.encode(), byteorder="big", signed=False)
|
||||||
|
|
||||||
|
|
||||||
|
def string_to_span_id(s: str) -> int:
|
||||||
|
# Use only the first 8 bytes (64 bits) for span ID
|
||||||
|
return int.from_bytes(s.encode()[:8], byteorder="big", signed=False)
|
||||||
|
|
||||||
|
|
||||||
|
def is_tracing_enabled(tracer):
|
||||||
|
with tracer.start_as_current_span("check_tracing") as span:
|
||||||
|
return span.is_recording()
|
||||||
|
|
||||||
|
|
||||||
|
class OpenTelemetryAdapter(Telemetry):
|
||||||
|
def __init__(self, config: OpenTelemetryConfig, deps) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.datasetio = deps[Api.datasetio]
|
||||||
|
|
||||||
|
resource = Resource.create(
|
||||||
|
{
|
||||||
|
ResourceAttributes.SERVICE_NAME: self.config.service_name,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
provider = TracerProvider(resource=resource)
|
||||||
|
trace.set_tracer_provider(provider)
|
||||||
|
if TelemetrySink.JAEGER in self.config.sinks:
|
||||||
|
otlp_exporter = OTLPSpanExporter(
|
||||||
|
endpoint=self.config.otel_endpoint,
|
||||||
|
)
|
||||||
|
span_processor = BatchSpanProcessor(otlp_exporter)
|
||||||
|
trace.get_tracer_provider().add_span_processor(span_processor)
|
||||||
|
metric_reader = PeriodicExportingMetricReader(
|
||||||
|
OTLPMetricExporter(
|
||||||
|
endpoint=self.config.otel_endpoint,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
metric_provider = MeterProvider(
|
||||||
|
resource=resource, metric_readers=[metric_reader]
|
||||||
|
)
|
||||||
|
metrics.set_meter_provider(metric_provider)
|
||||||
|
self.meter = metrics.get_meter(__name__)
|
||||||
|
if TelemetrySink.SQLITE in self.config.sinks:
|
||||||
|
trace.get_tracer_provider().add_span_processor(
|
||||||
|
SQLiteSpanProcessor(self.config.sqlite_db_path)
|
||||||
|
)
|
||||||
|
self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path)
|
||||||
|
if TelemetrySink.CONSOLE in self.config.sinks:
|
||||||
|
trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor())
|
||||||
|
self._lock = _global_lock
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
trace.get_tracer_provider().force_flush()
|
||||||
|
trace.get_tracer_provider().shutdown()
|
||||||
|
metrics.get_meter_provider().shutdown()
|
||||||
|
|
||||||
|
async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None:
|
||||||
|
if isinstance(event, UnstructuredLogEvent):
|
||||||
|
self._log_unstructured(event, ttl_seconds)
|
||||||
|
elif isinstance(event, MetricEvent):
|
||||||
|
self._log_metric(event)
|
||||||
|
elif isinstance(event, StructuredLogEvent):
|
||||||
|
self._log_structured(event, ttl_seconds)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown event type: {event}")
|
||||||
|
|
||||||
|
def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None:
|
||||||
|
with self._lock:
|
||||||
|
# Use global storage instead of instance storage
|
||||||
|
span_id = string_to_span_id(event.span_id)
|
||||||
|
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
|
||||||
|
|
||||||
|
if span:
|
||||||
|
timestamp_ns = int(event.timestamp.timestamp() * 1e9)
|
||||||
|
span.add_event(
|
||||||
|
name=event.type,
|
||||||
|
attributes={
|
||||||
|
"message": event.message,
|
||||||
|
"severity": event.severity.value,
|
||||||
|
"__ttl__": ttl_seconds,
|
||||||
|
**event.attributes,
|
||||||
|
},
|
||||||
|
timestamp=timestamp_ns,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"Warning: No active span found for span_id {span_id}. Dropping event: {event}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_or_create_counter(self, name: str, unit: str) -> metrics.Counter:
|
||||||
|
if name not in _GLOBAL_STORAGE["counters"]:
|
||||||
|
_GLOBAL_STORAGE["counters"][name] = self.meter.create_counter(
|
||||||
|
name=name,
|
||||||
|
unit=unit,
|
||||||
|
description=f"Counter for {name}",
|
||||||
|
)
|
||||||
|
return _GLOBAL_STORAGE["counters"][name]
|
||||||
|
|
||||||
|
def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge:
|
||||||
|
if name not in _GLOBAL_STORAGE["gauges"]:
|
||||||
|
_GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge(
|
||||||
|
name=name,
|
||||||
|
unit=unit,
|
||||||
|
description=f"Gauge for {name}",
|
||||||
|
)
|
||||||
|
return _GLOBAL_STORAGE["gauges"][name]
|
||||||
|
|
||||||
|
def _log_metric(self, event: MetricEvent) -> None:
|
||||||
|
if isinstance(event.value, int):
|
||||||
|
counter = self._get_or_create_counter(event.metric, event.unit)
|
||||||
|
counter.add(event.value, attributes=event.attributes)
|
||||||
|
elif isinstance(event.value, float):
|
||||||
|
up_down_counter = self._get_or_create_up_down_counter(
|
||||||
|
event.metric, event.unit
|
||||||
|
)
|
||||||
|
up_down_counter.add(event.value, attributes=event.attributes)
|
||||||
|
|
||||||
|
def _get_or_create_up_down_counter(
|
||||||
|
self, name: str, unit: str
|
||||||
|
) -> metrics.UpDownCounter:
|
||||||
|
if name not in _GLOBAL_STORAGE["up_down_counters"]:
|
||||||
|
_GLOBAL_STORAGE["up_down_counters"][name] = (
|
||||||
|
self.meter.create_up_down_counter(
|
||||||
|
name=name,
|
||||||
|
unit=unit,
|
||||||
|
description=f"UpDownCounter for {name}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return _GLOBAL_STORAGE["up_down_counters"][name]
|
||||||
|
|
||||||
|
def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
|
||||||
|
with self._lock:
|
||||||
|
span_id = string_to_span_id(event.span_id)
|
||||||
|
trace_id = string_to_trace_id(event.trace_id)
|
||||||
|
tracer = trace.get_tracer(__name__)
|
||||||
|
if event.attributes is None:
|
||||||
|
event.attributes = {}
|
||||||
|
event.attributes["__ttl__"] = ttl_seconds
|
||||||
|
|
||||||
|
if isinstance(event.payload, SpanStartPayload):
|
||||||
|
# Check if span already exists to prevent duplicates
|
||||||
|
if span_id in _GLOBAL_STORAGE["active_spans"]:
|
||||||
|
return
|
||||||
|
|
||||||
|
parent_span = None
|
||||||
|
if event.payload.parent_span_id:
|
||||||
|
parent_span_id = string_to_span_id(event.payload.parent_span_id)
|
||||||
|
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
|
||||||
|
|
||||||
|
context = trace.Context(trace_id=trace_id)
|
||||||
|
if parent_span:
|
||||||
|
context = trace.set_span_in_context(parent_span, context)
|
||||||
|
|
||||||
|
span = tracer.start_span(
|
||||||
|
name=event.payload.name,
|
||||||
|
context=context,
|
||||||
|
attributes=event.attributes or {},
|
||||||
|
)
|
||||||
|
_GLOBAL_STORAGE["active_spans"][span_id] = span
|
||||||
|
|
||||||
|
elif isinstance(event.payload, SpanEndPayload):
|
||||||
|
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
|
||||||
|
if span:
|
||||||
|
if event.attributes:
|
||||||
|
span.set_attributes(event.attributes)
|
||||||
|
|
||||||
|
status = (
|
||||||
|
trace.Status(status_code=trace.StatusCode.OK)
|
||||||
|
if event.payload.status == SpanStatus.OK
|
||||||
|
else trace.Status(status_code=trace.StatusCode.ERROR)
|
||||||
|
)
|
||||||
|
span.set_status(status)
|
||||||
|
span.end()
|
||||||
|
_GLOBAL_STORAGE["active_spans"].pop(span_id, None)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown structured log event: {event}")
|
||||||
|
|
||||||
|
async def query_traces(
|
||||||
|
self,
|
||||||
|
attribute_conditions: Optional[List[QueryCondition]] = None,
|
||||||
|
attribute_keys_to_return: Optional[List[str]] = None,
|
||||||
|
limit: Optional[int] = 100,
|
||||||
|
offset: Optional[int] = 0,
|
||||||
|
order_by: Optional[List[str]] = None,
|
||||||
|
) -> List[Trace]:
|
||||||
|
return await self.trace_store.query_traces(
|
||||||
|
attribute_conditions=attribute_conditions,
|
||||||
|
attribute_keys_to_return=attribute_keys_to_return,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
order_by=order_by,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_spans(
|
||||||
|
self,
|
||||||
|
span_id: str,
|
||||||
|
attribute_conditions: Optional[List[QueryCondition]] = None,
|
||||||
|
attribute_keys_to_return: Optional[List[str]] = None,
|
||||||
|
max_depth: Optional[int] = None,
|
||||||
|
limit: Optional[int] = 100,
|
||||||
|
offset: Optional[int] = 0,
|
||||||
|
order_by: Optional[List[str]] = None,
|
||||||
|
) -> SpanWithChildren:
|
||||||
|
return await self.trace_store.get_spans(
|
||||||
|
span_id=span_id,
|
||||||
|
attribute_conditions=attribute_conditions,
|
||||||
|
attribute_keys_to_return=attribute_keys_to_return,
|
||||||
|
max_depth=max_depth,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
order_by=order_by,
|
||||||
|
)
|
|
@ -11,8 +11,8 @@ from typing import List, Optional
|
||||||
import aiosqlite
|
import aiosqlite
|
||||||
|
|
||||||
from llama_stack.apis.telemetry import (
|
from llama_stack.apis.telemetry import (
|
||||||
MaterializedSpan,
|
|
||||||
QueryCondition,
|
QueryCondition,
|
||||||
|
SpanWithChildren,
|
||||||
Trace,
|
Trace,
|
||||||
TraceStore,
|
TraceStore,
|
||||||
)
|
)
|
||||||
|
@ -24,56 +24,76 @@ class SQLiteTraceStore(TraceStore):
|
||||||
|
|
||||||
async def query_traces(
|
async def query_traces(
|
||||||
self,
|
self,
|
||||||
attribute_conditions: Optional[List[QueryCondition]] = None,
|
attribute_filters: Optional[List[QueryCondition]] = None,
|
||||||
attribute_keys_to_return: Optional[List[str]] = None,
|
attributes_to_return: Optional[List[str]] = None,
|
||||||
limit: Optional[int] = 100,
|
limit: Optional[int] = 100,
|
||||||
offset: Optional[int] = 0,
|
offset: Optional[int] = 0,
|
||||||
order_by: Optional[List[str]] = None,
|
order_by: Optional[List[str]] = None,
|
||||||
) -> List[Trace]:
|
) -> List[Trace]:
|
||||||
# Build the SQL query with attribute selection
|
print(attribute_filters, attributes_to_return, limit, offset, order_by)
|
||||||
select_clause = """
|
|
||||||
SELECT DISTINCT t.trace_id, t.root_span_id, t.start_time, t.end_time
|
|
||||||
"""
|
|
||||||
if attribute_keys_to_return:
|
|
||||||
for key in attribute_keys_to_return:
|
|
||||||
select_clause += (
|
|
||||||
f", json_extract(s.attributes, '$.{key}') as attr_{key}"
|
|
||||||
)
|
|
||||||
|
|
||||||
query = (
|
def build_attribute_select() -> str:
|
||||||
select_clause
|
if not attributes_to_return:
|
||||||
+ """
|
return ""
|
||||||
FROM traces t
|
return "".join(
|
||||||
JOIN spans s ON t.trace_id = s.trace_id
|
f", json_extract(s.attributes, '$.{key}') as attr_{key}"
|
||||||
"""
|
for key in attributes_to_return
|
||||||
)
|
)
|
||||||
params = []
|
|
||||||
|
|
||||||
# Add attribute conditions if present
|
def build_where_clause() -> tuple[str, list]:
|
||||||
if attribute_conditions:
|
if not attribute_filters:
|
||||||
conditions = []
|
return "", []
|
||||||
for condition in attribute_conditions:
|
|
||||||
conditions.append(
|
conditions = [
|
||||||
f"json_extract(s.attributes, '$.{condition.key}') {condition.op} ?"
|
f"json_extract(s.attributes, '$.{condition.key}') {condition.op} ?"
|
||||||
)
|
for condition in attribute_filters
|
||||||
params.append(condition.value)
|
]
|
||||||
if conditions:
|
params = [condition.value for condition in attribute_filters]
|
||||||
query += " WHERE " + " AND ".join(conditions)
|
where_clause = " WHERE " + " AND ".join(conditions)
|
||||||
|
return where_clause, params
|
||||||
|
|
||||||
|
def build_order_clause() -> str:
|
||||||
|
if not order_by:
|
||||||
|
return ""
|
||||||
|
|
||||||
# Add ordering
|
|
||||||
if order_by:
|
|
||||||
order_clauses = []
|
order_clauses = []
|
||||||
for field in order_by:
|
for field in order_by:
|
||||||
desc = False
|
desc = field.startswith("-")
|
||||||
if field.startswith("-"):
|
clean_field = field[1:] if desc else field
|
||||||
field = field[1:]
|
order_clauses.append(f"t.{clean_field} {'DESC' if desc else 'ASC'}")
|
||||||
desc = True
|
return " ORDER BY " + ", ".join(order_clauses)
|
||||||
order_clauses.append(f"t.{field} {'DESC' if desc else 'ASC'}")
|
|
||||||
query += " ORDER BY " + ", ".join(order_clauses)
|
|
||||||
|
|
||||||
# Add limit and offset
|
# Build the main query
|
||||||
query += f" LIMIT {limit} OFFSET {offset}"
|
base_query = """
|
||||||
|
WITH matching_traces AS (
|
||||||
|
SELECT DISTINCT t.trace_id
|
||||||
|
FROM traces t
|
||||||
|
JOIN spans s ON t.trace_id = s.trace_id
|
||||||
|
{where_clause}
|
||||||
|
),
|
||||||
|
filtered_traces AS (
|
||||||
|
SELECT t.trace_id, t.root_span_id, t.start_time, t.end_time
|
||||||
|
{attribute_select}
|
||||||
|
FROM matching_traces mt
|
||||||
|
JOIN traces t ON mt.trace_id = t.trace_id
|
||||||
|
LEFT JOIN spans s ON t.trace_id = s.trace_id
|
||||||
|
{order_clause}
|
||||||
|
)
|
||||||
|
SELECT DISTINCT trace_id, root_span_id, start_time, end_time
|
||||||
|
FROM filtered_traces
|
||||||
|
LIMIT {limit} OFFSET {offset}
|
||||||
|
"""
|
||||||
|
|
||||||
|
where_clause, params = build_where_clause()
|
||||||
|
query = base_query.format(
|
||||||
|
attribute_select=build_attribute_select(),
|
||||||
|
where_clause=where_clause,
|
||||||
|
order_clause=build_order_clause(),
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute query and return results
|
||||||
async with aiosqlite.connect(self.conn_string) as conn:
|
async with aiosqlite.connect(self.conn_string) as conn:
|
||||||
conn.row_factory = aiosqlite.Row
|
conn.row_factory = aiosqlite.Row
|
||||||
async with conn.execute(query, params) as cursor:
|
async with conn.execute(query, params) as cursor:
|
||||||
|
@ -91,15 +111,15 @@ class SQLiteTraceStore(TraceStore):
|
||||||
async def get_materialized_span(
|
async def get_materialized_span(
|
||||||
self,
|
self,
|
||||||
span_id: str,
|
span_id: str,
|
||||||
attribute_keys_to_return: Optional[List[str]] = None,
|
attributes_to_return: Optional[List[str]] = None,
|
||||||
max_depth: Optional[int] = None,
|
max_depth: Optional[int] = None,
|
||||||
) -> MaterializedSpan:
|
) -> SpanWithChildren:
|
||||||
# Build the attributes selection
|
# Build the attributes selection
|
||||||
attributes_select = "s.attributes"
|
attributes_select = "s.attributes"
|
||||||
if attribute_keys_to_return:
|
if attributes_to_return:
|
||||||
json_object = ", ".join(
|
json_object = ", ".join(
|
||||||
f"'{key}', json_extract(s.attributes, '$.{key}')"
|
f"'{key}', json_extract(s.attributes, '$.{key}')"
|
||||||
for key in attribute_keys_to_return
|
for key in attributes_to_return
|
||||||
)
|
)
|
||||||
attributes_select = f"json_object({json_object})"
|
attributes_select = f"json_object({json_object})"
|
||||||
|
|
||||||
|
@ -135,7 +155,7 @@ class SQLiteTraceStore(TraceStore):
|
||||||
root_span = None
|
root_span = None
|
||||||
|
|
||||||
for row in rows:
|
for row in rows:
|
||||||
span = MaterializedSpan(
|
span = SpanWithChildren(
|
||||||
span_id=row["span_id"],
|
span_id=row["span_id"],
|
||||||
trace_id=row["trace_id"],
|
trace_id=row["trace_id"],
|
||||||
parent_span_id=row["parent_span_id"],
|
parent_span_id=row["parent_span_id"],
|
||||||
|
|
180
llama_stack/providers/utils/telemetry/sqlite_trace_store.py
Normal file
180
llama_stack/providers/utils/telemetry/sqlite_trace_store.py
Normal file
|
@ -0,0 +1,180 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Optional, Protocol
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
|
|
||||||
|
from llama_stack.apis.telemetry import QueryCondition, SpanWithChildren, Trace
|
||||||
|
|
||||||
|
|
||||||
|
class TraceStore(Protocol):
|
||||||
|
|
||||||
|
async def query_traces(
|
||||||
|
self,
|
||||||
|
attribute_filters: Optional[List[QueryCondition]] = None,
|
||||||
|
limit: Optional[int] = 100,
|
||||||
|
offset: Optional[int] = 0,
|
||||||
|
order_by: Optional[List[str]] = None,
|
||||||
|
) -> List[Trace]: ...
|
||||||
|
|
||||||
|
async def get_materialized_span(
|
||||||
|
self,
|
||||||
|
span_id: str,
|
||||||
|
attributes_to_return: Optional[List[str]] = None,
|
||||||
|
max_depth: Optional[int] = None,
|
||||||
|
) -> SpanWithChildren: ...
|
||||||
|
|
||||||
|
|
||||||
|
class SQLiteTraceStore(TraceStore):
|
||||||
|
def __init__(self, conn_string: str):
|
||||||
|
self.conn_string = conn_string
|
||||||
|
|
||||||
|
async def query_traces(
|
||||||
|
self,
|
||||||
|
attribute_filters: Optional[List[QueryCondition]] = None,
|
||||||
|
limit: Optional[int] = 100,
|
||||||
|
offset: Optional[int] = 0,
|
||||||
|
order_by: Optional[List[str]] = None,
|
||||||
|
) -> List[Trace]:
|
||||||
|
|
||||||
|
def build_where_clause() -> tuple[str, list]:
|
||||||
|
if not attribute_filters:
|
||||||
|
return "", []
|
||||||
|
|
||||||
|
ops_map = {"eq": "=", "ne": "!=", "gt": ">", "lt": "<"}
|
||||||
|
|
||||||
|
conditions = [
|
||||||
|
f"json_extract(s.attributes, '$.{condition.key}') {ops_map[condition.op]} ?"
|
||||||
|
for condition in attribute_filters
|
||||||
|
]
|
||||||
|
params = [condition.value for condition in attribute_filters]
|
||||||
|
where_clause = " WHERE " + " AND ".join(conditions)
|
||||||
|
return where_clause, params
|
||||||
|
|
||||||
|
def build_order_clause() -> str:
|
||||||
|
if not order_by:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
order_clauses = []
|
||||||
|
for field in order_by:
|
||||||
|
desc = field.startswith("-")
|
||||||
|
clean_field = field[1:] if desc else field
|
||||||
|
order_clauses.append(f"t.{clean_field} {'DESC' if desc else 'ASC'}")
|
||||||
|
return " ORDER BY " + ", ".join(order_clauses)
|
||||||
|
|
||||||
|
# Build the main query
|
||||||
|
base_query = """
|
||||||
|
WITH matching_traces AS (
|
||||||
|
SELECT DISTINCT t.trace_id
|
||||||
|
FROM traces t
|
||||||
|
JOIN spans s ON t.trace_id = s.trace_id
|
||||||
|
{where_clause}
|
||||||
|
),
|
||||||
|
filtered_traces AS (
|
||||||
|
SELECT t.trace_id, t.root_span_id, t.start_time, t.end_time
|
||||||
|
FROM matching_traces mt
|
||||||
|
JOIN traces t ON mt.trace_id = t.trace_id
|
||||||
|
LEFT JOIN spans s ON t.trace_id = s.trace_id
|
||||||
|
{order_clause}
|
||||||
|
)
|
||||||
|
SELECT DISTINCT trace_id, root_span_id, start_time, end_time
|
||||||
|
FROM filtered_traces
|
||||||
|
LIMIT {limit} OFFSET {offset}
|
||||||
|
"""
|
||||||
|
|
||||||
|
where_clause, params = build_where_clause()
|
||||||
|
query = base_query.format(
|
||||||
|
where_clause=where_clause,
|
||||||
|
order_clause=build_order_clause(),
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute query and return results
|
||||||
|
async with aiosqlite.connect(self.conn_string) as conn:
|
||||||
|
conn.row_factory = aiosqlite.Row
|
||||||
|
async with conn.execute(query, params) as cursor:
|
||||||
|
rows = await cursor.fetchall()
|
||||||
|
return [
|
||||||
|
Trace(
|
||||||
|
trace_id=row["trace_id"],
|
||||||
|
root_span_id=row["root_span_id"],
|
||||||
|
start_time=datetime.fromisoformat(row["start_time"]),
|
||||||
|
end_time=datetime.fromisoformat(row["end_time"]),
|
||||||
|
)
|
||||||
|
for row in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
async def get_materialized_span(
|
||||||
|
self,
|
||||||
|
span_id: str,
|
||||||
|
attributes_to_return: Optional[List[str]] = None,
|
||||||
|
max_depth: Optional[int] = None,
|
||||||
|
) -> SpanWithChildren:
|
||||||
|
# Build the attributes selection
|
||||||
|
attributes_select = "s.attributes"
|
||||||
|
if attributes_to_return:
|
||||||
|
json_object = ", ".join(
|
||||||
|
f"'{key}', json_extract(s.attributes, '$.{key}')"
|
||||||
|
for key in attributes_to_return
|
||||||
|
)
|
||||||
|
attributes_select = f"json_object({json_object})"
|
||||||
|
|
||||||
|
# SQLite CTE query with filtered attributes
|
||||||
|
query = f"""
|
||||||
|
WITH RECURSIVE span_tree AS (
|
||||||
|
SELECT s.*, 1 as depth, {attributes_select} as filtered_attributes
|
||||||
|
FROM spans s
|
||||||
|
WHERE s.span_id = ?
|
||||||
|
|
||||||
|
UNION ALL
|
||||||
|
|
||||||
|
SELECT s.*, st.depth + 1, {attributes_select} as filtered_attributes
|
||||||
|
FROM spans s
|
||||||
|
JOIN span_tree st ON s.parent_span_id = st.span_id
|
||||||
|
WHERE (? IS NULL OR st.depth < ?)
|
||||||
|
)
|
||||||
|
SELECT *
|
||||||
|
FROM span_tree
|
||||||
|
ORDER BY depth, start_time
|
||||||
|
"""
|
||||||
|
|
||||||
|
async with aiosqlite.connect(self.conn_string) as conn:
|
||||||
|
conn.row_factory = aiosqlite.Row
|
||||||
|
async with conn.execute(query, (span_id, max_depth, max_depth)) as cursor:
|
||||||
|
rows = await cursor.fetchall()
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
raise ValueError(f"Span {span_id} not found")
|
||||||
|
|
||||||
|
# Build span tree
|
||||||
|
spans_by_id = {}
|
||||||
|
root_span = None
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
span = SpanWithChildren(
|
||||||
|
span_id=row["span_id"],
|
||||||
|
trace_id=row["trace_id"],
|
||||||
|
parent_span_id=row["parent_span_id"],
|
||||||
|
name=row["name"],
|
||||||
|
start_time=datetime.fromisoformat(row["start_time"]),
|
||||||
|
end_time=datetime.fromisoformat(row["end_time"]),
|
||||||
|
attributes=json.loads(row["filtered_attributes"]),
|
||||||
|
status=row["status"].lower(),
|
||||||
|
children=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
spans_by_id[span.span_id] = span
|
||||||
|
|
||||||
|
if span.span_id == span_id:
|
||||||
|
root_span = span
|
||||||
|
elif span.parent_span_id in spans_by_id:
|
||||||
|
spans_by_id[span.parent_span_id].children.append(span)
|
||||||
|
|
||||||
|
return root_span
|
Loading…
Add table
Add a link
Reference in a new issue