From fcd64495195a53d78ebd7ec45b93e3b3d1143a57 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Wed, 4 Dec 2024 11:22:45 -0800 Subject: [PATCH] Telemetry API redesign (#525) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? Change the Telemetry API to be able to support different use cases like returning traces for the UI and ability to export for Evals. Other changes: * Add a new trace_protocol decorator to decorate all our API methods so that any call to them will automatically get traced across all impls. * There is some issue with the decorator pattern of span creation when using async generators, where there are multiple yields with in the same context. I think its much more explicit by using the explicit context manager pattern using with. I moved the span creations in agent instance to be using with * Inject session id at the turn level, which should quickly give us all traces across turns for a given session Addresses #509 ## Test Plan ``` llama stack run /Users/dineshyv/.llama/distributions/llamastack-together/together-run.yaml PYTHONPATH=. python -m examples.agents.rag_with_memory_bank localhost 5000 curl -X POST 'http://localhost:5000/alpha/telemetry/query-traces' \ -H 'Content-Type: application/json' \ -d '{ "attribute_filters": [ { "key": "session_id", "op": "eq", "value": "dd667b87-ca4b-4d30-9265-5a0de318fc65" }], "limit": 100, "offset": 0, "order_by": ["start_time"] }' | jq . [ { "trace_id": "6902f54b83b4b48be18a6f422b13e16f", "root_span_id": "5f37b85543afc15a", "start_time": "2024-12-04T08:08:30.501587", "end_time": "2024-12-04T08:08:36.026463" }, { "trace_id": "92227dac84c0615ed741be393813fb5f", "root_span_id": "af7c5bb46665c2c8", "start_time": "2024-12-04T08:08:36.031170", "end_time": "2024-12-04T08:08:41.693301" }, { "trace_id": "7d578a6edac62f204ab479fba82f77b6", "root_span_id": "1d935e3362676896", "start_time": "2024-12-04T08:08:41.695204", "end_time": "2024-12-04T08:08:47.228016" }, { "trace_id": "dbd767d76991bc816f9f078907dc9ff2", "root_span_id": "f5a7ee76683b9602", "start_time": "2024-12-04T08:08:47.234578", "end_time": "2024-12-04T08:08:53.189412" } ] curl -X POST 'http://localhost:5000/alpha/telemetry/get-span-tree' \ -H 'Content-Type: application/json' \ -d '{ "span_id" : "6cceb4b48a156913", "max_depth": 2, "attributes_to_return": ["input"] }' | jq . % Total % Received % Xferd Average Speed Time Time Time Current Dload Upload Total Spent Left Speed 100 875 100 790 100 85 18462 1986 --:--:-- --:--:-- --:--:-- 20833 { "span_id": "6cceb4b48a156913", "trace_id": "dafa796f6aaf925f511c04cd7c67fdda", "parent_span_id": "892a66d726c7f990", "name": "retrieve_rag_context", "start_time": "2024-12-04T09:28:21.781995", "end_time": "2024-12-04T09:28:21.913352", "attributes": { "input": [ "{\"role\":\"system\",\"content\":\"You are a helpful assistant\"}", "{\"role\":\"user\",\"content\":\"What are the top 5 topics that were explained in the documentation? Only list succinct bullet points.\",\"context\":null}" ] }, "children": [ { "span_id": "1a2df181854064a8", "trace_id": "dafa796f6aaf925f511c04cd7c67fdda", "parent_span_id": "6cceb4b48a156913", "name": "MemoryRouter.query_documents", "start_time": "2024-12-04T09:28:21.787620", "end_time": "2024-12-04T09:28:21.906512", "attributes": { "input": null }, "children": [], "status": "ok" } ], "status": "ok" } ``` Screenshot 2024-12-04 at 9 42 56 AM --- llama_stack/apis/agents/agents.py | 2 + llama_stack/apis/datasetio/datasetio.py | 5 + llama_stack/apis/inference/inference.py | 3 + llama_stack/apis/memory/memory.py | 2 + llama_stack/apis/memory_banks/memory_banks.py | 2 + llama_stack/apis/models/models.py | 2 + llama_stack/apis/safety/safety.py | 3 + llama_stack/apis/shields/shields.py | 2 + llama_stack/apis/telemetry/telemetry.py | 66 ++++- llama_stack/distribution/routers/routers.py | 6 + llama_stack/distribution/server/server.py | 8 +- llama_stack/distribution/tracing.py | 128 +++++++++ .../agents/meta_reference/agent_instance.py | 227 +++++++++------- .../inline/datasetio/localfs/datasetio.py | 43 ++- .../meta_reference/telemetry/__init__.py | 15 -- .../inline/meta_reference/telemetry/config.py | 21 -- .../meta_reference/telemetry/console.py | 25 +- .../{remote => inline}/telemetry/__init__.py | 0 .../telemetry/meta_reference/__init__.py | 18 ++ .../inline/telemetry/meta_reference/config.py | 45 ++++ .../meta_reference/console_span_processor.py | 95 +++++++ .../meta_reference/sqlite_span_processor.py | 242 +++++++++++++++++ .../telemetry/meta_reference/telemetry.py | 247 ++++++++++++++++++ .../telemetry/sample/__init__.py | 0 .../telemetry/sample/config.py | 0 .../telemetry/sample/sample.py | 0 llama_stack/providers/registry/telemetry.py | 23 +- .../datasetio/huggingface/huggingface.py | 21 +- .../telemetry/opentelemetry/__init__.py | 15 -- .../remote/telemetry/opentelemetry/config.py | 27 -- .../telemetry/opentelemetry/opentelemetry.py | 115 +++++--- .../providers/utils/telemetry/sqlite.py | 177 +++++++++++++ .../utils/telemetry/sqlite_trace_store.py | 180 +++++++++++++ .../providers/utils/telemetry/tracing.py | 31 ++- 34 files changed, 1551 insertions(+), 245 deletions(-) create mode 100644 llama_stack/distribution/tracing.py delete mode 100644 llama_stack/providers/inline/meta_reference/telemetry/__init__.py delete mode 100644 llama_stack/providers/inline/meta_reference/telemetry/config.py rename llama_stack/providers/{remote => inline}/telemetry/__init__.py (100%) create mode 100644 llama_stack/providers/inline/telemetry/meta_reference/__init__.py create mode 100644 llama_stack/providers/inline/telemetry/meta_reference/config.py create mode 100644 llama_stack/providers/inline/telemetry/meta_reference/console_span_processor.py create mode 100644 llama_stack/providers/inline/telemetry/meta_reference/sqlite_span_processor.py create mode 100644 llama_stack/providers/inline/telemetry/meta_reference/telemetry.py rename llama_stack/providers/{remote => inline}/telemetry/sample/__init__.py (100%) rename llama_stack/providers/{remote => inline}/telemetry/sample/config.py (100%) rename llama_stack/providers/{remote => inline}/telemetry/sample/sample.py (100%) delete mode 100644 llama_stack/providers/remote/telemetry/opentelemetry/__init__.py delete mode 100644 llama_stack/providers/remote/telemetry/opentelemetry/config.py create mode 100644 llama_stack/providers/utils/telemetry/sqlite.py create mode 100644 llama_stack/providers/utils/telemetry/sqlite_trace_store.py diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 25de35497..d2243c96f 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -23,6 +23,7 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, ConfigDict, Field from typing_extensions import Annotated +from llama_stack.distribution.tracing import trace_protocol from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.common.deployment_types import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403 @@ -418,6 +419,7 @@ class AgentStepResponse(BaseModel): @runtime_checkable +@trace_protocol class Agents(Protocol): @webmethod(route="/agents/create") async def create_agent( diff --git a/llama_stack/apis/datasetio/datasetio.py b/llama_stack/apis/datasetio/datasetio.py index c5052877a..22acc3211 100644 --- a/llama_stack/apis/datasetio/datasetio.py +++ b/llama_stack/apis/datasetio/datasetio.py @@ -37,3 +37,8 @@ class DatasetIO(Protocol): page_token: Optional[str] = None, filter_condition: Optional[str] = None, ) -> PaginatedRowsResult: ... + + @webmethod(route="/datasetio/append-rows", method="POST") + async def append_rows( + self, dataset_id: str, rows: List[Dict[str, Any]] + ) -> None: ... diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 5aadd97c7..85b29a147 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -21,6 +21,8 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field from typing_extensions import Annotated +from llama_stack.distribution.tracing import trace_protocol + from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.models import * # noqa: F403 @@ -220,6 +222,7 @@ class ModelStore(Protocol): @runtime_checkable +@trace_protocol class Inference(Protocol): model_store: ModelStore diff --git a/llama_stack/apis/memory/memory.py b/llama_stack/apis/memory/memory.py index 48b6e2241..b75df8a1a 100644 --- a/llama_stack/apis/memory/memory.py +++ b/llama_stack/apis/memory/memory.py @@ -16,6 +16,7 @@ from pydantic import BaseModel, Field from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.memory_banks import * # noqa: F403 +from llama_stack.distribution.tracing import trace_protocol @json_schema_type @@ -43,6 +44,7 @@ class MemoryBankStore(Protocol): @runtime_checkable +@trace_protocol class Memory(Protocol): memory_bank_store: MemoryBankStore diff --git a/llama_stack/apis/memory_banks/memory_banks.py b/llama_stack/apis/memory_banks/memory_banks.py index 1b16af330..0b8b2563f 100644 --- a/llama_stack/apis/memory_banks/memory_banks.py +++ b/llama_stack/apis/memory_banks/memory_banks.py @@ -20,6 +20,7 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field from llama_stack.apis.resource import Resource, ResourceType +from llama_stack.distribution.tracing import trace_protocol @json_schema_type @@ -129,6 +130,7 @@ class MemoryBankInput(BaseModel): @runtime_checkable +@trace_protocol class MemoryBanks(Protocol): @webmethod(route="/memory-banks/list", method="GET") async def list_memory_banks(self) -> List[MemoryBank]: ... diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index cbd6265e2..2c0f1ee21 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -10,6 +10,7 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, ConfigDict, Field from llama_stack.apis.resource import Resource, ResourceType +from llama_stack.distribution.tracing import trace_protocol class CommonModelFields(BaseModel): @@ -43,6 +44,7 @@ class ModelInput(CommonModelFields): @runtime_checkable +@trace_protocol class Models(Protocol): @webmethod(route="/models/list", method="GET") async def list_models(self) -> List[Model]: ... diff --git a/llama_stack/apis/safety/safety.py b/llama_stack/apis/safety/safety.py index 724f8dc96..41058f107 100644 --- a/llama_stack/apis/safety/safety.py +++ b/llama_stack/apis/safety/safety.py @@ -10,6 +10,8 @@ from typing import Any, Dict, List, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel +from llama_stack.distribution.tracing import trace_protocol + from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.shields import * # noqa: F403 @@ -43,6 +45,7 @@ class ShieldStore(Protocol): @runtime_checkable +@trace_protocol class Safety(Protocol): shield_store: ShieldStore diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py index 5ee444f68..b28605727 100644 --- a/llama_stack/apis/shields/shields.py +++ b/llama_stack/apis/shields/shields.py @@ -10,6 +10,7 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel from llama_stack.apis.resource import Resource, ResourceType +from llama_stack.distribution.tracing import trace_protocol class CommonShieldFields(BaseModel): @@ -38,6 +39,7 @@ class ShieldInput(CommonShieldFields): @runtime_checkable +@trace_protocol class Shields(Protocol): @webmethod(route="/shields/list", method="GET") async def list_shields(self) -> List[Shield]: ... diff --git a/llama_stack/apis/telemetry/telemetry.py b/llama_stack/apis/telemetry/telemetry.py index 31f64733b..2ff783c46 100644 --- a/llama_stack/apis/telemetry/telemetry.py +++ b/llama_stack/apis/telemetry/telemetry.py @@ -6,12 +6,24 @@ from datetime import datetime from enum import Enum -from typing import Any, Dict, Literal, Optional, Protocol, runtime_checkable, Union +from typing import ( + Any, + Dict, + List, + Literal, + Optional, + Protocol, + runtime_checkable, + Union, +) from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field from typing_extensions import Annotated +# Add this constant near the top of the file, after the imports +DEFAULT_TTL_DAYS = 7 + @json_schema_type class SpanStatus(Enum): @@ -29,6 +41,11 @@ class Span(BaseModel): end_time: Optional[datetime] = None attributes: Optional[Dict[str, Any]] = Field(default_factory=dict) + def set_attribute(self, key: str, value: Any): + if self.attributes is None: + self.attributes = {} + self.attributes[key] = value + @json_schema_type class Trace(BaseModel): @@ -123,10 +140,49 @@ Event = Annotated[ ] +@json_schema_type +class EvalTrace(BaseModel): + session_id: str + step: str + input: str + output: str + expected_output: str + + +@json_schema_type +class SpanWithChildren(Span): + children: List["SpanWithChildren"] = Field(default_factory=list) + status: Optional[SpanStatus] = None + + +@json_schema_type +class QueryCondition(BaseModel): + key: str + op: Literal["eq", "ne", "gt", "lt"] + value: Any + + @runtime_checkable class Telemetry(Protocol): - @webmethod(route="/telemetry/log-event") - async def log_event(self, event: Event) -> None: ... - @webmethod(route="/telemetry/get-trace", method="GET") - async def get_trace(self, trace_id: str) -> Trace: ... + @webmethod(route="/telemetry/log-event") + async def log_event( + self, event: Event, ttl_seconds: int = DEFAULT_TTL_DAYS * 86400 + ) -> None: ... + + @webmethod(route="/telemetry/query-traces", method="POST") + async def query_traces( + self, + attribute_filters: Optional[List[QueryCondition]] = None, + limit: Optional[int] = 100, + offset: Optional[int] = 0, + order_by: Optional[List[str]] = None, + ) -> List[Trace]: ... + + @webmethod(route="/telemetry/get-span-tree", method="POST") + async def get_span_tree( + self, + span_id: str, + attributes_to_return: Optional[List[str]] = None, + max_depth: Optional[int] = None, + ) -> SpanWithChildren: ... diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 5a62b6d64..5b75a525b 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -222,6 +222,12 @@ class DatasetIORouter(DatasetIO): filter_condition=filter_condition, ) + async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: + return await self.routing_table.get_provider_impl(dataset_id).append_rows( + dataset_id=dataset_id, + rows=rows, + ) + class ScoringRouter(Scoring): def __init__( diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 8116e2b39..4ae1854df 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -43,9 +43,9 @@ from llama_stack.distribution.stack import ( replace_env_vars, validate_env_pair, ) -from llama_stack.providers.inline.meta_reference.telemetry.console import ( - ConsoleConfig, - ConsoleTelemetryImpl, +from llama_stack.providers.inline.telemetry.meta_reference import ( + TelemetryAdapter, + TelemetryConfig, ) from .endpoints import get_all_api_endpoints @@ -290,7 +290,7 @@ def main(): if Api.telemetry in impls: setup_logger(impls[Api.telemetry]) else: - setup_logger(ConsoleTelemetryImpl(ConsoleConfig())) + setup_logger(TelemetryAdapter(TelemetryConfig())) all_endpoints = get_all_api_endpoints() diff --git a/llama_stack/distribution/tracing.py b/llama_stack/distribution/tracing.py new file mode 100644 index 000000000..ea663ec89 --- /dev/null +++ b/llama_stack/distribution/tracing.py @@ -0,0 +1,128 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +import inspect +import json +from functools import wraps +from typing import Any, AsyncGenerator, Callable, Type, TypeVar + +from pydantic import BaseModel + +from llama_stack.providers.utils.telemetry import tracing + +T = TypeVar("T") + + +def serialize_value(value: Any) -> str: + """Helper function to serialize values to string representation.""" + try: + if isinstance(value, BaseModel): + return value.model_dump_json() + elif isinstance(value, list) and value and isinstance(value[0], BaseModel): + return json.dumps([item.model_dump_json() for item in value]) + elif hasattr(value, "to_dict"): + return json.dumps(value.to_dict()) + elif isinstance(value, (dict, list, int, float, str, bool)): + return json.dumps(value) + else: + return str(value) + except Exception: + return str(value) + + +def trace_protocol(cls: Type[T]) -> Type[T]: + """ + A class decorator that automatically traces all methods in a protocol/base class + and its inheriting classes. + """ + + def trace_method(method: Callable) -> Callable: + is_async = asyncio.iscoroutinefunction(method) + is_async_gen = inspect.isasyncgenfunction(method) + + def create_span_context(self: Any, *args: Any, **kwargs: Any) -> tuple: + class_name = self.__class__.__name__ + method_name = method.__name__ + + span_type = ( + "async_generator" if is_async_gen else "async" if is_async else "sync" + ) + span_attributes = { + "class": class_name, + "method": method_name, + "type": span_type, + "args": serialize_value(args), + } + + return class_name, method_name, span_attributes + + @wraps(method) + async def async_gen_wrapper( + self: Any, *args: Any, **kwargs: Any + ) -> AsyncGenerator: + class_name, method_name, span_attributes = create_span_context( + self, *args, **kwargs + ) + + with tracing.span(f"{class_name}.{method_name}", span_attributes) as span: + try: + count = 0 + async for item in method(self, *args, **kwargs): + yield item + count += 1 + finally: + span.set_attribute("chunk_count", count) + + @wraps(method) + async def async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + class_name, method_name, span_attributes = create_span_context( + self, *args, **kwargs + ) + + with tracing.span(f"{class_name}.{method_name}", span_attributes) as span: + try: + result = await method(self, *args, **kwargs) + span.set_attribute("output", serialize_value(result)) + return result + except Exception as e: + span.set_attribute("error", str(e)) + raise + + @wraps(method) + def sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + class_name, method_name, span_attributes = create_span_context( + self, *args, **kwargs + ) + + with tracing.span(f"{class_name}.{method_name}", span_attributes) as span: + try: + result = method(self, *args, **kwargs) + span.set_attribute("output", serialize_value(result)) + return result + except Exception as e: + raise + + if is_async_gen: + return async_gen_wrapper + elif is_async: + return async_wrapper + else: + return sync_wrapper + + original_init_subclass = getattr(cls, "__init_subclass__", None) + + def __init_subclass__(cls_child, **kwargs): # noqa: N807 + if original_init_subclass: + original_init_subclass(**kwargs) + + for name, method in vars(cls_child).items(): + if inspect.isfunction(method) and not name.startswith("_"): + setattr(cls_child, name, trace_method(method)) # noqa: B010 + + cls.__init_subclass__ = classmethod(__init_subclass__) + + return cls diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 8f800ad6f..7df5d3bd4 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -144,87 +144,91 @@ class ChatAgent(ShieldRunnerMixin): async def create_session(self, name: str) -> str: return await self.storage.create_session(name) - @tracing.span("create_and_execute_turn") async def create_and_execute_turn( self, request: AgentTurnCreateRequest ) -> AsyncGenerator: - assert request.stream is True, "Non-streaming not supported" + with tracing.span("create_and_execute_turn") as span: + span.set_attribute("session_id", request.session_id) + span.set_attribute("agent_id", self.agent_id) + span.set_attribute("request", request.model_dump_json()) + assert request.stream is True, "Non-streaming not supported" - session_info = await self.storage.get_session_info(request.session_id) - if session_info is None: - raise ValueError(f"Session {request.session_id} not found") + session_info = await self.storage.get_session_info(request.session_id) + if session_info is None: + raise ValueError(f"Session {request.session_id} not found") - turns = await self.storage.get_session_turns(request.session_id) + turns = await self.storage.get_session_turns(request.session_id) - messages = [] - if self.agent_config.instructions != "": - messages.append(SystemMessage(content=self.agent_config.instructions)) + messages = [] + if self.agent_config.instructions != "": + messages.append(SystemMessage(content=self.agent_config.instructions)) - for i, turn in enumerate(turns): - messages.extend(self.turn_to_messages(turn)) + for i, turn in enumerate(turns): + messages.extend(self.turn_to_messages(turn)) - messages.extend(request.messages) + messages.extend(request.messages) - turn_id = str(uuid.uuid4()) - start_time = datetime.now() - yield AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseTurnStartPayload( - turn_id=turn_id, + turn_id = str(uuid.uuid4()) + span.set_attribute("turn_id", turn_id) + start_time = datetime.now() + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseTurnStartPayload( + turn_id=turn_id, + ) ) ) - ) - steps = [] - output_message = None - async for chunk in self.run( - session_id=request.session_id, - turn_id=turn_id, - input_messages=messages, - attachments=request.attachments or [], - sampling_params=self.agent_config.sampling_params, - stream=request.stream, - ): - if isinstance(chunk, CompletionMessage): - log.info( - f"{chunk.role.capitalize()}: {chunk.content}", - ) - output_message = chunk - continue - - assert isinstance( - chunk, AgentTurnResponseStreamChunk - ), f"Unexpected type {type(chunk)}" - event = chunk.event - if ( - event.payload.event_type - == AgentTurnResponseEventType.step_complete.value + steps = [] + output_message = None + async for chunk in self.run( + session_id=request.session_id, + turn_id=turn_id, + input_messages=messages, + attachments=request.attachments or [], + sampling_params=self.agent_config.sampling_params, + stream=request.stream, ): - steps.append(event.payload.step_details) + if isinstance(chunk, CompletionMessage): + log.info( + f"{chunk.role.capitalize()}: {chunk.content}", + ) + output_message = chunk + continue - yield chunk + assert isinstance( + chunk, AgentTurnResponseStreamChunk + ), f"Unexpected type {type(chunk)}" + event = chunk.event + if ( + event.payload.event_type + == AgentTurnResponseEventType.step_complete.value + ): + steps.append(event.payload.step_details) - assert output_message is not None + yield chunk - turn = Turn( - turn_id=turn_id, - session_id=request.session_id, - input_messages=request.messages, - output_message=output_message, - started_at=start_time, - completed_at=datetime.now(), - steps=steps, - ) - await self.storage.add_turn_to_session(request.session_id, turn) + assert output_message is not None - chunk = AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseTurnCompletePayload( - turn=turn, + turn = Turn( + turn_id=turn_id, + session_id=request.session_id, + input_messages=request.messages, + output_message=output_message, + started_at=start_time, + completed_at=datetime.now(), + steps=steps, + ) + await self.storage.add_turn_to_session(request.session_id, turn) + + chunk = AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseTurnCompletePayload( + turn=turn, + ) ) ) - ) - yield chunk + yield chunk async def run( self, @@ -273,7 +277,6 @@ class ChatAgent(ShieldRunnerMixin): yield final_response - @tracing.span("run_shields") async def run_multiple_shields_wrapper( self, turn_id: str, @@ -281,23 +284,47 @@ class ChatAgent(ShieldRunnerMixin): shields: List[str], touchpoint: str, ) -> AsyncGenerator: - if len(shields) == 0: - return + with tracing.span("run_shields") as span: + span.set_attribute("turn_id", turn_id) + span.set_attribute("input", [m.model_dump_json() for m in messages]) + if len(shields) == 0: + span.set_attribute("output", "no shields") + return - step_id = str(uuid.uuid4()) - try: - yield AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseStepStartPayload( - step_type=StepType.shield_call.value, - step_id=step_id, - metadata=dict(touchpoint=touchpoint), + step_id = str(uuid.uuid4()) + try: + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepStartPayload( + step_type=StepType.shield_call.value, + step_id=step_id, + metadata=dict(touchpoint=touchpoint), + ) ) ) - ) - await self.run_multiple_shields(messages, shields) + await self.run_multiple_shields(messages, shields) + + except SafetyException as e: + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepCompletePayload( + step_type=StepType.shield_call.value, + step_details=ShieldCallStep( + step_id=step_id, + turn_id=turn_id, + violation=e.violation, + ), + ) + ) + ) + span.set_attribute("output", e.violation.model_dump_json()) + + yield CompletionMessage( + content=str(e), + stop_reason=StopReason.end_of_turn, + ) + yield False - except SafetyException as e: yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepCompletePayload( @@ -305,30 +332,12 @@ class ChatAgent(ShieldRunnerMixin): step_details=ShieldCallStep( step_id=step_id, turn_id=turn_id, - violation=e.violation, + violation=None, ), ) ) ) - - yield CompletionMessage( - content=str(e), - stop_reason=StopReason.end_of_turn, - ) - yield False - - yield AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseStepCompletePayload( - step_type=StepType.shield_call.value, - step_details=ShieldCallStep( - step_id=step_id, - turn_id=turn_id, - violation=None, - ), - ) - ) - ) + span.set_attribute("output", "no violations") async def _run( self, @@ -356,10 +365,15 @@ class ChatAgent(ShieldRunnerMixin): # TODO: find older context from the session and either replace it # or append with a sliding window. this is really a very simplistic implementation - with tracing.span("retrieve_rag_context"): + with tracing.span("retrieve_rag_context") as span: rag_context, bank_ids = await self._retrieve_context( session_id, input_messages, attachments ) + span.set_attribute( + "input", [m.model_dump_json() for m in input_messages] + ) + span.set_attribute("output", rag_context) + span.set_attribute("bank_ids", bank_ids) step_id = str(uuid.uuid4()) yield AgentTurnResponseStreamChunk( @@ -416,7 +430,7 @@ class ChatAgent(ShieldRunnerMixin): content = "" stop_reason = None - with tracing.span("inference"): + with tracing.span("inference") as span: async for chunk in await self.inference_api.chat_completion( self.agent_config.model, input_messages, @@ -436,7 +450,6 @@ class ChatAgent(ShieldRunnerMixin): if isinstance(delta, ToolCallDelta): if delta.parse_status == ToolCallParseStatus.success: tool_calls.append(delta.content) - if stream: yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( @@ -466,6 +479,13 @@ class ChatAgent(ShieldRunnerMixin): if event.stop_reason is not None: stop_reason = event.stop_reason + span.set_attribute("stop_reason", stop_reason) + span.set_attribute( + "input", [m.model_dump_json() for m in input_messages] + ) + span.set_attribute( + "output", f"content: {content} tool_calls: {tool_calls}" + ) stop_reason = stop_reason or StopReason.out_of_tokens @@ -549,7 +569,13 @@ class ChatAgent(ShieldRunnerMixin): ) ) - with tracing.span("tool_execution"): + with tracing.span( + "tool_execution", + { + "tool_name": tool_call.tool_name, + "input": message.model_dump_json(), + }, + ) as span: result_messages = await execute_tool_call_maybe( self.tools_dict, [message], @@ -558,6 +584,7 @@ class ChatAgent(ShieldRunnerMixin): len(result_messages) == 1 ), "Currently not supporting multiple messages" result_message = result_messages[0] + span.set_attribute("output", result_message.model_dump_json()) yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( diff --git a/llama_stack/providers/inline/datasetio/localfs/datasetio.py b/llama_stack/providers/inline/datasetio/localfs/datasetio.py index 010610056..736e5d8b9 100644 --- a/llama_stack/providers/inline/datasetio/localfs/datasetio.py +++ b/llama_stack/providers/inline/datasetio/localfs/datasetio.py @@ -3,14 +3,17 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Optional +from typing import Any, Dict, List, Optional import pandas from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403 +import base64 +import os from abc import ABC, abstractmethod from dataclasses import dataclass +from urllib.parse import urlparse from llama_stack.providers.datatypes import DatasetsProtocolPrivate from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url @@ -131,3 +134,41 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): total_count=len(rows), next_page_token=str(end), ) + + async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: + dataset_info = self.dataset_infos.get(dataset_id) + if dataset_info is None: + raise ValueError(f"Dataset with id {dataset_id} not found") + + dataset_impl = dataset_info.dataset_impl + dataset_impl.load() + + new_rows_df = pandas.DataFrame(rows) + new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df) + dataset_impl.df = pandas.concat( + [dataset_impl.df, new_rows_df], ignore_index=True + ) + + url = str(dataset_info.dataset_def.url) + parsed_url = urlparse(url) + + if parsed_url.scheme == "file" or not parsed_url.scheme: + file_path = parsed_url.path + os.makedirs(os.path.dirname(file_path), exist_ok=True) + dataset_impl.df.to_csv(file_path, index=False) + elif parsed_url.scheme == "data": + # For data URLs, we need to update the base64-encoded content + if not parsed_url.path.startswith("text/csv;base64,"): + raise ValueError("Data URL must be a base64-encoded CSV") + + csv_buffer = dataset_impl.df.to_csv(index=False) + base64_content = base64.b64encode(csv_buffer.encode("utf-8")).decode( + "utf-8" + ) + dataset_info.dataset_def.url = URL( + uri=f"data:text/csv;base64,{base64_content}" + ) + else: + raise ValueError( + f"Unsupported URL scheme: {parsed_url.scheme}. Only file:// and data: URLs are supported for writing." + ) diff --git a/llama_stack/providers/inline/meta_reference/telemetry/__init__.py b/llama_stack/providers/inline/meta_reference/telemetry/__init__.py deleted file mode 100644 index 4a0c2f6ee..000000000 --- a/llama_stack/providers/inline/meta_reference/telemetry/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from .config import ConsoleConfig - - -async def get_provider_impl(config: ConsoleConfig, _deps): - from .console import ConsoleTelemetryImpl - - impl = ConsoleTelemetryImpl(config) - await impl.initialize() - return impl diff --git a/llama_stack/providers/inline/meta_reference/telemetry/config.py b/llama_stack/providers/inline/meta_reference/telemetry/config.py deleted file mode 100644 index a1db1d4d8..000000000 --- a/llama_stack/providers/inline/meta_reference/telemetry/config.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from enum import Enum - -from llama_models.schema_utils import json_schema_type - -from pydantic import BaseModel - - -class LogFormat(Enum): - TEXT = "text" - JSON = "json" - - -@json_schema_type -class ConsoleConfig(BaseModel): - log_format: LogFormat = LogFormat.TEXT diff --git a/llama_stack/providers/inline/meta_reference/telemetry/console.py b/llama_stack/providers/inline/meta_reference/telemetry/console.py index d8ef49481..838aaa4e1 100644 --- a/llama_stack/providers/inline/meta_reference/telemetry/console.py +++ b/llama_stack/providers/inline/meta_reference/telemetry/console.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import json -from typing import Optional +from typing import List, Optional from .config import LogFormat @@ -49,8 +49,27 @@ class ConsoleTelemetryImpl(Telemetry): if formatted: print(formatted) - async def get_trace(self, trace_id: str) -> Trace: - raise NotImplementedError() + async def query_traces( + self, + attribute_conditions: Optional[List[QueryCondition]] = None, + attribute_keys_to_return: Optional[List[str]] = None, + limit: Optional[int] = 100, + offset: Optional[int] = 0, + order_by: Optional[List[str]] = None, + ) -> List[Trace]: + raise NotImplementedError("Console telemetry does not support trace querying") + + async def get_spans( + self, + span_id: str, + attribute_conditions: Optional[List[QueryCondition]] = None, + attribute_keys_to_return: Optional[List[str]] = None, + max_depth: Optional[int] = None, + limit: Optional[int] = 100, + offset: Optional[int] = 0, + order_by: Optional[List[str]] = None, + ) -> SpanWithChildren: + raise NotImplementedError("Console telemetry does not support span querying") COLORS = { diff --git a/llama_stack/providers/remote/telemetry/__init__.py b/llama_stack/providers/inline/telemetry/__init__.py similarity index 100% rename from llama_stack/providers/remote/telemetry/__init__.py rename to llama_stack/providers/inline/telemetry/__init__.py diff --git a/llama_stack/providers/inline/telemetry/meta_reference/__init__.py b/llama_stack/providers/inline/telemetry/meta_reference/__init__.py new file mode 100644 index 000000000..6213d5536 --- /dev/null +++ b/llama_stack/providers/inline/telemetry/meta_reference/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict + +from .config import TelemetryConfig, TelemetrySink +from .telemetry import TelemetryAdapter + +__all__ = ["TelemetryConfig", "TelemetryAdapter", "TelemetrySink"] + + +async def get_provider_impl(config: TelemetryConfig, deps: Dict[str, Any]): + impl = TelemetryAdapter(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/inline/telemetry/meta_reference/config.py b/llama_stack/providers/inline/telemetry/meta_reference/config.py new file mode 100644 index 000000000..0230d24d2 --- /dev/null +++ b/llama_stack/providers/inline/telemetry/meta_reference/config.py @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from enum import Enum +from typing import Any, Dict, List + +from pydantic import BaseModel, Field + +from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR + + +class TelemetrySink(str, Enum): + JAEGER = "jaeger" + SQLITE = "sqlite" + CONSOLE = "console" + + +class TelemetryConfig(BaseModel): + otel_endpoint: str = Field( + default="http://localhost:4318/v1/traces", + description="The OpenTelemetry collector endpoint URL", + ) + service_name: str = Field( + default="llama-stack", + description="The service name to use for telemetry", + ) + sinks: List[TelemetrySink] = Field( + default=[TelemetrySink.CONSOLE, TelemetrySink.SQLITE], + description="List of telemetry sinks to enable (possible values: jaeger, sqlite, console)", + ) + sqlite_db_path: str = Field( + default=(RUNTIME_BASE_DIR / "trace_store.db").as_posix(), + description="The path to the SQLite database to use for storing traces", + ) + + @classmethod + def sample_run_config(cls, **kwargs) -> Dict[str, Any]: + return { + "service_name": "${env.OTEL_SERVICE_NAME:llama-stack}", + "sinks": "${env.TELEMETRY_SINKS:['console', 'sqlite']}", + "sqlite_db_path": "${env.SQLITE_DB_PATH:${runtime.base_dir}/trace_store.db}", + } diff --git a/llama_stack/providers/inline/telemetry/meta_reference/console_span_processor.py b/llama_stack/providers/inline/telemetry/meta_reference/console_span_processor.py new file mode 100644 index 000000000..8d6f779e6 --- /dev/null +++ b/llama_stack/providers/inline/telemetry/meta_reference/console_span_processor.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from datetime import datetime + +from opentelemetry.sdk.trace import ReadableSpan +from opentelemetry.sdk.trace.export import SpanProcessor + +# Colors for console output +COLORS = { + "reset": "\033[0m", + "bold": "\033[1m", + "dim": "\033[2m", + "red": "\033[31m", + "green": "\033[32m", + "yellow": "\033[33m", + "blue": "\033[34m", + "magenta": "\033[35m", + "cyan": "\033[36m", + "white": "\033[37m", +} + + +class ConsoleSpanProcessor(SpanProcessor): + """A SpanProcessor that prints spans to the console with color formatting.""" + + def on_start(self, span: ReadableSpan, parent_context=None) -> None: + """Called when a span starts.""" + timestamp = datetime.utcfromtimestamp(span.start_time / 1e9).strftime( + "%H:%M:%S.%f" + )[:-3] + + print( + f"{COLORS['dim']}{timestamp}{COLORS['reset']} " + f"{COLORS['magenta']}[START]{COLORS['reset']} " + f"{COLORS['cyan']}{span.name}{COLORS['reset']}" + ) + + def on_end(self, span: ReadableSpan) -> None: + """Called when a span ends.""" + timestamp = datetime.utcfromtimestamp(span.end_time / 1e9).strftime( + "%H:%M:%S.%f" + )[:-3] + + # Build the span context string + span_context = ( + f"{COLORS['dim']}{timestamp}{COLORS['reset']} " + f"{COLORS['magenta']}[END]{COLORS['reset']} " + f"{COLORS['cyan']}{span.name}{COLORS['reset']} " + ) + + # Add status if not OK + if span.status.status_code != 0: # UNSET or ERROR + status_color = ( + COLORS["red"] if span.status.status_code == 2 else COLORS["yellow"] + ) + span_context += ( + f" {status_color}[{span.status.status_code}]{COLORS['reset']}" + ) + + # Add duration + duration_ms = (span.end_time - span.start_time) / 1e6 + span_context += f" {COLORS['dim']}({duration_ms:.2f}ms){COLORS['reset']}" + + # Print the main span line + print(span_context) + + # Print attributes indented + if span.attributes: + for key, value in span.attributes.items(): + print(f" {COLORS['dim']}{key}: {value}{COLORS['reset']}") + + # Print events indented + for event in span.events: + event_time = datetime.utcfromtimestamp(event.timestamp / 1e9).strftime( + "%H:%M:%S.%f" + )[:-3] + print( + f" {COLORS['dim']}{event_time}{COLORS['reset']} " + f"{COLORS['cyan']}[EVENT]{COLORS['reset']} {event.name}" + ) + if event.attributes: + for key, value in event.attributes.items(): + print(f" {COLORS['dim']}{key}: {value}{COLORS['reset']}") + + def shutdown(self) -> None: + """Shutdown the processor.""" + pass + + def force_flush(self, timeout_millis: float = None) -> bool: + """Force flush any pending spans.""" + return True diff --git a/llama_stack/providers/inline/telemetry/meta_reference/sqlite_span_processor.py b/llama_stack/providers/inline/telemetry/meta_reference/sqlite_span_processor.py new file mode 100644 index 000000000..553dd5000 --- /dev/null +++ b/llama_stack/providers/inline/telemetry/meta_reference/sqlite_span_processor.py @@ -0,0 +1,242 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +import os +import sqlite3 +import threading +from datetime import datetime, timedelta +from typing import Dict + +from opentelemetry.sdk.trace import SpanProcessor +from opentelemetry.trace import Span + + +class SQLiteSpanProcessor(SpanProcessor): + def __init__(self, conn_string, ttl_days=30): + """Initialize the SQLite span processor with a connection string.""" + self.conn_string = conn_string + self.ttl_days = ttl_days + self.cleanup_task = None + self._thread_local = threading.local() + self._connections: Dict[int, sqlite3.Connection] = {} + self._lock = threading.Lock() + self.setup_database() + + def _get_connection(self) -> sqlite3.Connection: + """Get a thread-specific database connection.""" + thread_id = threading.get_ident() + with self._lock: + if thread_id not in self._connections: + conn = sqlite3.connect(self.conn_string) + self._connections[thread_id] = conn + return self._connections[thread_id] + + def setup_database(self): + """Create the necessary tables if they don't exist.""" + # Create directory if it doesn't exist + os.makedirs(os.path.dirname(self.conn_string), exist_ok=True) + + conn = self._get_connection() + cursor = conn.cursor() + + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS traces ( + trace_id TEXT PRIMARY KEY, + service_name TEXT, + root_span_id TEXT, + start_time TIMESTAMP, + end_time TIMESTAMP, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """ + ) + + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS spans ( + span_id TEXT PRIMARY KEY, + trace_id TEXT REFERENCES traces(trace_id), + parent_span_id TEXT, + name TEXT, + start_time TIMESTAMP, + end_time TIMESTAMP, + attributes TEXT, + status TEXT, + kind TEXT + ) + """ + ) + + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS span_events ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + span_id TEXT REFERENCES spans(span_id), + name TEXT, + timestamp TIMESTAMP, + attributes TEXT + ) + """ + ) + + cursor.execute( + """ + CREATE INDEX IF NOT EXISTS idx_traces_created_at + ON traces(created_at) + """ + ) + + conn.commit() + cursor.close() + + # Start periodic cleanup in a separate thread + self.cleanup_task = threading.Thread(target=self._periodic_cleanup, daemon=True) + self.cleanup_task.start() + + def _cleanup_old_data(self): + """Delete records older than TTL.""" + try: + conn = self._get_connection() + cutoff_date = (datetime.now() - timedelta(days=self.ttl_days)).isoformat() + cursor = conn.cursor() + + # Delete old span events + cursor.execute( + """ + DELETE FROM span_events + WHERE span_id IN ( + SELECT span_id FROM spans + WHERE trace_id IN ( + SELECT trace_id FROM traces + WHERE created_at < ? + ) + ) + """, + (cutoff_date,), + ) + + # Delete old spans + cursor.execute( + """ + DELETE FROM spans + WHERE trace_id IN ( + SELECT trace_id FROM traces + WHERE created_at < ? + ) + """, + (cutoff_date,), + ) + + # Delete old traces + cursor.execute("DELETE FROM traces WHERE created_at < ?", (cutoff_date,)) + + conn.commit() + cursor.close() + except Exception as e: + print(f"Error during cleanup: {e}") + + def _periodic_cleanup(self): + """Run cleanup periodically.""" + import time + + while True: + time.sleep(3600) # Sleep for 1 hour + self._cleanup_old_data() + + def on_start(self, span: Span, parent_context=None): + """Called when a span starts.""" + pass + + def on_end(self, span: Span): + """Called when a span ends. Export the span data to SQLite.""" + try: + conn = self._get_connection() + cursor = conn.cursor() + + trace_id = format(span.get_span_context().trace_id, "032x") + span_id = format(span.get_span_context().span_id, "016x") + service_name = span.resource.attributes.get("service.name", "unknown") + + parent_span_id = None + parent_context = span.parent + if parent_context: + parent_span_id = format(parent_context.span_id, "016x") + + # Insert into traces + cursor.execute( + """ + INSERT INTO traces ( + trace_id, service_name, root_span_id, start_time, end_time + ) VALUES (?, ?, ?, ?, ?) + ON CONFLICT(trace_id) DO UPDATE SET + root_span_id = COALESCE(root_span_id, excluded.root_span_id), + start_time = MIN(excluded.start_time, start_time), + end_time = MAX(excluded.end_time, end_time) + """, + ( + trace_id, + service_name, + (span_id if not parent_span_id else None), + datetime.fromtimestamp(span.start_time / 1e9).isoformat(), + datetime.fromtimestamp(span.end_time / 1e9).isoformat(), + ), + ) + + # Insert into spans + cursor.execute( + """ + INSERT INTO spans ( + span_id, trace_id, parent_span_id, name, + start_time, end_time, attributes, status, + kind + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + span_id, + trace_id, + parent_span_id, + span.name, + datetime.fromtimestamp(span.start_time / 1e9).isoformat(), + datetime.fromtimestamp(span.end_time / 1e9).isoformat(), + json.dumps(dict(span.attributes)), + span.status.status_code.name, + span.kind.name, + ), + ) + + for event in span.events: + cursor.execute( + """ + INSERT INTO span_events ( + span_id, name, timestamp, attributes + ) VALUES (?, ?, ?, ?) + """, + ( + span_id, + event.name, + datetime.fromtimestamp(event.timestamp / 1e9).isoformat(), + json.dumps(dict(event.attributes)), + ), + ) + + conn.commit() + cursor.close() + except Exception as e: + print(f"Error exporting span to SQLite: {e}") + + def shutdown(self): + """Cleanup any resources.""" + with self._lock: + for conn in self._connections.values(): + if conn: + conn.close() + self._connections.clear() + + def force_flush(self, timeout_millis=30000): + """Force export of spans.""" + pass diff --git a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py new file mode 100644 index 000000000..6540a667f --- /dev/null +++ b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py @@ -0,0 +1,247 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import threading +from typing import List, Optional + +from opentelemetry import metrics, trace +from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry.sdk.metrics import MeterProvider +from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor +from opentelemetry.semconv.resource import ResourceAttributes + +from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import ( + ConsoleSpanProcessor, +) + +from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor import ( + SQLiteSpanProcessor, +) +from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore + +from llama_stack.apis.telemetry import * # noqa: F403 + +from .config import TelemetryConfig, TelemetrySink + +_GLOBAL_STORAGE = { + "active_spans": {}, + "counters": {}, + "gauges": {}, + "up_down_counters": {}, +} +_global_lock = threading.Lock() + + +def string_to_trace_id(s: str) -> int: + # Convert the string to bytes and then to an integer + return int.from_bytes(s.encode(), byteorder="big", signed=False) + + +def string_to_span_id(s: str) -> int: + # Use only the first 8 bytes (64 bits) for span ID + return int.from_bytes(s.encode()[:8], byteorder="big", signed=False) + + +def is_tracing_enabled(tracer): + with tracer.start_as_current_span("check_tracing") as span: + return span.is_recording() + + +class TelemetryAdapter(Telemetry): + def __init__(self, config: TelemetryConfig) -> None: + self.config = config + + resource = Resource.create( + { + ResourceAttributes.SERVICE_NAME: self.config.service_name, + } + ) + + provider = TracerProvider(resource=resource) + trace.set_tracer_provider(provider) + if TelemetrySink.JAEGER in self.config.sinks: + otlp_exporter = OTLPSpanExporter( + endpoint=self.config.otel_endpoint, + ) + span_processor = BatchSpanProcessor(otlp_exporter) + trace.get_tracer_provider().add_span_processor(span_processor) + metric_reader = PeriodicExportingMetricReader( + OTLPMetricExporter( + endpoint=self.config.otel_endpoint, + ) + ) + metric_provider = MeterProvider( + resource=resource, metric_readers=[metric_reader] + ) + metrics.set_meter_provider(metric_provider) + self.meter = metrics.get_meter(__name__) + if TelemetrySink.SQLITE in self.config.sinks: + trace.get_tracer_provider().add_span_processor( + SQLiteSpanProcessor(self.config.sqlite_db_path) + ) + self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path) + if TelemetrySink.CONSOLE in self.config.sinks: + trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor()) + self._lock = _global_lock + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + trace.get_tracer_provider().force_flush() + trace.get_tracer_provider().shutdown() + metrics.get_meter_provider().shutdown() + + async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None: + if isinstance(event, UnstructuredLogEvent): + self._log_unstructured(event, ttl_seconds) + elif isinstance(event, MetricEvent): + self._log_metric(event) + elif isinstance(event, StructuredLogEvent): + self._log_structured(event, ttl_seconds) + else: + raise ValueError(f"Unknown event type: {event}") + + def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None: + with self._lock: + # Use global storage instead of instance storage + span_id = string_to_span_id(event.span_id) + span = _GLOBAL_STORAGE["active_spans"].get(span_id) + + if span: + timestamp_ns = int(event.timestamp.timestamp() * 1e9) + span.add_event( + name=event.type, + attributes={ + "message": event.message, + "severity": event.severity.value, + "__ttl__": ttl_seconds, + **event.attributes, + }, + timestamp=timestamp_ns, + ) + else: + print( + f"Warning: No active span found for span_id {span_id}. Dropping event: {event}" + ) + + def _get_or_create_counter(self, name: str, unit: str) -> metrics.Counter: + if name not in _GLOBAL_STORAGE["counters"]: + _GLOBAL_STORAGE["counters"][name] = self.meter.create_counter( + name=name, + unit=unit, + description=f"Counter for {name}", + ) + return _GLOBAL_STORAGE["counters"][name] + + def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge: + if name not in _GLOBAL_STORAGE["gauges"]: + _GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge( + name=name, + unit=unit, + description=f"Gauge for {name}", + ) + return _GLOBAL_STORAGE["gauges"][name] + + def _log_metric(self, event: MetricEvent) -> None: + if isinstance(event.value, int): + counter = self._get_or_create_counter(event.metric, event.unit) + counter.add(event.value, attributes=event.attributes) + elif isinstance(event.value, float): + up_down_counter = self._get_or_create_up_down_counter( + event.metric, event.unit + ) + up_down_counter.add(event.value, attributes=event.attributes) + + def _get_or_create_up_down_counter( + self, name: str, unit: str + ) -> metrics.UpDownCounter: + if name not in _GLOBAL_STORAGE["up_down_counters"]: + _GLOBAL_STORAGE["up_down_counters"][name] = ( + self.meter.create_up_down_counter( + name=name, + unit=unit, + description=f"UpDownCounter for {name}", + ) + ) + return _GLOBAL_STORAGE["up_down_counters"][name] + + def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None: + with self._lock: + span_id = string_to_span_id(event.span_id) + trace_id = string_to_trace_id(event.trace_id) + tracer = trace.get_tracer(__name__) + if event.attributes is None: + event.attributes = {} + event.attributes["__ttl__"] = ttl_seconds + + if isinstance(event.payload, SpanStartPayload): + # Check if span already exists to prevent duplicates + if span_id in _GLOBAL_STORAGE["active_spans"]: + return + + parent_span = None + if event.payload.parent_span_id: + parent_span_id = string_to_span_id(event.payload.parent_span_id) + parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id) + + context = trace.Context(trace_id=trace_id) + if parent_span: + context = trace.set_span_in_context(parent_span, context) + + span = tracer.start_span( + name=event.payload.name, + context=context, + attributes=event.attributes or {}, + ) + _GLOBAL_STORAGE["active_spans"][span_id] = span + + elif isinstance(event.payload, SpanEndPayload): + span = _GLOBAL_STORAGE["active_spans"].get(span_id) + if span: + if event.attributes: + span.set_attributes(event.attributes) + + status = ( + trace.Status(status_code=trace.StatusCode.OK) + if event.payload.status == SpanStatus.OK + else trace.Status(status_code=trace.StatusCode.ERROR) + ) + span.set_status(status) + span.end() + _GLOBAL_STORAGE["active_spans"].pop(span_id, None) + else: + raise ValueError(f"Unknown structured log event: {event}") + + async def query_traces( + self, + attribute_filters: Optional[List[QueryCondition]] = None, + limit: Optional[int] = 100, + offset: Optional[int] = 0, + order_by: Optional[List[str]] = None, + ) -> List[Trace]: + return await self.trace_store.query_traces( + attribute_filters=attribute_filters, + limit=limit, + offset=offset, + order_by=order_by, + ) + + async def get_span_tree( + self, + span_id: str, + attributes_to_return: Optional[List[str]] = None, + max_depth: Optional[int] = None, + ) -> SpanWithChildren: + return await self.trace_store.get_materialized_span( + span_id=span_id, + attributes_to_return=attributes_to_return, + max_depth=max_depth, + ) diff --git a/llama_stack/providers/remote/telemetry/sample/__init__.py b/llama_stack/providers/inline/telemetry/sample/__init__.py similarity index 100% rename from llama_stack/providers/remote/telemetry/sample/__init__.py rename to llama_stack/providers/inline/telemetry/sample/__init__.py diff --git a/llama_stack/providers/remote/telemetry/sample/config.py b/llama_stack/providers/inline/telemetry/sample/config.py similarity index 100% rename from llama_stack/providers/remote/telemetry/sample/config.py rename to llama_stack/providers/inline/telemetry/sample/config.py diff --git a/llama_stack/providers/remote/telemetry/sample/sample.py b/llama_stack/providers/inline/telemetry/sample/sample.py similarity index 100% rename from llama_stack/providers/remote/telemetry/sample/sample.py rename to llama_stack/providers/inline/telemetry/sample/sample.py diff --git a/llama_stack/providers/registry/telemetry.py b/llama_stack/providers/registry/telemetry.py index ac537e076..a53ad5b94 100644 --- a/llama_stack/providers/registry/telemetry.py +++ b/llama_stack/providers/registry/telemetry.py @@ -14,9 +14,12 @@ def available_providers() -> List[ProviderSpec]: InlineProviderSpec( api=Api.telemetry, provider_type="inline::meta-reference", - pip_packages=[], - module="llama_stack.providers.inline.meta_reference.telemetry", - config_class="llama_stack.providers.inline.meta_reference.telemetry.ConsoleConfig", + pip_packages=[ + "opentelemetry-sdk", + "opentelemetry-exporter-otlp-proto-http", + ], + module="llama_stack.providers.inline.telemetry.meta_reference", + config_class="llama_stack.providers.inline.telemetry.meta_reference.config.TelemetryConfig", ), remote_provider_spec( api=Api.telemetry, @@ -27,18 +30,4 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.remote.telemetry.sample.SampleConfig", ), ), - remote_provider_spec( - api=Api.telemetry, - adapter=AdapterSpec( - adapter_type="opentelemetry-jaeger", - pip_packages=[ - "opentelemetry-api", - "opentelemetry-sdk", - "opentelemetry-exporter-jaeger", - "opentelemetry-semantic-conventions", - ], - module="llama_stack.providers.remote.telemetry.opentelemetry", - config_class="llama_stack.providers.remote.telemetry.opentelemetry.OpenTelemetryConfig", - ), - ), ] diff --git a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py index cdd5d9cd3..db52270a7 100644 --- a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py +++ b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py @@ -3,7 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Optional +from typing import Any, Dict, List, Optional from llama_stack.apis.datasetio import * # noqa: F403 @@ -100,3 +100,22 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): total_count=len(rows), next_page_token=str(end), ) + + async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: + dataset_def = self.dataset_infos[dataset_id] + loaded_dataset = load_hf_dataset(dataset_def) + + # Convert rows to HF Dataset format + new_dataset = hf_datasets.Dataset.from_list(rows) + + # Concatenate the new rows with existing dataset + updated_dataset = hf_datasets.concatenate_datasets( + [loaded_dataset, new_dataset] + ) + + if dataset_def.metadata.get("path", None): + updated_dataset.push_to_hub(dataset_def.metadata["path"]) + else: + raise NotImplementedError( + "Uploading to URL-based datasets is not supported yet" + ) diff --git a/llama_stack/providers/remote/telemetry/opentelemetry/__init__.py b/llama_stack/providers/remote/telemetry/opentelemetry/__init__.py deleted file mode 100644 index 0842afe2d..000000000 --- a/llama_stack/providers/remote/telemetry/opentelemetry/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from .config import OpenTelemetryConfig - - -async def get_adapter_impl(config: OpenTelemetryConfig, _deps): - from .opentelemetry import OpenTelemetryAdapter - - impl = OpenTelemetryAdapter(config) - await impl.initialize() - return impl diff --git a/llama_stack/providers/remote/telemetry/opentelemetry/config.py b/llama_stack/providers/remote/telemetry/opentelemetry/config.py deleted file mode 100644 index 5e9dff1a1..000000000 --- a/llama_stack/providers/remote/telemetry/opentelemetry/config.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any, Dict - -from pydantic import BaseModel, Field - - -class OpenTelemetryConfig(BaseModel): - otel_endpoint: str = Field( - default="http://localhost:4318/v1/traces", - description="The OpenTelemetry collector endpoint URL", - ) - service_name: str = Field( - default="llama-stack", - description="The service name to use for telemetry", - ) - - @classmethod - def sample_run_config(cls, **kwargs) -> Dict[str, Any]: - return { - "otel_endpoint": "${env.OTEL_ENDPOINT:http://localhost:4318/v1/traces}", - "service_name": "${env.OTEL_SERVICE_NAME:llama-stack}", - } diff --git a/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py b/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py index c9830fd9d..04eb71ce0 100644 --- a/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py +++ b/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py @@ -5,6 +5,16 @@ # the root directory of this source tree. import threading +from typing import List, Optional + +from llama_stack.distribution.datatypes import Api +from llama_stack.providers.remote.telemetry.opentelemetry.console_span_processor import ( + ConsoleSpanProcessor, +) +from llama_stack.providers.remote.telemetry.opentelemetry.sqlite_span_processor import ( + SQLiteSpanProcessor, +) +from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore from opentelemetry import metrics, trace from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter @@ -19,7 +29,7 @@ from opentelemetry.semconv.resource import ResourceAttributes from llama_stack.apis.telemetry import * # noqa: F403 -from .config import OpenTelemetryConfig +from .config import OpenTelemetryConfig, TelemetrySink _GLOBAL_STORAGE = { "active_spans": {}, @@ -46,8 +56,9 @@ def is_tracing_enabled(tracer): class OpenTelemetryAdapter(Telemetry): - def __init__(self, config: OpenTelemetryConfig): + def __init__(self, config: OpenTelemetryConfig, deps) -> None: self.config = config + self.datasetio = deps[Api.datasetio] resource = Resource.create( { @@ -57,22 +68,29 @@ class OpenTelemetryAdapter(Telemetry): provider = TracerProvider(resource=resource) trace.set_tracer_provider(provider) - otlp_exporter = OTLPSpanExporter( - endpoint=self.config.otel_endpoint, - ) - span_processor = BatchSpanProcessor(otlp_exporter) - trace.get_tracer_provider().add_span_processor(span_processor) - # Set up metrics - metric_reader = PeriodicExportingMetricReader( - OTLPMetricExporter( + if TelemetrySink.JAEGER in self.config.sinks: + otlp_exporter = OTLPSpanExporter( endpoint=self.config.otel_endpoint, ) - ) - metric_provider = MeterProvider( - resource=resource, metric_readers=[metric_reader] - ) - metrics.set_meter_provider(metric_provider) - self.meter = metrics.get_meter(__name__) + span_processor = BatchSpanProcessor(otlp_exporter) + trace.get_tracer_provider().add_span_processor(span_processor) + metric_reader = PeriodicExportingMetricReader( + OTLPMetricExporter( + endpoint=self.config.otel_endpoint, + ) + ) + metric_provider = MeterProvider( + resource=resource, metric_readers=[metric_reader] + ) + metrics.set_meter_provider(metric_provider) + self.meter = metrics.get_meter(__name__) + if TelemetrySink.SQLITE in self.config.sinks: + trace.get_tracer_provider().add_span_processor( + SQLiteSpanProcessor(self.config.sqlite_db_path) + ) + self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path) + if TelemetrySink.CONSOLE in self.config.sinks: + trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor()) self._lock = _global_lock async def initialize(self) -> None: @@ -83,15 +101,17 @@ class OpenTelemetryAdapter(Telemetry): trace.get_tracer_provider().shutdown() metrics.get_meter_provider().shutdown() - async def log_event(self, event: Event) -> None: + async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None: if isinstance(event, UnstructuredLogEvent): - self._log_unstructured(event) + self._log_unstructured(event, ttl_seconds) elif isinstance(event, MetricEvent): self._log_metric(event) elif isinstance(event, StructuredLogEvent): - self._log_structured(event) + self._log_structured(event, ttl_seconds) + else: + raise ValueError(f"Unknown event type: {event}") - def _log_unstructured(self, event: UnstructuredLogEvent) -> None: + def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None: with self._lock: # Use global storage instead of instance storage span_id = string_to_span_id(event.span_id) @@ -104,6 +124,7 @@ class OpenTelemetryAdapter(Telemetry): attributes={ "message": event.message, "severity": event.severity.value, + "__ttl__": ttl_seconds, **event.attributes, }, timestamp=timestamp_ns, @@ -154,11 +175,14 @@ class OpenTelemetryAdapter(Telemetry): ) return _GLOBAL_STORAGE["up_down_counters"][name] - def _log_structured(self, event: StructuredLogEvent) -> None: + def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None: with self._lock: span_id = string_to_span_id(event.span_id) trace_id = string_to_trace_id(event.trace_id) tracer = trace.get_tracer(__name__) + if event.attributes is None: + event.attributes = {} + event.attributes["__ttl__"] = ttl_seconds if isinstance(event.payload, SpanStartPayload): # Check if span already exists to prevent duplicates @@ -170,7 +194,6 @@ class OpenTelemetryAdapter(Telemetry): parent_span_id = string_to_span_id(event.payload.parent_span_id) parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id) - # Create a new trace context with the trace_id context = trace.Context(trace_id=trace_id) if parent_span: context = trace.set_span_in_context(parent_span, context) @@ -179,14 +202,9 @@ class OpenTelemetryAdapter(Telemetry): name=event.payload.name, context=context, attributes=event.attributes or {}, - start_time=int(event.timestamp.timestamp() * 1e9), ) _GLOBAL_STORAGE["active_spans"][span_id] = span - # Set as current span using context manager - with trace.use_span(span, end_on_exit=False): - pass # Let the span continue beyond this block - elif isinstance(event.payload, SpanEndPayload): span = _GLOBAL_STORAGE["active_spans"].get(span_id) if span: @@ -199,10 +217,43 @@ class OpenTelemetryAdapter(Telemetry): else trace.Status(status_code=trace.StatusCode.ERROR) ) span.set_status(status) - span.end(end_time=int(event.timestamp.timestamp() * 1e9)) - - # Remove from active spans + span.end() _GLOBAL_STORAGE["active_spans"].pop(span_id, None) + else: + raise ValueError(f"Unknown structured log event: {event}") - async def get_trace(self, trace_id: str) -> Trace: - raise NotImplementedError("Trace retrieval not implemented yet") + async def query_traces( + self, + attribute_conditions: Optional[List[QueryCondition]] = None, + attribute_keys_to_return: Optional[List[str]] = None, + limit: Optional[int] = 100, + offset: Optional[int] = 0, + order_by: Optional[List[str]] = None, + ) -> List[Trace]: + return await self.trace_store.query_traces( + attribute_conditions=attribute_conditions, + attribute_keys_to_return=attribute_keys_to_return, + limit=limit, + offset=offset, + order_by=order_by, + ) + + async def get_spans( + self, + span_id: str, + attribute_conditions: Optional[List[QueryCondition]] = None, + attribute_keys_to_return: Optional[List[str]] = None, + max_depth: Optional[int] = None, + limit: Optional[int] = 100, + offset: Optional[int] = 0, + order_by: Optional[List[str]] = None, + ) -> SpanWithChildren: + return await self.trace_store.get_spans( + span_id=span_id, + attribute_conditions=attribute_conditions, + attribute_keys_to_return=attribute_keys_to_return, + max_depth=max_depth, + limit=limit, + offset=offset, + order_by=order_by, + ) diff --git a/llama_stack/providers/utils/telemetry/sqlite.py b/llama_stack/providers/utils/telemetry/sqlite.py new file mode 100644 index 000000000..e7161fffa --- /dev/null +++ b/llama_stack/providers/utils/telemetry/sqlite.py @@ -0,0 +1,177 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +from datetime import datetime +from typing import List, Optional + +import aiosqlite + +from llama_stack.apis.telemetry import ( + QueryCondition, + SpanWithChildren, + Trace, + TraceStore, +) + + +class SQLiteTraceStore(TraceStore): + def __init__(self, conn_string: str): + self.conn_string = conn_string + + async def query_traces( + self, + attribute_filters: Optional[List[QueryCondition]] = None, + attributes_to_return: Optional[List[str]] = None, + limit: Optional[int] = 100, + offset: Optional[int] = 0, + order_by: Optional[List[str]] = None, + ) -> List[Trace]: + print(attribute_filters, attributes_to_return, limit, offset, order_by) + + def build_attribute_select() -> str: + if not attributes_to_return: + return "" + return "".join( + f", json_extract(s.attributes, '$.{key}') as attr_{key}" + for key in attributes_to_return + ) + + def build_where_clause() -> tuple[str, list]: + if not attribute_filters: + return "", [] + + conditions = [ + f"json_extract(s.attributes, '$.{condition.key}') {condition.op} ?" + for condition in attribute_filters + ] + params = [condition.value for condition in attribute_filters] + where_clause = " WHERE " + " AND ".join(conditions) + return where_clause, params + + def build_order_clause() -> str: + if not order_by: + return "" + + order_clauses = [] + for field in order_by: + desc = field.startswith("-") + clean_field = field[1:] if desc else field + order_clauses.append(f"t.{clean_field} {'DESC' if desc else 'ASC'}") + return " ORDER BY " + ", ".join(order_clauses) + + # Build the main query + base_query = """ + WITH matching_traces AS ( + SELECT DISTINCT t.trace_id + FROM traces t + JOIN spans s ON t.trace_id = s.trace_id + {where_clause} + ), + filtered_traces AS ( + SELECT t.trace_id, t.root_span_id, t.start_time, t.end_time + {attribute_select} + FROM matching_traces mt + JOIN traces t ON mt.trace_id = t.trace_id + LEFT JOIN spans s ON t.trace_id = s.trace_id + {order_clause} + ) + SELECT DISTINCT trace_id, root_span_id, start_time, end_time + FROM filtered_traces + LIMIT {limit} OFFSET {offset} + """ + + where_clause, params = build_where_clause() + query = base_query.format( + attribute_select=build_attribute_select(), + where_clause=where_clause, + order_clause=build_order_clause(), + limit=limit, + offset=offset, + ) + + # Execute query and return results + async with aiosqlite.connect(self.conn_string) as conn: + conn.row_factory = aiosqlite.Row + async with conn.execute(query, params) as cursor: + rows = await cursor.fetchall() + return [ + Trace( + trace_id=row["trace_id"], + root_span_id=row["root_span_id"], + start_time=datetime.fromisoformat(row["start_time"]), + end_time=datetime.fromisoformat(row["end_time"]), + ) + for row in rows + ] + + async def get_materialized_span( + self, + span_id: str, + attributes_to_return: Optional[List[str]] = None, + max_depth: Optional[int] = None, + ) -> SpanWithChildren: + # Build the attributes selection + attributes_select = "s.attributes" + if attributes_to_return: + json_object = ", ".join( + f"'{key}', json_extract(s.attributes, '$.{key}')" + for key in attributes_to_return + ) + attributes_select = f"json_object({json_object})" + + # SQLite CTE query with filtered attributes + query = f""" + WITH RECURSIVE span_tree AS ( + SELECT s.*, 1 as depth, {attributes_select} as filtered_attributes + FROM spans s + WHERE s.span_id = ? + + UNION ALL + + SELECT s.*, st.depth + 1, {attributes_select} as filtered_attributes + FROM spans s + JOIN span_tree st ON s.parent_span_id = st.span_id + WHERE (? IS NULL OR st.depth < ?) + ) + SELECT * + FROM span_tree + ORDER BY depth, start_time + """ + + async with aiosqlite.connect(self.conn_string) as conn: + conn.row_factory = aiosqlite.Row + async with conn.execute(query, (span_id, max_depth, max_depth)) as cursor: + rows = await cursor.fetchall() + + if not rows: + raise ValueError(f"Span {span_id} not found") + + # Build span tree + spans_by_id = {} + root_span = None + + for row in rows: + span = SpanWithChildren( + span_id=row["span_id"], + trace_id=row["trace_id"], + parent_span_id=row["parent_span_id"], + name=row["name"], + start_time=datetime.fromisoformat(row["start_time"]), + end_time=datetime.fromisoformat(row["end_time"]), + attributes=json.loads(row["filtered_attributes"]), + status=row["status"].lower(), + children=[], + ) + + spans_by_id[span.span_id] = span + + if span.span_id == span_id: + root_span = span + elif span.parent_span_id in spans_by_id: + spans_by_id[span.parent_span_id].children.append(span) + + return root_span diff --git a/llama_stack/providers/utils/telemetry/sqlite_trace_store.py b/llama_stack/providers/utils/telemetry/sqlite_trace_store.py new file mode 100644 index 000000000..ed1343e0b --- /dev/null +++ b/llama_stack/providers/utils/telemetry/sqlite_trace_store.py @@ -0,0 +1,180 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +from datetime import datetime +from typing import List, Optional, Protocol + +import aiosqlite + +from llama_stack.apis.telemetry import QueryCondition, SpanWithChildren, Trace + + +class TraceStore(Protocol): + + async def query_traces( + self, + attribute_filters: Optional[List[QueryCondition]] = None, + limit: Optional[int] = 100, + offset: Optional[int] = 0, + order_by: Optional[List[str]] = None, + ) -> List[Trace]: ... + + async def get_materialized_span( + self, + span_id: str, + attributes_to_return: Optional[List[str]] = None, + max_depth: Optional[int] = None, + ) -> SpanWithChildren: ... + + +class SQLiteTraceStore(TraceStore): + def __init__(self, conn_string: str): + self.conn_string = conn_string + + async def query_traces( + self, + attribute_filters: Optional[List[QueryCondition]] = None, + limit: Optional[int] = 100, + offset: Optional[int] = 0, + order_by: Optional[List[str]] = None, + ) -> List[Trace]: + + def build_where_clause() -> tuple[str, list]: + if not attribute_filters: + return "", [] + + ops_map = {"eq": "=", "ne": "!=", "gt": ">", "lt": "<"} + + conditions = [ + f"json_extract(s.attributes, '$.{condition.key}') {ops_map[condition.op]} ?" + for condition in attribute_filters + ] + params = [condition.value for condition in attribute_filters] + where_clause = " WHERE " + " AND ".join(conditions) + return where_clause, params + + def build_order_clause() -> str: + if not order_by: + return "" + + order_clauses = [] + for field in order_by: + desc = field.startswith("-") + clean_field = field[1:] if desc else field + order_clauses.append(f"t.{clean_field} {'DESC' if desc else 'ASC'}") + return " ORDER BY " + ", ".join(order_clauses) + + # Build the main query + base_query = """ + WITH matching_traces AS ( + SELECT DISTINCT t.trace_id + FROM traces t + JOIN spans s ON t.trace_id = s.trace_id + {where_clause} + ), + filtered_traces AS ( + SELECT t.trace_id, t.root_span_id, t.start_time, t.end_time + FROM matching_traces mt + JOIN traces t ON mt.trace_id = t.trace_id + LEFT JOIN spans s ON t.trace_id = s.trace_id + {order_clause} + ) + SELECT DISTINCT trace_id, root_span_id, start_time, end_time + FROM filtered_traces + LIMIT {limit} OFFSET {offset} + """ + + where_clause, params = build_where_clause() + query = base_query.format( + where_clause=where_clause, + order_clause=build_order_clause(), + limit=limit, + offset=offset, + ) + + # Execute query and return results + async with aiosqlite.connect(self.conn_string) as conn: + conn.row_factory = aiosqlite.Row + async with conn.execute(query, params) as cursor: + rows = await cursor.fetchall() + return [ + Trace( + trace_id=row["trace_id"], + root_span_id=row["root_span_id"], + start_time=datetime.fromisoformat(row["start_time"]), + end_time=datetime.fromisoformat(row["end_time"]), + ) + for row in rows + ] + + async def get_materialized_span( + self, + span_id: str, + attributes_to_return: Optional[List[str]] = None, + max_depth: Optional[int] = None, + ) -> SpanWithChildren: + # Build the attributes selection + attributes_select = "s.attributes" + if attributes_to_return: + json_object = ", ".join( + f"'{key}', json_extract(s.attributes, '$.{key}')" + for key in attributes_to_return + ) + attributes_select = f"json_object({json_object})" + + # SQLite CTE query with filtered attributes + query = f""" + WITH RECURSIVE span_tree AS ( + SELECT s.*, 1 as depth, {attributes_select} as filtered_attributes + FROM spans s + WHERE s.span_id = ? + + UNION ALL + + SELECT s.*, st.depth + 1, {attributes_select} as filtered_attributes + FROM spans s + JOIN span_tree st ON s.parent_span_id = st.span_id + WHERE (? IS NULL OR st.depth < ?) + ) + SELECT * + FROM span_tree + ORDER BY depth, start_time + """ + + async with aiosqlite.connect(self.conn_string) as conn: + conn.row_factory = aiosqlite.Row + async with conn.execute(query, (span_id, max_depth, max_depth)) as cursor: + rows = await cursor.fetchall() + + if not rows: + raise ValueError(f"Span {span_id} not found") + + # Build span tree + spans_by_id = {} + root_span = None + + for row in rows: + span = SpanWithChildren( + span_id=row["span_id"], + trace_id=row["trace_id"], + parent_span_id=row["parent_span_id"], + name=row["name"], + start_time=datetime.fromisoformat(row["start_time"]), + end_time=datetime.fromisoformat(row["end_time"]), + attributes=json.loads(row["filtered_attributes"]), + status=row["status"].lower(), + children=[], + ) + + spans_by_id[span.span_id] = span + + if span.span_id == span_id: + root_span = span + elif span.parent_span_id in spans_by_id: + spans_by_id[span.parent_span_id].children.append(span) + + return root_span diff --git a/llama_stack/providers/utils/telemetry/tracing.py b/llama_stack/providers/utils/telemetry/tracing.py index b53dc0df9..54558afdc 100644 --- a/llama_stack/providers/utils/telemetry/tracing.py +++ b/llama_stack/providers/utils/telemetry/tracing.py @@ -69,7 +69,7 @@ class TraceContext: self.logger = logger self.trace_id = trace_id - def push_span(self, name: str, attributes: Dict[str, Any] = None): + def push_span(self, name: str, attributes: Dict[str, Any] = None) -> Span: current_span = self.get_current_span() span = Span( span_id=generate_short_uuid(), @@ -94,6 +94,7 @@ class TraceContext: ) self.spans.append(span) + return span def pop_span(self, status: SpanStatus = SpanStatus.OK): span = self.spans.pop() @@ -203,12 +204,13 @@ class SpanContextManager: def __init__(self, name: str, attributes: Dict[str, Any] = None): self.name = name self.attributes = attributes + self.span = None def __enter__(self): global CURRENT_TRACE_CONTEXT context = CURRENT_TRACE_CONTEXT if context: - context.push_span(self.name, self.attributes) + self.span = context.push_span(self.name, self.attributes) return self def __exit__(self, exc_type, exc_value, traceback): @@ -217,11 +219,24 @@ class SpanContextManager: if context: context.pop_span() + def set_attribute(self, key: str, value: Any): + if self.span: + if self.span.attributes is None: + self.span.attributes = {} + self.span.attributes[key] = value + async def __aenter__(self): - return self.__enter__() + global CURRENT_TRACE_CONTEXT + context = CURRENT_TRACE_CONTEXT + if context: + self.span = context.push_span(self.name, self.attributes) + return self async def __aexit__(self, exc_type, exc_value, traceback): - self.__exit__(exc_type, exc_value, traceback) + global CURRENT_TRACE_CONTEXT + context = CURRENT_TRACE_CONTEXT + if context: + context.pop_span() def __call__(self, func: Callable): @wraps(func) @@ -246,3 +261,11 @@ class SpanContextManager: def span(name: str, attributes: Dict[str, Any] = None): return SpanContextManager(name, attributes) + + +def get_current_span() -> Optional[Span]: + global CURRENT_TRACE_CONTEXT + context = CURRENT_TRACE_CONTEXT + if context: + return context.get_current_span() + return None