mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 09:02:37 +00:00
253 lines
6.3 KiB
Python
253 lines
6.3 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from datetime import datetime
|
|
from enum import Enum
|
|
from typing import (
|
|
Any,
|
|
Dict,
|
|
List,
|
|
Literal,
|
|
Optional,
|
|
Protocol,
|
|
runtime_checkable,
|
|
Union,
|
|
)
|
|
|
|
from llama_models.schema_utils import json_schema_type, webmethod
|
|
from pydantic import BaseModel, Field
|
|
from typing_extensions import Annotated
|
|
|
|
from llama_stack.apis.datasetio import DatasetIO
|
|
|
|
# Add this constant near the top of the file, after the imports
|
|
DEFAULT_TTL_DAYS = 7
|
|
|
|
|
|
@json_schema_type
|
|
class SpanStatus(Enum):
|
|
OK = "ok"
|
|
ERROR = "error"
|
|
|
|
|
|
@json_schema_type
|
|
class Span(BaseModel):
|
|
span_id: str
|
|
trace_id: str
|
|
parent_span_id: Optional[str] = None
|
|
name: str
|
|
start_time: datetime
|
|
end_time: Optional[datetime] = None
|
|
attributes: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
|
|
|
def set_attribute(self, key: str, value: Any):
|
|
if self.attributes is None:
|
|
self.attributes = {}
|
|
self.attributes[key] = value
|
|
|
|
|
|
@json_schema_type
|
|
class Trace(BaseModel):
|
|
trace_id: str
|
|
root_span_id: str
|
|
start_time: datetime
|
|
end_time: Optional[datetime] = None
|
|
|
|
|
|
@json_schema_type
|
|
class EventType(Enum):
|
|
UNSTRUCTURED_LOG = "unstructured_log"
|
|
STRUCTURED_LOG = "structured_log"
|
|
METRIC = "metric"
|
|
|
|
|
|
@json_schema_type
|
|
class LogSeverity(Enum):
|
|
VERBOSE = "verbose"
|
|
DEBUG = "debug"
|
|
INFO = "info"
|
|
WARN = "warn"
|
|
ERROR = "error"
|
|
CRITICAL = "critical"
|
|
|
|
|
|
class EventCommon(BaseModel):
|
|
trace_id: str
|
|
span_id: str
|
|
timestamp: datetime
|
|
attributes: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
|
|
|
|
|
@json_schema_type
|
|
class UnstructuredLogEvent(EventCommon):
|
|
type: Literal[EventType.UNSTRUCTURED_LOG.value] = EventType.UNSTRUCTURED_LOG.value
|
|
message: str
|
|
severity: LogSeverity
|
|
|
|
|
|
@json_schema_type
|
|
class MetricEvent(EventCommon):
|
|
type: Literal[EventType.METRIC.value] = EventType.METRIC.value
|
|
metric: str # this would be an enum
|
|
value: Union[int, float]
|
|
unit: str
|
|
|
|
|
|
@json_schema_type
|
|
class StructuredLogType(Enum):
|
|
SPAN_START = "span_start"
|
|
SPAN_END = "span_end"
|
|
|
|
|
|
@json_schema_type
|
|
class SpanStartPayload(BaseModel):
|
|
type: Literal[StructuredLogType.SPAN_START.value] = (
|
|
StructuredLogType.SPAN_START.value
|
|
)
|
|
name: str
|
|
parent_span_id: Optional[str] = None
|
|
|
|
|
|
@json_schema_type
|
|
class SpanEndPayload(BaseModel):
|
|
type: Literal[StructuredLogType.SPAN_END.value] = StructuredLogType.SPAN_END.value
|
|
status: SpanStatus
|
|
|
|
|
|
StructuredLogPayload = Annotated[
|
|
Union[
|
|
SpanStartPayload,
|
|
SpanEndPayload,
|
|
],
|
|
Field(discriminator="type"),
|
|
]
|
|
|
|
|
|
@json_schema_type
|
|
class StructuredLogEvent(EventCommon):
|
|
type: Literal[EventType.STRUCTURED_LOG.value] = EventType.STRUCTURED_LOG.value
|
|
payload: StructuredLogPayload
|
|
|
|
|
|
Event = Annotated[
|
|
Union[
|
|
UnstructuredLogEvent,
|
|
MetricEvent,
|
|
StructuredLogEvent,
|
|
],
|
|
Field(discriminator="type"),
|
|
]
|
|
|
|
|
|
@json_schema_type
|
|
class EvalTrace(BaseModel):
|
|
session_id: str
|
|
step: str
|
|
input: str
|
|
output: str
|
|
expected_output: str
|
|
|
|
|
|
@json_schema_type
|
|
class SpanWithChildren(Span):
|
|
children: List["SpanWithChildren"] = Field(default_factory=list)
|
|
status: Optional[SpanStatus] = None
|
|
|
|
|
|
@json_schema_type
|
|
class QueryCondition(BaseModel):
|
|
key: str
|
|
op: Literal["eq", "ne", "gt", "lt"]
|
|
value: Any
|
|
|
|
|
|
@runtime_checkable
|
|
class Telemetry(Protocol):
|
|
|
|
datasetio_api: DatasetIO
|
|
|
|
@webmethod(route="/telemetry/log-event")
|
|
async def log_event(
|
|
self, event: Event, ttl_seconds: int = DEFAULT_TTL_DAYS * 86400
|
|
) -> None: ...
|
|
|
|
@webmethod(route="/telemetry/query-traces", method="POST")
|
|
async def query_traces(
|
|
self,
|
|
attribute_filters: Optional[List[QueryCondition]] = None,
|
|
limit: Optional[int] = 100,
|
|
offset: Optional[int] = 0,
|
|
order_by: Optional[List[str]] = None,
|
|
) -> List[Trace]: ...
|
|
|
|
@webmethod(route="/telemetry/get-span-tree", method="POST")
|
|
async def get_span_tree(
|
|
self,
|
|
span_id: str,
|
|
attributes_to_return: Optional[List[str]] = None,
|
|
max_depth: Optional[int] = None,
|
|
) -> SpanWithChildren: ...
|
|
|
|
@webmethod(route="/telemetry/query-spans", method="POST")
|
|
async def query_spans(
|
|
self,
|
|
attribute_filters: List[QueryCondition],
|
|
attributes_to_return: List[str],
|
|
max_depth: Optional[int] = None,
|
|
) -> List[Dict[str, Any]]:
|
|
traces = await self.query_traces(attribute_filters=attribute_filters)
|
|
|
|
rows = []
|
|
|
|
for trace in traces:
|
|
span_tree = await self.get_span_tree(
|
|
span_id=trace.root_span_id,
|
|
attributes_to_return=attributes_to_return,
|
|
max_depth=max_depth,
|
|
)
|
|
|
|
def extract_spans(span: SpanWithChildren) -> List[Dict[str, Any]]:
|
|
rows = []
|
|
if span.attributes and all(
|
|
attr in span.attributes and span.attributes[attr] is not None
|
|
for attr in attributes_to_return
|
|
):
|
|
row = {
|
|
"trace_id": trace.root_span_id,
|
|
"span_id": span.span_id,
|
|
"step_name": span.name,
|
|
}
|
|
for attr in attributes_to_return:
|
|
row[attr] = str(span.attributes[attr])
|
|
rows.append(row)
|
|
|
|
for child in span.children:
|
|
rows.extend(extract_spans(child))
|
|
|
|
return rows
|
|
|
|
rows.extend(extract_spans(span_tree))
|
|
|
|
return rows
|
|
|
|
@webmethod(route="/telemetry/save-traces-to-dataset", method="POST")
|
|
async def save_traces_to_dataset(
|
|
self,
|
|
attribute_filters: List[QueryCondition],
|
|
attributes_to_save: List[str],
|
|
dataset_id: str,
|
|
max_depth: Optional[int] = None,
|
|
) -> None:
|
|
annotation_rows = await self.query_spans(
|
|
attribute_filters=attribute_filters,
|
|
attributes_to_return=attributes_to_save,
|
|
max_depth=max_depth,
|
|
)
|
|
|
|
if annotation_rows:
|
|
await self.datasetio_api.append_rows(
|
|
dataset_id=dataset_id, rows=annotation_rows
|
|
)
|