diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 25de35497..d2243c96f 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -23,6 +23,7 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, ConfigDict, Field from typing_extensions import Annotated +from llama_stack.distribution.tracing import trace_protocol from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.common.deployment_types import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403 @@ -418,6 +419,7 @@ class AgentStepResponse(BaseModel): @runtime_checkable +@trace_protocol class Agents(Protocol): @webmethod(route="/agents/create") async def create_agent( diff --git a/llama_stack/apis/datasetio/datasetio.py b/llama_stack/apis/datasetio/datasetio.py index c5052877a..22acc3211 100644 --- a/llama_stack/apis/datasetio/datasetio.py +++ b/llama_stack/apis/datasetio/datasetio.py @@ -37,3 +37,8 @@ class DatasetIO(Protocol): page_token: Optional[str] = None, filter_condition: Optional[str] = None, ) -> PaginatedRowsResult: ... + + @webmethod(route="/datasetio/append-rows", method="POST") + async def append_rows( + self, dataset_id: str, rows: List[Dict[str, Any]] + ) -> None: ... diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 5aadd97c7..85b29a147 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -21,6 +21,8 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field from typing_extensions import Annotated +from llama_stack.distribution.tracing import trace_protocol + from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.models import * # noqa: F403 @@ -220,6 +222,7 @@ class ModelStore(Protocol): @runtime_checkable +@trace_protocol class Inference(Protocol): model_store: ModelStore diff --git a/llama_stack/apis/memory/memory.py b/llama_stack/apis/memory/memory.py index 48b6e2241..b75df8a1a 100644 --- a/llama_stack/apis/memory/memory.py +++ b/llama_stack/apis/memory/memory.py @@ -16,6 +16,7 @@ from pydantic import BaseModel, Field from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.memory_banks import * # noqa: F403 +from llama_stack.distribution.tracing import trace_protocol @json_schema_type @@ -43,6 +44,7 @@ class MemoryBankStore(Protocol): @runtime_checkable +@trace_protocol class Memory(Protocol): memory_bank_store: MemoryBankStore diff --git a/llama_stack/apis/memory_banks/memory_banks.py b/llama_stack/apis/memory_banks/memory_banks.py index 1b16af330..0b8b2563f 100644 --- a/llama_stack/apis/memory_banks/memory_banks.py +++ b/llama_stack/apis/memory_banks/memory_banks.py @@ -20,6 +20,7 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field from llama_stack.apis.resource import Resource, ResourceType +from llama_stack.distribution.tracing import trace_protocol @json_schema_type @@ -129,6 +130,7 @@ class MemoryBankInput(BaseModel): @runtime_checkable +@trace_protocol class MemoryBanks(Protocol): @webmethod(route="/memory-banks/list", method="GET") async def list_memory_banks(self) -> List[MemoryBank]: ... diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index cbd6265e2..2c0f1ee21 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -10,6 +10,7 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, ConfigDict, Field from llama_stack.apis.resource import Resource, ResourceType +from llama_stack.distribution.tracing import trace_protocol class CommonModelFields(BaseModel): @@ -43,6 +44,7 @@ class ModelInput(CommonModelFields): @runtime_checkable +@trace_protocol class Models(Protocol): @webmethod(route="/models/list", method="GET") async def list_models(self) -> List[Model]: ... diff --git a/llama_stack/apis/safety/safety.py b/llama_stack/apis/safety/safety.py index 724f8dc96..41058f107 100644 --- a/llama_stack/apis/safety/safety.py +++ b/llama_stack/apis/safety/safety.py @@ -10,6 +10,8 @@ from typing import Any, Dict, List, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel +from llama_stack.distribution.tracing import trace_protocol + from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.shields import * # noqa: F403 @@ -43,6 +45,7 @@ class ShieldStore(Protocol): @runtime_checkable +@trace_protocol class Safety(Protocol): shield_store: ShieldStore diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py index 5ee444f68..b28605727 100644 --- a/llama_stack/apis/shields/shields.py +++ b/llama_stack/apis/shields/shields.py @@ -10,6 +10,7 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel from llama_stack.apis.resource import Resource, ResourceType +from llama_stack.distribution.tracing import trace_protocol class CommonShieldFields(BaseModel): @@ -38,6 +39,7 @@ class ShieldInput(CommonShieldFields): @runtime_checkable +@trace_protocol class Shields(Protocol): @webmethod(route="/shields/list", method="GET") async def list_shields(self) -> List[Shield]: ... diff --git a/llama_stack/apis/telemetry/telemetry.py b/llama_stack/apis/telemetry/telemetry.py index 31f64733b..2ff783c46 100644 --- a/llama_stack/apis/telemetry/telemetry.py +++ b/llama_stack/apis/telemetry/telemetry.py @@ -6,12 +6,24 @@ from datetime import datetime from enum import Enum -from typing import Any, Dict, Literal, Optional, Protocol, runtime_checkable, Union +from typing import ( + Any, + Dict, + List, + Literal, + Optional, + Protocol, + runtime_checkable, + Union, +) from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field from typing_extensions import Annotated +# Add this constant near the top of the file, after the imports +DEFAULT_TTL_DAYS = 7 + @json_schema_type class SpanStatus(Enum): @@ -29,6 +41,11 @@ class Span(BaseModel): end_time: Optional[datetime] = None attributes: Optional[Dict[str, Any]] = Field(default_factory=dict) + def set_attribute(self, key: str, value: Any): + if self.attributes is None: + self.attributes = {} + self.attributes[key] = value + @json_schema_type class Trace(BaseModel): @@ -123,10 +140,49 @@ Event = Annotated[ ] +@json_schema_type +class EvalTrace(BaseModel): + session_id: str + step: str + input: str + output: str + expected_output: str + + +@json_schema_type +class SpanWithChildren(Span): + children: List["SpanWithChildren"] = Field(default_factory=list) + status: Optional[SpanStatus] = None + + +@json_schema_type +class QueryCondition(BaseModel): + key: str + op: Literal["eq", "ne", "gt", "lt"] + value: Any + + @runtime_checkable class Telemetry(Protocol): - @webmethod(route="/telemetry/log-event") - async def log_event(self, event: Event) -> None: ... - @webmethod(route="/telemetry/get-trace", method="GET") - async def get_trace(self, trace_id: str) -> Trace: ... + @webmethod(route="/telemetry/log-event") + async def log_event( + self, event: Event, ttl_seconds: int = DEFAULT_TTL_DAYS * 86400 + ) -> None: ... + + @webmethod(route="/telemetry/query-traces", method="POST") + async def query_traces( + self, + attribute_filters: Optional[List[QueryCondition]] = None, + limit: Optional[int] = 100, + offset: Optional[int] = 0, + order_by: Optional[List[str]] = None, + ) -> List[Trace]: ... + + @webmethod(route="/telemetry/get-span-tree", method="POST") + async def get_span_tree( + self, + span_id: str, + attributes_to_return: Optional[List[str]] = None, + max_depth: Optional[int] = None, + ) -> SpanWithChildren: ... diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 5a62b6d64..5b75a525b 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -222,6 +222,12 @@ class DatasetIORouter(DatasetIO): filter_condition=filter_condition, ) + async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: + return await self.routing_table.get_provider_impl(dataset_id).append_rows( + dataset_id=dataset_id, + rows=rows, + ) + class ScoringRouter(Scoring): def __init__( diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 8116e2b39..4ae1854df 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -43,9 +43,9 @@ from llama_stack.distribution.stack import ( replace_env_vars, validate_env_pair, ) -from llama_stack.providers.inline.meta_reference.telemetry.console import ( - ConsoleConfig, - ConsoleTelemetryImpl, +from llama_stack.providers.inline.telemetry.meta_reference import ( + TelemetryAdapter, + TelemetryConfig, ) from .endpoints import get_all_api_endpoints @@ -290,7 +290,7 @@ def main(): if Api.telemetry in impls: setup_logger(impls[Api.telemetry]) else: - setup_logger(ConsoleTelemetryImpl(ConsoleConfig())) + setup_logger(TelemetryAdapter(TelemetryConfig())) all_endpoints = get_all_api_endpoints() diff --git a/llama_stack/distribution/tracing.py b/llama_stack/distribution/tracing.py new file mode 100644 index 000000000..ea663ec89 --- /dev/null +++ b/llama_stack/distribution/tracing.py @@ -0,0 +1,128 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +import inspect +import json +from functools import wraps +from typing import Any, AsyncGenerator, Callable, Type, TypeVar + +from pydantic import BaseModel + +from llama_stack.providers.utils.telemetry import tracing + +T = TypeVar("T") + + +def serialize_value(value: Any) -> str: + """Helper function to serialize values to string representation.""" + try: + if isinstance(value, BaseModel): + return value.model_dump_json() + elif isinstance(value, list) and value and isinstance(value[0], BaseModel): + return json.dumps([item.model_dump_json() for item in value]) + elif hasattr(value, "to_dict"): + return json.dumps(value.to_dict()) + elif isinstance(value, (dict, list, int, float, str, bool)): + return json.dumps(value) + else: + return str(value) + except Exception: + return str(value) + + +def trace_protocol(cls: Type[T]) -> Type[T]: + """ + A class decorator that automatically traces all methods in a protocol/base class + and its inheriting classes. + """ + + def trace_method(method: Callable) -> Callable: + is_async = asyncio.iscoroutinefunction(method) + is_async_gen = inspect.isasyncgenfunction(method) + + def create_span_context(self: Any, *args: Any, **kwargs: Any) -> tuple: + class_name = self.__class__.__name__ + method_name = method.__name__ + + span_type = ( + "async_generator" if is_async_gen else "async" if is_async else "sync" + ) + span_attributes = { + "class": class_name, + "method": method_name, + "type": span_type, + "args": serialize_value(args), + } + + return class_name, method_name, span_attributes + + @wraps(method) + async def async_gen_wrapper( + self: Any, *args: Any, **kwargs: Any + ) -> AsyncGenerator: + class_name, method_name, span_attributes = create_span_context( + self, *args, **kwargs + ) + + with tracing.span(f"{class_name}.{method_name}", span_attributes) as span: + try: + count = 0 + async for item in method(self, *args, **kwargs): + yield item + count += 1 + finally: + span.set_attribute("chunk_count", count) + + @wraps(method) + async def async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + class_name, method_name, span_attributes = create_span_context( + self, *args, **kwargs + ) + + with tracing.span(f"{class_name}.{method_name}", span_attributes) as span: + try: + result = await method(self, *args, **kwargs) + span.set_attribute("output", serialize_value(result)) + return result + except Exception as e: + span.set_attribute("error", str(e)) + raise + + @wraps(method) + def sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + class_name, method_name, span_attributes = create_span_context( + self, *args, **kwargs + ) + + with tracing.span(f"{class_name}.{method_name}", span_attributes) as span: + try: + result = method(self, *args, **kwargs) + span.set_attribute("output", serialize_value(result)) + return result + except Exception as e: + raise + + if is_async_gen: + return async_gen_wrapper + elif is_async: + return async_wrapper + else: + return sync_wrapper + + original_init_subclass = getattr(cls, "__init_subclass__", None) + + def __init_subclass__(cls_child, **kwargs): # noqa: N807 + if original_init_subclass: + original_init_subclass(**kwargs) + + for name, method in vars(cls_child).items(): + if inspect.isfunction(method) and not name.startswith("_"): + setattr(cls_child, name, trace_method(method)) # noqa: B010 + + cls.__init_subclass__ = classmethod(__init_subclass__) + + return cls diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 8f800ad6f..7df5d3bd4 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -144,87 +144,91 @@ class ChatAgent(ShieldRunnerMixin): async def create_session(self, name: str) -> str: return await self.storage.create_session(name) - @tracing.span("create_and_execute_turn") async def create_and_execute_turn( self, request: AgentTurnCreateRequest ) -> AsyncGenerator: - assert request.stream is True, "Non-streaming not supported" + with tracing.span("create_and_execute_turn") as span: + span.set_attribute("session_id", request.session_id) + span.set_attribute("agent_id", self.agent_id) + span.set_attribute("request", request.model_dump_json()) + assert request.stream is True, "Non-streaming not supported" - session_info = await self.storage.get_session_info(request.session_id) - if session_info is None: - raise ValueError(f"Session {request.session_id} not found") + session_info = await self.storage.get_session_info(request.session_id) + if session_info is None: + raise ValueError(f"Session {request.session_id} not found") - turns = await self.storage.get_session_turns(request.session_id) + turns = await self.storage.get_session_turns(request.session_id) - messages = [] - if self.agent_config.instructions != "": - messages.append(SystemMessage(content=self.agent_config.instructions)) + messages = [] + if self.agent_config.instructions != "": + messages.append(SystemMessage(content=self.agent_config.instructions)) - for i, turn in enumerate(turns): - messages.extend(self.turn_to_messages(turn)) + for i, turn in enumerate(turns): + messages.extend(self.turn_to_messages(turn)) - messages.extend(request.messages) + messages.extend(request.messages) - turn_id = str(uuid.uuid4()) - start_time = datetime.now() - yield AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseTurnStartPayload( - turn_id=turn_id, + turn_id = str(uuid.uuid4()) + span.set_attribute("turn_id", turn_id) + start_time = datetime.now() + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseTurnStartPayload( + turn_id=turn_id, + ) ) ) - ) - steps = [] - output_message = None - async for chunk in self.run( - session_id=request.session_id, - turn_id=turn_id, - input_messages=messages, - attachments=request.attachments or [], - sampling_params=self.agent_config.sampling_params, - stream=request.stream, - ): - if isinstance(chunk, CompletionMessage): - log.info( - f"{chunk.role.capitalize()}: {chunk.content}", - ) - output_message = chunk - continue - - assert isinstance( - chunk, AgentTurnResponseStreamChunk - ), f"Unexpected type {type(chunk)}" - event = chunk.event - if ( - event.payload.event_type - == AgentTurnResponseEventType.step_complete.value + steps = [] + output_message = None + async for chunk in self.run( + session_id=request.session_id, + turn_id=turn_id, + input_messages=messages, + attachments=request.attachments or [], + sampling_params=self.agent_config.sampling_params, + stream=request.stream, ): - steps.append(event.payload.step_details) + if isinstance(chunk, CompletionMessage): + log.info( + f"{chunk.role.capitalize()}: {chunk.content}", + ) + output_message = chunk + continue - yield chunk + assert isinstance( + chunk, AgentTurnResponseStreamChunk + ), f"Unexpected type {type(chunk)}" + event = chunk.event + if ( + event.payload.event_type + == AgentTurnResponseEventType.step_complete.value + ): + steps.append(event.payload.step_details) - assert output_message is not None + yield chunk - turn = Turn( - turn_id=turn_id, - session_id=request.session_id, - input_messages=request.messages, - output_message=output_message, - started_at=start_time, - completed_at=datetime.now(), - steps=steps, - ) - await self.storage.add_turn_to_session(request.session_id, turn) + assert output_message is not None - chunk = AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseTurnCompletePayload( - turn=turn, + turn = Turn( + turn_id=turn_id, + session_id=request.session_id, + input_messages=request.messages, + output_message=output_message, + started_at=start_time, + completed_at=datetime.now(), + steps=steps, + ) + await self.storage.add_turn_to_session(request.session_id, turn) + + chunk = AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseTurnCompletePayload( + turn=turn, + ) ) ) - ) - yield chunk + yield chunk async def run( self, @@ -273,7 +277,6 @@ class ChatAgent(ShieldRunnerMixin): yield final_response - @tracing.span("run_shields") async def run_multiple_shields_wrapper( self, turn_id: str, @@ -281,23 +284,47 @@ class ChatAgent(ShieldRunnerMixin): shields: List[str], touchpoint: str, ) -> AsyncGenerator: - if len(shields) == 0: - return + with tracing.span("run_shields") as span: + span.set_attribute("turn_id", turn_id) + span.set_attribute("input", [m.model_dump_json() for m in messages]) + if len(shields) == 0: + span.set_attribute("output", "no shields") + return - step_id = str(uuid.uuid4()) - try: - yield AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseStepStartPayload( - step_type=StepType.shield_call.value, - step_id=step_id, - metadata=dict(touchpoint=touchpoint), + step_id = str(uuid.uuid4()) + try: + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepStartPayload( + step_type=StepType.shield_call.value, + step_id=step_id, + metadata=dict(touchpoint=touchpoint), + ) ) ) - ) - await self.run_multiple_shields(messages, shields) + await self.run_multiple_shields(messages, shields) + + except SafetyException as e: + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepCompletePayload( + step_type=StepType.shield_call.value, + step_details=ShieldCallStep( + step_id=step_id, + turn_id=turn_id, + violation=e.violation, + ), + ) + ) + ) + span.set_attribute("output", e.violation.model_dump_json()) + + yield CompletionMessage( + content=str(e), + stop_reason=StopReason.end_of_turn, + ) + yield False - except SafetyException as e: yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepCompletePayload( @@ -305,30 +332,12 @@ class ChatAgent(ShieldRunnerMixin): step_details=ShieldCallStep( step_id=step_id, turn_id=turn_id, - violation=e.violation, + violation=None, ), ) ) ) - - yield CompletionMessage( - content=str(e), - stop_reason=StopReason.end_of_turn, - ) - yield False - - yield AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseStepCompletePayload( - step_type=StepType.shield_call.value, - step_details=ShieldCallStep( - step_id=step_id, - turn_id=turn_id, - violation=None, - ), - ) - ) - ) + span.set_attribute("output", "no violations") async def _run( self, @@ -356,10 +365,15 @@ class ChatAgent(ShieldRunnerMixin): # TODO: find older context from the session and either replace it # or append with a sliding window. this is really a very simplistic implementation - with tracing.span("retrieve_rag_context"): + with tracing.span("retrieve_rag_context") as span: rag_context, bank_ids = await self._retrieve_context( session_id, input_messages, attachments ) + span.set_attribute( + "input", [m.model_dump_json() for m in input_messages] + ) + span.set_attribute("output", rag_context) + span.set_attribute("bank_ids", bank_ids) step_id = str(uuid.uuid4()) yield AgentTurnResponseStreamChunk( @@ -416,7 +430,7 @@ class ChatAgent(ShieldRunnerMixin): content = "" stop_reason = None - with tracing.span("inference"): + with tracing.span("inference") as span: async for chunk in await self.inference_api.chat_completion( self.agent_config.model, input_messages, @@ -436,7 +450,6 @@ class ChatAgent(ShieldRunnerMixin): if isinstance(delta, ToolCallDelta): if delta.parse_status == ToolCallParseStatus.success: tool_calls.append(delta.content) - if stream: yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( @@ -466,6 +479,13 @@ class ChatAgent(ShieldRunnerMixin): if event.stop_reason is not None: stop_reason = event.stop_reason + span.set_attribute("stop_reason", stop_reason) + span.set_attribute( + "input", [m.model_dump_json() for m in input_messages] + ) + span.set_attribute( + "output", f"content: {content} tool_calls: {tool_calls}" + ) stop_reason = stop_reason or StopReason.out_of_tokens @@ -549,7 +569,13 @@ class ChatAgent(ShieldRunnerMixin): ) ) - with tracing.span("tool_execution"): + with tracing.span( + "tool_execution", + { + "tool_name": tool_call.tool_name, + "input": message.model_dump_json(), + }, + ) as span: result_messages = await execute_tool_call_maybe( self.tools_dict, [message], @@ -558,6 +584,7 @@ class ChatAgent(ShieldRunnerMixin): len(result_messages) == 1 ), "Currently not supporting multiple messages" result_message = result_messages[0] + span.set_attribute("output", result_message.model_dump_json()) yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( diff --git a/llama_stack/providers/inline/datasetio/localfs/datasetio.py b/llama_stack/providers/inline/datasetio/localfs/datasetio.py index 010610056..736e5d8b9 100644 --- a/llama_stack/providers/inline/datasetio/localfs/datasetio.py +++ b/llama_stack/providers/inline/datasetio/localfs/datasetio.py @@ -3,14 +3,17 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Optional +from typing import Any, Dict, List, Optional import pandas from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403 +import base64 +import os from abc import ABC, abstractmethod from dataclasses import dataclass +from urllib.parse import urlparse from llama_stack.providers.datatypes import DatasetsProtocolPrivate from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url @@ -131,3 +134,41 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): total_count=len(rows), next_page_token=str(end), ) + + async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: + dataset_info = self.dataset_infos.get(dataset_id) + if dataset_info is None: + raise ValueError(f"Dataset with id {dataset_id} not found") + + dataset_impl = dataset_info.dataset_impl + dataset_impl.load() + + new_rows_df = pandas.DataFrame(rows) + new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df) + dataset_impl.df = pandas.concat( + [dataset_impl.df, new_rows_df], ignore_index=True + ) + + url = str(dataset_info.dataset_def.url) + parsed_url = urlparse(url) + + if parsed_url.scheme == "file" or not parsed_url.scheme: + file_path = parsed_url.path + os.makedirs(os.path.dirname(file_path), exist_ok=True) + dataset_impl.df.to_csv(file_path, index=False) + elif parsed_url.scheme == "data": + # For data URLs, we need to update the base64-encoded content + if not parsed_url.path.startswith("text/csv;base64,"): + raise ValueError("Data URL must be a base64-encoded CSV") + + csv_buffer = dataset_impl.df.to_csv(index=False) + base64_content = base64.b64encode(csv_buffer.encode("utf-8")).decode( + "utf-8" + ) + dataset_info.dataset_def.url = URL( + uri=f"data:text/csv;base64,{base64_content}" + ) + else: + raise ValueError( + f"Unsupported URL scheme: {parsed_url.scheme}. Only file:// and data: URLs are supported for writing." + ) diff --git a/llama_stack/providers/inline/meta_reference/telemetry/__init__.py b/llama_stack/providers/inline/meta_reference/telemetry/__init__.py deleted file mode 100644 index 4a0c2f6ee..000000000 --- a/llama_stack/providers/inline/meta_reference/telemetry/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from .config import ConsoleConfig - - -async def get_provider_impl(config: ConsoleConfig, _deps): - from .console import ConsoleTelemetryImpl - - impl = ConsoleTelemetryImpl(config) - await impl.initialize() - return impl diff --git a/llama_stack/providers/inline/meta_reference/telemetry/config.py b/llama_stack/providers/inline/meta_reference/telemetry/config.py deleted file mode 100644 index a1db1d4d8..000000000 --- a/llama_stack/providers/inline/meta_reference/telemetry/config.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from enum import Enum - -from llama_models.schema_utils import json_schema_type - -from pydantic import BaseModel - - -class LogFormat(Enum): - TEXT = "text" - JSON = "json" - - -@json_schema_type -class ConsoleConfig(BaseModel): - log_format: LogFormat = LogFormat.TEXT diff --git a/llama_stack/providers/inline/meta_reference/telemetry/console.py b/llama_stack/providers/inline/meta_reference/telemetry/console.py index d8ef49481..838aaa4e1 100644 --- a/llama_stack/providers/inline/meta_reference/telemetry/console.py +++ b/llama_stack/providers/inline/meta_reference/telemetry/console.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import json -from typing import Optional +from typing import List, Optional from .config import LogFormat @@ -49,8 +49,27 @@ class ConsoleTelemetryImpl(Telemetry): if formatted: print(formatted) - async def get_trace(self, trace_id: str) -> Trace: - raise NotImplementedError() + async def query_traces( + self, + attribute_conditions: Optional[List[QueryCondition]] = None, + attribute_keys_to_return: Optional[List[str]] = None, + limit: Optional[int] = 100, + offset: Optional[int] = 0, + order_by: Optional[List[str]] = None, + ) -> List[Trace]: + raise NotImplementedError("Console telemetry does not support trace querying") + + async def get_spans( + self, + span_id: str, + attribute_conditions: Optional[List[QueryCondition]] = None, + attribute_keys_to_return: Optional[List[str]] = None, + max_depth: Optional[int] = None, + limit: Optional[int] = 100, + offset: Optional[int] = 0, + order_by: Optional[List[str]] = None, + ) -> SpanWithChildren: + raise NotImplementedError("Console telemetry does not support span querying") COLORS = { diff --git a/llama_stack/providers/remote/telemetry/__init__.py b/llama_stack/providers/inline/telemetry/__init__.py similarity index 100% rename from llama_stack/providers/remote/telemetry/__init__.py rename to llama_stack/providers/inline/telemetry/__init__.py diff --git a/llama_stack/providers/inline/telemetry/meta_reference/__init__.py b/llama_stack/providers/inline/telemetry/meta_reference/__init__.py new file mode 100644 index 000000000..6213d5536 --- /dev/null +++ b/llama_stack/providers/inline/telemetry/meta_reference/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict + +from .config import TelemetryConfig, TelemetrySink +from .telemetry import TelemetryAdapter + +__all__ = ["TelemetryConfig", "TelemetryAdapter", "TelemetrySink"] + + +async def get_provider_impl(config: TelemetryConfig, deps: Dict[str, Any]): + impl = TelemetryAdapter(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/inline/telemetry/meta_reference/config.py b/llama_stack/providers/inline/telemetry/meta_reference/config.py new file mode 100644 index 000000000..0230d24d2 --- /dev/null +++ b/llama_stack/providers/inline/telemetry/meta_reference/config.py @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from enum import Enum +from typing import Any, Dict, List + +from pydantic import BaseModel, Field + +from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR + + +class TelemetrySink(str, Enum): + JAEGER = "jaeger" + SQLITE = "sqlite" + CONSOLE = "console" + + +class TelemetryConfig(BaseModel): + otel_endpoint: str = Field( + default="http://localhost:4318/v1/traces", + description="The OpenTelemetry collector endpoint URL", + ) + service_name: str = Field( + default="llama-stack", + description="The service name to use for telemetry", + ) + sinks: List[TelemetrySink] = Field( + default=[TelemetrySink.CONSOLE, TelemetrySink.SQLITE], + description="List of telemetry sinks to enable (possible values: jaeger, sqlite, console)", + ) + sqlite_db_path: str = Field( + default=(RUNTIME_BASE_DIR / "trace_store.db").as_posix(), + description="The path to the SQLite database to use for storing traces", + ) + + @classmethod + def sample_run_config(cls, **kwargs) -> Dict[str, Any]: + return { + "service_name": "${env.OTEL_SERVICE_NAME:llama-stack}", + "sinks": "${env.TELEMETRY_SINKS:['console', 'sqlite']}", + "sqlite_db_path": "${env.SQLITE_DB_PATH:${runtime.base_dir}/trace_store.db}", + } diff --git a/llama_stack/providers/inline/telemetry/meta_reference/console_span_processor.py b/llama_stack/providers/inline/telemetry/meta_reference/console_span_processor.py new file mode 100644 index 000000000..8d6f779e6 --- /dev/null +++ b/llama_stack/providers/inline/telemetry/meta_reference/console_span_processor.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from datetime import datetime + +from opentelemetry.sdk.trace import ReadableSpan +from opentelemetry.sdk.trace.export import SpanProcessor + +# Colors for console output +COLORS = { + "reset": "\033[0m", + "bold": "\033[1m", + "dim": "\033[2m", + "red": "\033[31m", + "green": "\033[32m", + "yellow": "\033[33m", + "blue": "\033[34m", + "magenta": "\033[35m", + "cyan": "\033[36m", + "white": "\033[37m", +} + + +class ConsoleSpanProcessor(SpanProcessor): + """A SpanProcessor that prints spans to the console with color formatting.""" + + def on_start(self, span: ReadableSpan, parent_context=None) -> None: + """Called when a span starts.""" + timestamp = datetime.utcfromtimestamp(span.start_time / 1e9).strftime( + "%H:%M:%S.%f" + )[:-3] + + print( + f"{COLORS['dim']}{timestamp}{COLORS['reset']} " + f"{COLORS['magenta']}[START]{COLORS['reset']} " + f"{COLORS['cyan']}{span.name}{COLORS['reset']}" + ) + + def on_end(self, span: ReadableSpan) -> None: + """Called when a span ends.""" + timestamp = datetime.utcfromtimestamp(span.end_time / 1e9).strftime( + "%H:%M:%S.%f" + )[:-3] + + # Build the span context string + span_context = ( + f"{COLORS['dim']}{timestamp}{COLORS['reset']} " + f"{COLORS['magenta']}[END]{COLORS['reset']} " + f"{COLORS['cyan']}{span.name}{COLORS['reset']} " + ) + + # Add status if not OK + if span.status.status_code != 0: # UNSET or ERROR + status_color = ( + COLORS["red"] if span.status.status_code == 2 else COLORS["yellow"] + ) + span_context += ( + f" {status_color}[{span.status.status_code}]{COLORS['reset']}" + ) + + # Add duration + duration_ms = (span.end_time - span.start_time) / 1e6 + span_context += f" {COLORS['dim']}({duration_ms:.2f}ms){COLORS['reset']}" + + # Print the main span line + print(span_context) + + # Print attributes indented + if span.attributes: + for key, value in span.attributes.items(): + print(f" {COLORS['dim']}{key}: {value}{COLORS['reset']}") + + # Print events indented + for event in span.events: + event_time = datetime.utcfromtimestamp(event.timestamp / 1e9).strftime( + "%H:%M:%S.%f" + )[:-3] + print( + f" {COLORS['dim']}{event_time}{COLORS['reset']} " + f"{COLORS['cyan']}[EVENT]{COLORS['reset']} {event.name}" + ) + if event.attributes: + for key, value in event.attributes.items(): + print(f" {COLORS['dim']}{key}: {value}{COLORS['reset']}") + + def shutdown(self) -> None: + """Shutdown the processor.""" + pass + + def force_flush(self, timeout_millis: float = None) -> bool: + """Force flush any pending spans.""" + return True diff --git a/llama_stack/providers/inline/telemetry/meta_reference/sqlite_span_processor.py b/llama_stack/providers/inline/telemetry/meta_reference/sqlite_span_processor.py new file mode 100644 index 000000000..553dd5000 --- /dev/null +++ b/llama_stack/providers/inline/telemetry/meta_reference/sqlite_span_processor.py @@ -0,0 +1,242 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +import os +import sqlite3 +import threading +from datetime import datetime, timedelta +from typing import Dict + +from opentelemetry.sdk.trace import SpanProcessor +from opentelemetry.trace import Span + + +class SQLiteSpanProcessor(SpanProcessor): + def __init__(self, conn_string, ttl_days=30): + """Initialize the SQLite span processor with a connection string.""" + self.conn_string = conn_string + self.ttl_days = ttl_days + self.cleanup_task = None + self._thread_local = threading.local() + self._connections: Dict[int, sqlite3.Connection] = {} + self._lock = threading.Lock() + self.setup_database() + + def _get_connection(self) -> sqlite3.Connection: + """Get a thread-specific database connection.""" + thread_id = threading.get_ident() + with self._lock: + if thread_id not in self._connections: + conn = sqlite3.connect(self.conn_string) + self._connections[thread_id] = conn + return self._connections[thread_id] + + def setup_database(self): + """Create the necessary tables if they don't exist.""" + # Create directory if it doesn't exist + os.makedirs(os.path.dirname(self.conn_string), exist_ok=True) + + conn = self._get_connection() + cursor = conn.cursor() + + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS traces ( + trace_id TEXT PRIMARY KEY, + service_name TEXT, + root_span_id TEXT, + start_time TIMESTAMP, + end_time TIMESTAMP, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """ + ) + + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS spans ( + span_id TEXT PRIMARY KEY, + trace_id TEXT REFERENCES traces(trace_id), + parent_span_id TEXT, + name TEXT, + start_time TIMESTAMP, + end_time TIMESTAMP, + attributes TEXT, + status TEXT, + kind TEXT + ) + """ + ) + + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS span_events ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + span_id TEXT REFERENCES spans(span_id), + name TEXT, + timestamp TIMESTAMP, + attributes TEXT + ) + """ + ) + + cursor.execute( + """ + CREATE INDEX IF NOT EXISTS idx_traces_created_at + ON traces(created_at) + """ + ) + + conn.commit() + cursor.close() + + # Start periodic cleanup in a separate thread + self.cleanup_task = threading.Thread(target=self._periodic_cleanup, daemon=True) + self.cleanup_task.start() + + def _cleanup_old_data(self): + """Delete records older than TTL.""" + try: + conn = self._get_connection() + cutoff_date = (datetime.now() - timedelta(days=self.ttl_days)).isoformat() + cursor = conn.cursor() + + # Delete old span events + cursor.execute( + """ + DELETE FROM span_events + WHERE span_id IN ( + SELECT span_id FROM spans + WHERE trace_id IN ( + SELECT trace_id FROM traces + WHERE created_at < ? + ) + ) + """, + (cutoff_date,), + ) + + # Delete old spans + cursor.execute( + """ + DELETE FROM spans + WHERE trace_id IN ( + SELECT trace_id FROM traces + WHERE created_at < ? + ) + """, + (cutoff_date,), + ) + + # Delete old traces + cursor.execute("DELETE FROM traces WHERE created_at < ?", (cutoff_date,)) + + conn.commit() + cursor.close() + except Exception as e: + print(f"Error during cleanup: {e}") + + def _periodic_cleanup(self): + """Run cleanup periodically.""" + import time + + while True: + time.sleep(3600) # Sleep for 1 hour + self._cleanup_old_data() + + def on_start(self, span: Span, parent_context=None): + """Called when a span starts.""" + pass + + def on_end(self, span: Span): + """Called when a span ends. Export the span data to SQLite.""" + try: + conn = self._get_connection() + cursor = conn.cursor() + + trace_id = format(span.get_span_context().trace_id, "032x") + span_id = format(span.get_span_context().span_id, "016x") + service_name = span.resource.attributes.get("service.name", "unknown") + + parent_span_id = None + parent_context = span.parent + if parent_context: + parent_span_id = format(parent_context.span_id, "016x") + + # Insert into traces + cursor.execute( + """ + INSERT INTO traces ( + trace_id, service_name, root_span_id, start_time, end_time + ) VALUES (?, ?, ?, ?, ?) + ON CONFLICT(trace_id) DO UPDATE SET + root_span_id = COALESCE(root_span_id, excluded.root_span_id), + start_time = MIN(excluded.start_time, start_time), + end_time = MAX(excluded.end_time, end_time) + """, + ( + trace_id, + service_name, + (span_id if not parent_span_id else None), + datetime.fromtimestamp(span.start_time / 1e9).isoformat(), + datetime.fromtimestamp(span.end_time / 1e9).isoformat(), + ), + ) + + # Insert into spans + cursor.execute( + """ + INSERT INTO spans ( + span_id, trace_id, parent_span_id, name, + start_time, end_time, attributes, status, + kind + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + span_id, + trace_id, + parent_span_id, + span.name, + datetime.fromtimestamp(span.start_time / 1e9).isoformat(), + datetime.fromtimestamp(span.end_time / 1e9).isoformat(), + json.dumps(dict(span.attributes)), + span.status.status_code.name, + span.kind.name, + ), + ) + + for event in span.events: + cursor.execute( + """ + INSERT INTO span_events ( + span_id, name, timestamp, attributes + ) VALUES (?, ?, ?, ?) + """, + ( + span_id, + event.name, + datetime.fromtimestamp(event.timestamp / 1e9).isoformat(), + json.dumps(dict(event.attributes)), + ), + ) + + conn.commit() + cursor.close() + except Exception as e: + print(f"Error exporting span to SQLite: {e}") + + def shutdown(self): + """Cleanup any resources.""" + with self._lock: + for conn in self._connections.values(): + if conn: + conn.close() + self._connections.clear() + + def force_flush(self, timeout_millis=30000): + """Force export of spans.""" + pass diff --git a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py new file mode 100644 index 000000000..6540a667f --- /dev/null +++ b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py @@ -0,0 +1,247 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import threading +from typing import List, Optional + +from opentelemetry import metrics, trace +from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry.sdk.metrics import MeterProvider +from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor +from opentelemetry.semconv.resource import ResourceAttributes + +from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import ( + ConsoleSpanProcessor, +) + +from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor import ( + SQLiteSpanProcessor, +) +from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore + +from llama_stack.apis.telemetry import * # noqa: F403 + +from .config import TelemetryConfig, TelemetrySink + +_GLOBAL_STORAGE = { + "active_spans": {}, + "counters": {}, + "gauges": {}, + "up_down_counters": {}, +} +_global_lock = threading.Lock() + + +def string_to_trace_id(s: str) -> int: + # Convert the string to bytes and then to an integer + return int.from_bytes(s.encode(), byteorder="big", signed=False) + + +def string_to_span_id(s: str) -> int: + # Use only the first 8 bytes (64 bits) for span ID + return int.from_bytes(s.encode()[:8], byteorder="big", signed=False) + + +def is_tracing_enabled(tracer): + with tracer.start_as_current_span("check_tracing") as span: + return span.is_recording() + + +class TelemetryAdapter(Telemetry): + def __init__(self, config: TelemetryConfig) -> None: + self.config = config + + resource = Resource.create( + { + ResourceAttributes.SERVICE_NAME: self.config.service_name, + } + ) + + provider = TracerProvider(resource=resource) + trace.set_tracer_provider(provider) + if TelemetrySink.JAEGER in self.config.sinks: + otlp_exporter = OTLPSpanExporter( + endpoint=self.config.otel_endpoint, + ) + span_processor = BatchSpanProcessor(otlp_exporter) + trace.get_tracer_provider().add_span_processor(span_processor) + metric_reader = PeriodicExportingMetricReader( + OTLPMetricExporter( + endpoint=self.config.otel_endpoint, + ) + ) + metric_provider = MeterProvider( + resource=resource, metric_readers=[metric_reader] + ) + metrics.set_meter_provider(metric_provider) + self.meter = metrics.get_meter(__name__) + if TelemetrySink.SQLITE in self.config.sinks: + trace.get_tracer_provider().add_span_processor( + SQLiteSpanProcessor(self.config.sqlite_db_path) + ) + self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path) + if TelemetrySink.CONSOLE in self.config.sinks: + trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor()) + self._lock = _global_lock + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + trace.get_tracer_provider().force_flush() + trace.get_tracer_provider().shutdown() + metrics.get_meter_provider().shutdown() + + async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None: + if isinstance(event, UnstructuredLogEvent): + self._log_unstructured(event, ttl_seconds) + elif isinstance(event, MetricEvent): + self._log_metric(event) + elif isinstance(event, StructuredLogEvent): + self._log_structured(event, ttl_seconds) + else: + raise ValueError(f"Unknown event type: {event}") + + def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None: + with self._lock: + # Use global storage instead of instance storage + span_id = string_to_span_id(event.span_id) + span = _GLOBAL_STORAGE["active_spans"].get(span_id) + + if span: + timestamp_ns = int(event.timestamp.timestamp() * 1e9) + span.add_event( + name=event.type, + attributes={ + "message": event.message, + "severity": event.severity.value, + "__ttl__": ttl_seconds, + **event.attributes, + }, + timestamp=timestamp_ns, + ) + else: + print( + f"Warning: No active span found for span_id {span_id}. Dropping event: {event}" + ) + + def _get_or_create_counter(self, name: str, unit: str) -> metrics.Counter: + if name not in _GLOBAL_STORAGE["counters"]: + _GLOBAL_STORAGE["counters"][name] = self.meter.create_counter( + name=name, + unit=unit, + description=f"Counter for {name}", + ) + return _GLOBAL_STORAGE["counters"][name] + + def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge: + if name not in _GLOBAL_STORAGE["gauges"]: + _GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge( + name=name, + unit=unit, + description=f"Gauge for {name}", + ) + return _GLOBAL_STORAGE["gauges"][name] + + def _log_metric(self, event: MetricEvent) -> None: + if isinstance(event.value, int): + counter = self._get_or_create_counter(event.metric, event.unit) + counter.add(event.value, attributes=event.attributes) + elif isinstance(event.value, float): + up_down_counter = self._get_or_create_up_down_counter( + event.metric, event.unit + ) + up_down_counter.add(event.value, attributes=event.attributes) + + def _get_or_create_up_down_counter( + self, name: str, unit: str + ) -> metrics.UpDownCounter: + if name not in _GLOBAL_STORAGE["up_down_counters"]: + _GLOBAL_STORAGE["up_down_counters"][name] = ( + self.meter.create_up_down_counter( + name=name, + unit=unit, + description=f"UpDownCounter for {name}", + ) + ) + return _GLOBAL_STORAGE["up_down_counters"][name] + + def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None: + with self._lock: + span_id = string_to_span_id(event.span_id) + trace_id = string_to_trace_id(event.trace_id) + tracer = trace.get_tracer(__name__) + if event.attributes is None: + event.attributes = {} + event.attributes["__ttl__"] = ttl_seconds + + if isinstance(event.payload, SpanStartPayload): + # Check if span already exists to prevent duplicates + if span_id in _GLOBAL_STORAGE["active_spans"]: + return + + parent_span = None + if event.payload.parent_span_id: + parent_span_id = string_to_span_id(event.payload.parent_span_id) + parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id) + + context = trace.Context(trace_id=trace_id) + if parent_span: + context = trace.set_span_in_context(parent_span, context) + + span = tracer.start_span( + name=event.payload.name, + context=context, + attributes=event.attributes or {}, + ) + _GLOBAL_STORAGE["active_spans"][span_id] = span + + elif isinstance(event.payload, SpanEndPayload): + span = _GLOBAL_STORAGE["active_spans"].get(span_id) + if span: + if event.attributes: + span.set_attributes(event.attributes) + + status = ( + trace.Status(status_code=trace.StatusCode.OK) + if event.payload.status == SpanStatus.OK + else trace.Status(status_code=trace.StatusCode.ERROR) + ) + span.set_status(status) + span.end() + _GLOBAL_STORAGE["active_spans"].pop(span_id, None) + else: + raise ValueError(f"Unknown structured log event: {event}") + + async def query_traces( + self, + attribute_filters: Optional[List[QueryCondition]] = None, + limit: Optional[int] = 100, + offset: Optional[int] = 0, + order_by: Optional[List[str]] = None, + ) -> List[Trace]: + return await self.trace_store.query_traces( + attribute_filters=attribute_filters, + limit=limit, + offset=offset, + order_by=order_by, + ) + + async def get_span_tree( + self, + span_id: str, + attributes_to_return: Optional[List[str]] = None, + max_depth: Optional[int] = None, + ) -> SpanWithChildren: + return await self.trace_store.get_materialized_span( + span_id=span_id, + attributes_to_return=attributes_to_return, + max_depth=max_depth, + ) diff --git a/llama_stack/providers/remote/telemetry/sample/__init__.py b/llama_stack/providers/inline/telemetry/sample/__init__.py similarity index 100% rename from llama_stack/providers/remote/telemetry/sample/__init__.py rename to llama_stack/providers/inline/telemetry/sample/__init__.py diff --git a/llama_stack/providers/remote/telemetry/sample/config.py b/llama_stack/providers/inline/telemetry/sample/config.py similarity index 100% rename from llama_stack/providers/remote/telemetry/sample/config.py rename to llama_stack/providers/inline/telemetry/sample/config.py diff --git a/llama_stack/providers/remote/telemetry/sample/sample.py b/llama_stack/providers/inline/telemetry/sample/sample.py similarity index 100% rename from llama_stack/providers/remote/telemetry/sample/sample.py rename to llama_stack/providers/inline/telemetry/sample/sample.py diff --git a/llama_stack/providers/registry/telemetry.py b/llama_stack/providers/registry/telemetry.py index ac537e076..a53ad5b94 100644 --- a/llama_stack/providers/registry/telemetry.py +++ b/llama_stack/providers/registry/telemetry.py @@ -14,9 +14,12 @@ def available_providers() -> List[ProviderSpec]: InlineProviderSpec( api=Api.telemetry, provider_type="inline::meta-reference", - pip_packages=[], - module="llama_stack.providers.inline.meta_reference.telemetry", - config_class="llama_stack.providers.inline.meta_reference.telemetry.ConsoleConfig", + pip_packages=[ + "opentelemetry-sdk", + "opentelemetry-exporter-otlp-proto-http", + ], + module="llama_stack.providers.inline.telemetry.meta_reference", + config_class="llama_stack.providers.inline.telemetry.meta_reference.config.TelemetryConfig", ), remote_provider_spec( api=Api.telemetry, @@ -27,18 +30,4 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.remote.telemetry.sample.SampleConfig", ), ), - remote_provider_spec( - api=Api.telemetry, - adapter=AdapterSpec( - adapter_type="opentelemetry-jaeger", - pip_packages=[ - "opentelemetry-api", - "opentelemetry-sdk", - "opentelemetry-exporter-jaeger", - "opentelemetry-semantic-conventions", - ], - module="llama_stack.providers.remote.telemetry.opentelemetry", - config_class="llama_stack.providers.remote.telemetry.opentelemetry.OpenTelemetryConfig", - ), - ), ] diff --git a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py index cdd5d9cd3..db52270a7 100644 --- a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py +++ b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py @@ -3,7 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Optional +from typing import Any, Dict, List, Optional from llama_stack.apis.datasetio import * # noqa: F403 @@ -100,3 +100,22 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): total_count=len(rows), next_page_token=str(end), ) + + async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: + dataset_def = self.dataset_infos[dataset_id] + loaded_dataset = load_hf_dataset(dataset_def) + + # Convert rows to HF Dataset format + new_dataset = hf_datasets.Dataset.from_list(rows) + + # Concatenate the new rows with existing dataset + updated_dataset = hf_datasets.concatenate_datasets( + [loaded_dataset, new_dataset] + ) + + if dataset_def.metadata.get("path", None): + updated_dataset.push_to_hub(dataset_def.metadata["path"]) + else: + raise NotImplementedError( + "Uploading to URL-based datasets is not supported yet" + ) diff --git a/llama_stack/providers/remote/telemetry/opentelemetry/__init__.py b/llama_stack/providers/remote/telemetry/opentelemetry/__init__.py deleted file mode 100644 index 0842afe2d..000000000 --- a/llama_stack/providers/remote/telemetry/opentelemetry/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from .config import OpenTelemetryConfig - - -async def get_adapter_impl(config: OpenTelemetryConfig, _deps): - from .opentelemetry import OpenTelemetryAdapter - - impl = OpenTelemetryAdapter(config) - await impl.initialize() - return impl diff --git a/llama_stack/providers/remote/telemetry/opentelemetry/config.py b/llama_stack/providers/remote/telemetry/opentelemetry/config.py deleted file mode 100644 index 5e9dff1a1..000000000 --- a/llama_stack/providers/remote/telemetry/opentelemetry/config.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any, Dict - -from pydantic import BaseModel, Field - - -class OpenTelemetryConfig(BaseModel): - otel_endpoint: str = Field( - default="http://localhost:4318/v1/traces", - description="The OpenTelemetry collector endpoint URL", - ) - service_name: str = Field( - default="llama-stack", - description="The service name to use for telemetry", - ) - - @classmethod - def sample_run_config(cls, **kwargs) -> Dict[str, Any]: - return { - "otel_endpoint": "${env.OTEL_ENDPOINT:http://localhost:4318/v1/traces}", - "service_name": "${env.OTEL_SERVICE_NAME:llama-stack}", - } diff --git a/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py b/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py index c9830fd9d..04eb71ce0 100644 --- a/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py +++ b/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py @@ -5,6 +5,16 @@ # the root directory of this source tree. import threading +from typing import List, Optional + +from llama_stack.distribution.datatypes import Api +from llama_stack.providers.remote.telemetry.opentelemetry.console_span_processor import ( + ConsoleSpanProcessor, +) +from llama_stack.providers.remote.telemetry.opentelemetry.sqlite_span_processor import ( + SQLiteSpanProcessor, +) +from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore from opentelemetry import metrics, trace from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter @@ -19,7 +29,7 @@ from opentelemetry.semconv.resource import ResourceAttributes from llama_stack.apis.telemetry import * # noqa: F403 -from .config import OpenTelemetryConfig +from .config import OpenTelemetryConfig, TelemetrySink _GLOBAL_STORAGE = { "active_spans": {}, @@ -46,8 +56,9 @@ def is_tracing_enabled(tracer): class OpenTelemetryAdapter(Telemetry): - def __init__(self, config: OpenTelemetryConfig): + def __init__(self, config: OpenTelemetryConfig, deps) -> None: self.config = config + self.datasetio = deps[Api.datasetio] resource = Resource.create( { @@ -57,22 +68,29 @@ class OpenTelemetryAdapter(Telemetry): provider = TracerProvider(resource=resource) trace.set_tracer_provider(provider) - otlp_exporter = OTLPSpanExporter( - endpoint=self.config.otel_endpoint, - ) - span_processor = BatchSpanProcessor(otlp_exporter) - trace.get_tracer_provider().add_span_processor(span_processor) - # Set up metrics - metric_reader = PeriodicExportingMetricReader( - OTLPMetricExporter( + if TelemetrySink.JAEGER in self.config.sinks: + otlp_exporter = OTLPSpanExporter( endpoint=self.config.otel_endpoint, ) - ) - metric_provider = MeterProvider( - resource=resource, metric_readers=[metric_reader] - ) - metrics.set_meter_provider(metric_provider) - self.meter = metrics.get_meter(__name__) + span_processor = BatchSpanProcessor(otlp_exporter) + trace.get_tracer_provider().add_span_processor(span_processor) + metric_reader = PeriodicExportingMetricReader( + OTLPMetricExporter( + endpoint=self.config.otel_endpoint, + ) + ) + metric_provider = MeterProvider( + resource=resource, metric_readers=[metric_reader] + ) + metrics.set_meter_provider(metric_provider) + self.meter = metrics.get_meter(__name__) + if TelemetrySink.SQLITE in self.config.sinks: + trace.get_tracer_provider().add_span_processor( + SQLiteSpanProcessor(self.config.sqlite_db_path) + ) + self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path) + if TelemetrySink.CONSOLE in self.config.sinks: + trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor()) self._lock = _global_lock async def initialize(self) -> None: @@ -83,15 +101,17 @@ class OpenTelemetryAdapter(Telemetry): trace.get_tracer_provider().shutdown() metrics.get_meter_provider().shutdown() - async def log_event(self, event: Event) -> None: + async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None: if isinstance(event, UnstructuredLogEvent): - self._log_unstructured(event) + self._log_unstructured(event, ttl_seconds) elif isinstance(event, MetricEvent): self._log_metric(event) elif isinstance(event, StructuredLogEvent): - self._log_structured(event) + self._log_structured(event, ttl_seconds) + else: + raise ValueError(f"Unknown event type: {event}") - def _log_unstructured(self, event: UnstructuredLogEvent) -> None: + def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None: with self._lock: # Use global storage instead of instance storage span_id = string_to_span_id(event.span_id) @@ -104,6 +124,7 @@ class OpenTelemetryAdapter(Telemetry): attributes={ "message": event.message, "severity": event.severity.value, + "__ttl__": ttl_seconds, **event.attributes, }, timestamp=timestamp_ns, @@ -154,11 +175,14 @@ class OpenTelemetryAdapter(Telemetry): ) return _GLOBAL_STORAGE["up_down_counters"][name] - def _log_structured(self, event: StructuredLogEvent) -> None: + def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None: with self._lock: span_id = string_to_span_id(event.span_id) trace_id = string_to_trace_id(event.trace_id) tracer = trace.get_tracer(__name__) + if event.attributes is None: + event.attributes = {} + event.attributes["__ttl__"] = ttl_seconds if isinstance(event.payload, SpanStartPayload): # Check if span already exists to prevent duplicates @@ -170,7 +194,6 @@ class OpenTelemetryAdapter(Telemetry): parent_span_id = string_to_span_id(event.payload.parent_span_id) parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id) - # Create a new trace context with the trace_id context = trace.Context(trace_id=trace_id) if parent_span: context = trace.set_span_in_context(parent_span, context) @@ -179,14 +202,9 @@ class OpenTelemetryAdapter(Telemetry): name=event.payload.name, context=context, attributes=event.attributes or {}, - start_time=int(event.timestamp.timestamp() * 1e9), ) _GLOBAL_STORAGE["active_spans"][span_id] = span - # Set as current span using context manager - with trace.use_span(span, end_on_exit=False): - pass # Let the span continue beyond this block - elif isinstance(event.payload, SpanEndPayload): span = _GLOBAL_STORAGE["active_spans"].get(span_id) if span: @@ -199,10 +217,43 @@ class OpenTelemetryAdapter(Telemetry): else trace.Status(status_code=trace.StatusCode.ERROR) ) span.set_status(status) - span.end(end_time=int(event.timestamp.timestamp() * 1e9)) - - # Remove from active spans + span.end() _GLOBAL_STORAGE["active_spans"].pop(span_id, None) + else: + raise ValueError(f"Unknown structured log event: {event}") - async def get_trace(self, trace_id: str) -> Trace: - raise NotImplementedError("Trace retrieval not implemented yet") + async def query_traces( + self, + attribute_conditions: Optional[List[QueryCondition]] = None, + attribute_keys_to_return: Optional[List[str]] = None, + limit: Optional[int] = 100, + offset: Optional[int] = 0, + order_by: Optional[List[str]] = None, + ) -> List[Trace]: + return await self.trace_store.query_traces( + attribute_conditions=attribute_conditions, + attribute_keys_to_return=attribute_keys_to_return, + limit=limit, + offset=offset, + order_by=order_by, + ) + + async def get_spans( + self, + span_id: str, + attribute_conditions: Optional[List[QueryCondition]] = None, + attribute_keys_to_return: Optional[List[str]] = None, + max_depth: Optional[int] = None, + limit: Optional[int] = 100, + offset: Optional[int] = 0, + order_by: Optional[List[str]] = None, + ) -> SpanWithChildren: + return await self.trace_store.get_spans( + span_id=span_id, + attribute_conditions=attribute_conditions, + attribute_keys_to_return=attribute_keys_to_return, + max_depth=max_depth, + limit=limit, + offset=offset, + order_by=order_by, + ) diff --git a/llama_stack/providers/utils/telemetry/sqlite.py b/llama_stack/providers/utils/telemetry/sqlite.py new file mode 100644 index 000000000..e7161fffa --- /dev/null +++ b/llama_stack/providers/utils/telemetry/sqlite.py @@ -0,0 +1,177 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +from datetime import datetime +from typing import List, Optional + +import aiosqlite + +from llama_stack.apis.telemetry import ( + QueryCondition, + SpanWithChildren, + Trace, + TraceStore, +) + + +class SQLiteTraceStore(TraceStore): + def __init__(self, conn_string: str): + self.conn_string = conn_string + + async def query_traces( + self, + attribute_filters: Optional[List[QueryCondition]] = None, + attributes_to_return: Optional[List[str]] = None, + limit: Optional[int] = 100, + offset: Optional[int] = 0, + order_by: Optional[List[str]] = None, + ) -> List[Trace]: + print(attribute_filters, attributes_to_return, limit, offset, order_by) + + def build_attribute_select() -> str: + if not attributes_to_return: + return "" + return "".join( + f", json_extract(s.attributes, '$.{key}') as attr_{key}" + for key in attributes_to_return + ) + + def build_where_clause() -> tuple[str, list]: + if not attribute_filters: + return "", [] + + conditions = [ + f"json_extract(s.attributes, '$.{condition.key}') {condition.op} ?" + for condition in attribute_filters + ] + params = [condition.value for condition in attribute_filters] + where_clause = " WHERE " + " AND ".join(conditions) + return where_clause, params + + def build_order_clause() -> str: + if not order_by: + return "" + + order_clauses = [] + for field in order_by: + desc = field.startswith("-") + clean_field = field[1:] if desc else field + order_clauses.append(f"t.{clean_field} {'DESC' if desc else 'ASC'}") + return " ORDER BY " + ", ".join(order_clauses) + + # Build the main query + base_query = """ + WITH matching_traces AS ( + SELECT DISTINCT t.trace_id + FROM traces t + JOIN spans s ON t.trace_id = s.trace_id + {where_clause} + ), + filtered_traces AS ( + SELECT t.trace_id, t.root_span_id, t.start_time, t.end_time + {attribute_select} + FROM matching_traces mt + JOIN traces t ON mt.trace_id = t.trace_id + LEFT JOIN spans s ON t.trace_id = s.trace_id + {order_clause} + ) + SELECT DISTINCT trace_id, root_span_id, start_time, end_time + FROM filtered_traces + LIMIT {limit} OFFSET {offset} + """ + + where_clause, params = build_where_clause() + query = base_query.format( + attribute_select=build_attribute_select(), + where_clause=where_clause, + order_clause=build_order_clause(), + limit=limit, + offset=offset, + ) + + # Execute query and return results + async with aiosqlite.connect(self.conn_string) as conn: + conn.row_factory = aiosqlite.Row + async with conn.execute(query, params) as cursor: + rows = await cursor.fetchall() + return [ + Trace( + trace_id=row["trace_id"], + root_span_id=row["root_span_id"], + start_time=datetime.fromisoformat(row["start_time"]), + end_time=datetime.fromisoformat(row["end_time"]), + ) + for row in rows + ] + + async def get_materialized_span( + self, + span_id: str, + attributes_to_return: Optional[List[str]] = None, + max_depth: Optional[int] = None, + ) -> SpanWithChildren: + # Build the attributes selection + attributes_select = "s.attributes" + if attributes_to_return: + json_object = ", ".join( + f"'{key}', json_extract(s.attributes, '$.{key}')" + for key in attributes_to_return + ) + attributes_select = f"json_object({json_object})" + + # SQLite CTE query with filtered attributes + query = f""" + WITH RECURSIVE span_tree AS ( + SELECT s.*, 1 as depth, {attributes_select} as filtered_attributes + FROM spans s + WHERE s.span_id = ? + + UNION ALL + + SELECT s.*, st.depth + 1, {attributes_select} as filtered_attributes + FROM spans s + JOIN span_tree st ON s.parent_span_id = st.span_id + WHERE (? IS NULL OR st.depth < ?) + ) + SELECT * + FROM span_tree + ORDER BY depth, start_time + """ + + async with aiosqlite.connect(self.conn_string) as conn: + conn.row_factory = aiosqlite.Row + async with conn.execute(query, (span_id, max_depth, max_depth)) as cursor: + rows = await cursor.fetchall() + + if not rows: + raise ValueError(f"Span {span_id} not found") + + # Build span tree + spans_by_id = {} + root_span = None + + for row in rows: + span = SpanWithChildren( + span_id=row["span_id"], + trace_id=row["trace_id"], + parent_span_id=row["parent_span_id"], + name=row["name"], + start_time=datetime.fromisoformat(row["start_time"]), + end_time=datetime.fromisoformat(row["end_time"]), + attributes=json.loads(row["filtered_attributes"]), + status=row["status"].lower(), + children=[], + ) + + spans_by_id[span.span_id] = span + + if span.span_id == span_id: + root_span = span + elif span.parent_span_id in spans_by_id: + spans_by_id[span.parent_span_id].children.append(span) + + return root_span diff --git a/llama_stack/providers/utils/telemetry/sqlite_trace_store.py b/llama_stack/providers/utils/telemetry/sqlite_trace_store.py new file mode 100644 index 000000000..ed1343e0b --- /dev/null +++ b/llama_stack/providers/utils/telemetry/sqlite_trace_store.py @@ -0,0 +1,180 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +from datetime import datetime +from typing import List, Optional, Protocol + +import aiosqlite + +from llama_stack.apis.telemetry import QueryCondition, SpanWithChildren, Trace + + +class TraceStore(Protocol): + + async def query_traces( + self, + attribute_filters: Optional[List[QueryCondition]] = None, + limit: Optional[int] = 100, + offset: Optional[int] = 0, + order_by: Optional[List[str]] = None, + ) -> List[Trace]: ... + + async def get_materialized_span( + self, + span_id: str, + attributes_to_return: Optional[List[str]] = None, + max_depth: Optional[int] = None, + ) -> SpanWithChildren: ... + + +class SQLiteTraceStore(TraceStore): + def __init__(self, conn_string: str): + self.conn_string = conn_string + + async def query_traces( + self, + attribute_filters: Optional[List[QueryCondition]] = None, + limit: Optional[int] = 100, + offset: Optional[int] = 0, + order_by: Optional[List[str]] = None, + ) -> List[Trace]: + + def build_where_clause() -> tuple[str, list]: + if not attribute_filters: + return "", [] + + ops_map = {"eq": "=", "ne": "!=", "gt": ">", "lt": "<"} + + conditions = [ + f"json_extract(s.attributes, '$.{condition.key}') {ops_map[condition.op]} ?" + for condition in attribute_filters + ] + params = [condition.value for condition in attribute_filters] + where_clause = " WHERE " + " AND ".join(conditions) + return where_clause, params + + def build_order_clause() -> str: + if not order_by: + return "" + + order_clauses = [] + for field in order_by: + desc = field.startswith("-") + clean_field = field[1:] if desc else field + order_clauses.append(f"t.{clean_field} {'DESC' if desc else 'ASC'}") + return " ORDER BY " + ", ".join(order_clauses) + + # Build the main query + base_query = """ + WITH matching_traces AS ( + SELECT DISTINCT t.trace_id + FROM traces t + JOIN spans s ON t.trace_id = s.trace_id + {where_clause} + ), + filtered_traces AS ( + SELECT t.trace_id, t.root_span_id, t.start_time, t.end_time + FROM matching_traces mt + JOIN traces t ON mt.trace_id = t.trace_id + LEFT JOIN spans s ON t.trace_id = s.trace_id + {order_clause} + ) + SELECT DISTINCT trace_id, root_span_id, start_time, end_time + FROM filtered_traces + LIMIT {limit} OFFSET {offset} + """ + + where_clause, params = build_where_clause() + query = base_query.format( + where_clause=where_clause, + order_clause=build_order_clause(), + limit=limit, + offset=offset, + ) + + # Execute query and return results + async with aiosqlite.connect(self.conn_string) as conn: + conn.row_factory = aiosqlite.Row + async with conn.execute(query, params) as cursor: + rows = await cursor.fetchall() + return [ + Trace( + trace_id=row["trace_id"], + root_span_id=row["root_span_id"], + start_time=datetime.fromisoformat(row["start_time"]), + end_time=datetime.fromisoformat(row["end_time"]), + ) + for row in rows + ] + + async def get_materialized_span( + self, + span_id: str, + attributes_to_return: Optional[List[str]] = None, + max_depth: Optional[int] = None, + ) -> SpanWithChildren: + # Build the attributes selection + attributes_select = "s.attributes" + if attributes_to_return: + json_object = ", ".join( + f"'{key}', json_extract(s.attributes, '$.{key}')" + for key in attributes_to_return + ) + attributes_select = f"json_object({json_object})" + + # SQLite CTE query with filtered attributes + query = f""" + WITH RECURSIVE span_tree AS ( + SELECT s.*, 1 as depth, {attributes_select} as filtered_attributes + FROM spans s + WHERE s.span_id = ? + + UNION ALL + + SELECT s.*, st.depth + 1, {attributes_select} as filtered_attributes + FROM spans s + JOIN span_tree st ON s.parent_span_id = st.span_id + WHERE (? IS NULL OR st.depth < ?) + ) + SELECT * + FROM span_tree + ORDER BY depth, start_time + """ + + async with aiosqlite.connect(self.conn_string) as conn: + conn.row_factory = aiosqlite.Row + async with conn.execute(query, (span_id, max_depth, max_depth)) as cursor: + rows = await cursor.fetchall() + + if not rows: + raise ValueError(f"Span {span_id} not found") + + # Build span tree + spans_by_id = {} + root_span = None + + for row in rows: + span = SpanWithChildren( + span_id=row["span_id"], + trace_id=row["trace_id"], + parent_span_id=row["parent_span_id"], + name=row["name"], + start_time=datetime.fromisoformat(row["start_time"]), + end_time=datetime.fromisoformat(row["end_time"]), + attributes=json.loads(row["filtered_attributes"]), + status=row["status"].lower(), + children=[], + ) + + spans_by_id[span.span_id] = span + + if span.span_id == span_id: + root_span = span + elif span.parent_span_id in spans_by_id: + spans_by_id[span.parent_span_id].children.append(span) + + return root_span diff --git a/llama_stack/providers/utils/telemetry/tracing.py b/llama_stack/providers/utils/telemetry/tracing.py index b53dc0df9..54558afdc 100644 --- a/llama_stack/providers/utils/telemetry/tracing.py +++ b/llama_stack/providers/utils/telemetry/tracing.py @@ -69,7 +69,7 @@ class TraceContext: self.logger = logger self.trace_id = trace_id - def push_span(self, name: str, attributes: Dict[str, Any] = None): + def push_span(self, name: str, attributes: Dict[str, Any] = None) -> Span: current_span = self.get_current_span() span = Span( span_id=generate_short_uuid(), @@ -94,6 +94,7 @@ class TraceContext: ) self.spans.append(span) + return span def pop_span(self, status: SpanStatus = SpanStatus.OK): span = self.spans.pop() @@ -203,12 +204,13 @@ class SpanContextManager: def __init__(self, name: str, attributes: Dict[str, Any] = None): self.name = name self.attributes = attributes + self.span = None def __enter__(self): global CURRENT_TRACE_CONTEXT context = CURRENT_TRACE_CONTEXT if context: - context.push_span(self.name, self.attributes) + self.span = context.push_span(self.name, self.attributes) return self def __exit__(self, exc_type, exc_value, traceback): @@ -217,11 +219,24 @@ class SpanContextManager: if context: context.pop_span() + def set_attribute(self, key: str, value: Any): + if self.span: + if self.span.attributes is None: + self.span.attributes = {} + self.span.attributes[key] = value + async def __aenter__(self): - return self.__enter__() + global CURRENT_TRACE_CONTEXT + context = CURRENT_TRACE_CONTEXT + if context: + self.span = context.push_span(self.name, self.attributes) + return self async def __aexit__(self, exc_type, exc_value, traceback): - self.__exit__(exc_type, exc_value, traceback) + global CURRENT_TRACE_CONTEXT + context = CURRENT_TRACE_CONTEXT + if context: + context.pop_span() def __call__(self, func: Callable): @wraps(func) @@ -246,3 +261,11 @@ class SpanContextManager: def span(name: str, attributes: Dict[str, Any] = None): return SpanContextManager(name, attributes) + + +def get_current_span() -> Optional[Span]: + global CURRENT_TRACE_CONTEXT + context = CURRENT_TRACE_CONTEXT + if context: + return context.get_current_span() + return None