diff --git a/llama_toolchain/agentic_system/providers.py b/llama_toolchain/agentic_system/providers.py index 164df1a30..e2d1e6424 100644 --- a/llama_toolchain/agentic_system/providers.py +++ b/llama_toolchain/agentic_system/providers.py @@ -9,7 +9,7 @@ from typing import List from llama_toolchain.core.datatypes import Api, InlineProviderSpec, ProviderSpec -def available_agentic_system_providers() -> List[ProviderSpec]: +def available_providers() -> List[ProviderSpec]: return [ InlineProviderSpec( api=Api.agentic_system, diff --git a/llama_toolchain/core/datatypes.py b/llama_toolchain/core/datatypes.py index 2405a57ce..cd9fc9dcf 100644 --- a/llama_toolchain/core/datatypes.py +++ b/llama_toolchain/core/datatypes.py @@ -19,6 +19,7 @@ class Api(Enum): safety = "safety" agentic_system = "agentic_system" memory = "memory" + telemetry = "telemetry" @json_schema_type diff --git a/llama_toolchain/core/distribution.py b/llama_toolchain/core/distribution.py index 89e1d7793..1a22b7b06 100644 --- a/llama_toolchain/core/distribution.py +++ b/llama_toolchain/core/distribution.py @@ -4,17 +4,15 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import importlib import inspect from typing import Dict, List from llama_toolchain.agentic_system.api import AgenticSystem -from llama_toolchain.agentic_system.providers import available_agentic_system_providers from llama_toolchain.inference.api import Inference -from llama_toolchain.inference.providers import available_inference_providers from llama_toolchain.memory.api import Memory -from llama_toolchain.memory.providers import available_memory_providers from llama_toolchain.safety.api import Safety -from llama_toolchain.safety.providers import available_safety_providers +from llama_toolchain.telemetry.api import Telemetry from .datatypes import ( Api, @@ -44,7 +42,7 @@ def distribution_dependencies(distribution: DistributionSpec) -> List[str]: def stack_apis() -> List[Api]: - return [Api.inference, Api.safety, Api.agentic_system, Api.memory] + return [v for v in Api] def api_endpoints() -> Dict[Api, List[ApiEndpoint]]: @@ -55,6 +53,7 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]: Api.safety: Safety, Api.agentic_system: AgenticSystem, Api.memory: Memory, + Api.telemetry: Telemetry, } for api, protocol in protocols.items(): @@ -82,20 +81,13 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]: def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]: - inference_providers_by_id = { - a.provider_type: a for a in available_inference_providers() - } - safety_providers_by_id = {a.provider_type: a for a in available_safety_providers()} - agentic_system_providers_by_id = { - a.provider_type: a for a in available_agentic_system_providers() - } + ret = {} + for api in stack_apis(): + name = api.name.lower() + module = importlib.import_module(f"llama_toolchain.{name}.providers") + ret[api] = { + "remote": remote_provider_spec(api), + **{a.provider_type: a for a in module.available_providers()}, + } - ret = { - Api.inference: inference_providers_by_id, - Api.safety: safety_providers_by_id, - Api.agentic_system: agentic_system_providers_by_id, - Api.memory: {a.provider_type: a for a in available_memory_providers()}, - } - for k, v in ret.items(): - v["remote"] = remote_provider_spec(k) return ret diff --git a/llama_toolchain/core/distribution_registry.py b/llama_toolchain/core/distribution_registry.py index 2b15af72b..75c43ed34 100644 --- a/llama_toolchain/core/distribution_registry.py +++ b/llama_toolchain/core/distribution_registry.py @@ -21,12 +21,16 @@ def available_distribution_specs() -> List[DistributionSpec]: Api.memory: "meta-reference-faiss", Api.safety: "meta-reference", Api.agentic_system: "meta-reference", + Api.telemetry: "console", }, ), DistributionSpec( distribution_type="remote", description="Point to remote services for all llama stack APIs", - providers={x: "remote" for x in Api}, + providers={ + **{x: "remote" for x in Api}, + Api.telemetry: "console", + }, ), DistributionSpec( distribution_type="local-ollama", @@ -36,6 +40,7 @@ def available_distribution_specs() -> List[DistributionSpec]: Api.safety: "meta-reference", Api.agentic_system: "meta-reference", Api.memory: "meta-reference-faiss", + Api.telemetry: "console", }, ), DistributionSpec( @@ -46,6 +51,7 @@ def available_distribution_specs() -> List[DistributionSpec]: Api.safety: "meta-reference", Api.agentic_system: "meta-reference", Api.memory: "meta-reference-faiss", + Api.telemetry: "console", }, ), DistributionSpec( @@ -56,6 +62,7 @@ def available_distribution_specs() -> List[DistributionSpec]: Api.safety: "meta-reference", Api.agentic_system: "meta-reference", Api.memory: "meta-reference-faiss", + Api.telemetry: "console", }, ), ] diff --git a/llama_toolchain/core/server.py b/llama_toolchain/core/server.py index 8c7ab10a7..b0ec75fe5 100644 --- a/llama_toolchain/core/server.py +++ b/llama_toolchain/core/server.py @@ -38,6 +38,13 @@ from pydantic import BaseModel, ValidationError from termcolor import cprint from typing_extensions import Annotated +from llama_toolchain.telemetry.tracing import ( + end_trace, + setup_logger, + SpanStatus, + start_trace, +) + from .datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec from .distribution import api_endpoints, api_providers from .dynamic import instantiate_provider @@ -88,6 +95,8 @@ async def passthrough( downstream_url: str, downstream_headers: Optional[Dict[str, str]] = None, ): + await start_trace(request.path, {"downstream_url": downstream_url}) + headers = dict(request.headers) headers.pop("host", None) headers.update(downstream_headers or {}) @@ -95,6 +104,7 @@ async def passthrough( content = await request.body() client = httpx.AsyncClient() + erred = False try: req = client.build_request( method=request.method, @@ -120,17 +130,25 @@ async def passthrough( ) except httpx.ReadTimeout: + erred = True return Response(content="Downstream server timed out", status_code=504) except httpx.NetworkError as e: + erred = True return Response(content=f"Network error: {str(e)}", status_code=502) except httpx.TooManyRedirects: + erred = True return Response(content="Too many redirects", status_code=502) except SSLError as e: + erred = True return Response(content=f"SSL error: {str(e)}", status_code=502) except httpx.HTTPStatusError as e: + erred = True return Response(content=str(e), status_code=e.response.status_code) except Exception as e: + erred = True return Response(content=f"Unexpected error: {str(e)}", status_code=500) + finally: + await end_trace(SpanStatus.OK if not erred else SpanStatus.ERROR) def handle_sigint(*args, **kwargs): @@ -159,7 +177,7 @@ def create_dynamic_passthrough( def create_dynamic_typed_route(func: Any, method: str): hints = get_type_hints(func) - response_model = hints["return"] + response_model = hints.get("return") # NOTE: I think it is better to just add a method within each Api # "Protocol" / adapter-impl to tell what sort of a response this request @@ -170,6 +188,8 @@ def create_dynamic_typed_route(func: Any, method: str): if is_streaming: async def endpoint(**kwargs): + await start_trace(func.__name__) + async def sse_generator(event_gen): try: async for item in event_gen: @@ -187,6 +207,8 @@ def create_dynamic_typed_route(func: Any, method: str): }, } ) + finally: + await end_trace() return StreamingResponse( sse_generator(func(**kwargs)), media_type="text/event-stream" @@ -195,6 +217,7 @@ def create_dynamic_typed_route(func: Any, method: str): else: async def endpoint(**kwargs): + await start_trace(func.__name__) try: return ( await func(**kwargs) @@ -204,6 +227,8 @@ def create_dynamic_typed_route(func: Any, method: str): except Exception as e: traceback.print_exception(e) raise translate_exception(e) from e + finally: + await end_trace() sig = inspect.signature(func) if method == "post": @@ -293,6 +318,8 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False): provider_specs[api] = providers[provider_type] impls = resolve_impls(provider_specs, config) + if Api.telemetry in impls: + setup_logger(impls[Api.telemetry]) for provider_spec in provider_specs.values(): api = provider_spec.api diff --git a/llama_toolchain/inference/providers.py b/llama_toolchain/inference/providers.py index 5219585c3..6e80ecb37 100644 --- a/llama_toolchain/inference/providers.py +++ b/llama_toolchain/inference/providers.py @@ -9,7 +9,7 @@ from typing import List from llama_toolchain.core.datatypes import * # noqa: F403 -def available_inference_providers() -> List[ProviderSpec]: +def available_providers() -> List[ProviderSpec]: return [ InlineProviderSpec( api=Api.inference, diff --git a/llama_toolchain/memory/meta_reference/faiss/faiss.py b/llama_toolchain/memory/meta_reference/faiss/faiss.py index 807aa208f..2dcff4d25 100644 --- a/llama_toolchain/memory/meta_reference/faiss/faiss.py +++ b/llama_toolchain/memory/meta_reference/faiss/faiss.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import logging import uuid from typing import Any, Dict, List, Optional @@ -20,8 +21,11 @@ from llama_toolchain.memory.common.vector_store import ( BankWithIndex, EmbeddingIndex, ) +from llama_toolchain.telemetry import tracing from .config import FaissImplConfig +logger = logging.getLogger(__name__) + class FaissIndex(EmbeddingIndex): id_by_index: Dict[int, str] @@ -32,11 +36,12 @@ class FaissIndex(EmbeddingIndex): self.id_by_index = {} self.chunk_by_index = {} + @tracing.span(name="add_chunks") async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): indexlen = len(self.id_by_index) for i, chunk in enumerate(chunks): self.chunk_by_index[indexlen + i] = chunk - print(f"Adding chunk #{indexlen + i} tokens={chunk.token_count}") + logger.info(f"Adding chunk #{indexlen + i} tokens={chunk.token_count}") self.id_by_index[indexlen + i] = chunk.document_id self.index.add(np.array(embeddings).astype(np.float32)) diff --git a/llama_toolchain/memory/providers.py b/llama_toolchain/memory/providers.py index cc113d132..525f947a0 100644 --- a/llama_toolchain/memory/providers.py +++ b/llama_toolchain/memory/providers.py @@ -14,7 +14,7 @@ EMBEDDING_DEPS = [ ] -def available_memory_providers() -> List[ProviderSpec]: +def available_providers() -> List[ProviderSpec]: return [ InlineProviderSpec( api=Api.memory, diff --git a/llama_toolchain/safety/providers.py b/llama_toolchain/safety/providers.py index 8471ab139..0db454ef3 100644 --- a/llama_toolchain/safety/providers.py +++ b/llama_toolchain/safety/providers.py @@ -9,7 +9,7 @@ from typing import List from llama_toolchain.core.datatypes import Api, InlineProviderSpec, ProviderSpec -def available_safety_providers() -> List[ProviderSpec]: +def available_providers() -> List[ProviderSpec]: return [ InlineProviderSpec( api=Api.safety, diff --git a/llama_toolchain/telemetry/api/api.py b/llama_toolchain/telemetry/api/api.py index ae784428b..100836b46 100644 --- a/llama_toolchain/telemetry/api/api.py +++ b/llama_toolchain/telemetry/api/api.py @@ -6,170 +6,126 @@ from datetime import datetime from enum import Enum - -from typing import Any, Dict, List, Optional, Protocol, Union +from typing import Any, Dict, Literal, Optional, Protocol, Union from llama_models.schema_utils import json_schema_type, webmethod - -from pydantic import BaseModel +from pydantic import BaseModel, Field +from typing_extensions import Annotated @json_schema_type -class ExperimentStatus(Enum): - NOT_STARTED = "not_started" - RUNNING = "running" - COMPLETED = "completed" - FAILED = "failed" +class SpanStatus(Enum): + OK = "ok" + ERROR = "error" @json_schema_type -class Experiment(BaseModel): - id: str +class Span(BaseModel): + span_id: str + trace_id: str + parent_span_id: Optional[str] = None name: str - status: ExperimentStatus - created_at: datetime - updated_at: datetime - metadata: Dict[str, Any] + start_time: datetime + end_time: Optional[datetime] = None + attributes: Optional[Dict[str, Any]] = Field(default_factory=dict) @json_schema_type -class Run(BaseModel): - id: str - experiment_id: str - status: str - started_at: datetime - ended_at: Optional[datetime] - metadata: Dict[str, Any] +class Trace(BaseModel): + trace_id: str + root_span_id: str + start_time: datetime + end_time: Optional[datetime] = None @json_schema_type -class Metric(BaseModel): - name: str - value: Union[float, int, str, bool] - timestamp: datetime - run_id: str - - -@json_schema_type -class Log(BaseModel): - message: str - level: str - timestamp: datetime - additional_info: Dict[str, Any] - - -@json_schema_type -class ArtifactType(Enum): - MODEL = "model" - DATASET = "dataset" - CHECKPOINT = "checkpoint" - PLOT = "plot" +class EventType(Enum): + UNSTRUCTURED_LOG = "unstructured_log" + STRUCTURED_LOG = "structured_log" METRIC = "metric" - CONFIG = "config" - CODE = "code" - OTHER = "other" @json_schema_type -class Artifact(BaseModel): - id: str +class LogSeverity(Enum): + VERBOSE = "verbose" + DEBUG = "debug" + INFO = "info" + WARN = "warn" + ERROR = "error" + CRITICAL = "critical" + + +class EventCommon(BaseModel): + trace_id: str + span_id: str + timestamp: datetime + attributes: Optional[Dict[str, Any]] = Field(default_factory=dict) + + +@json_schema_type +class UnstructuredLogEvent(EventCommon): + type: Literal[EventType.UNSTRUCTURED_LOG.value] = EventType.UNSTRUCTURED_LOG.value + message: str + severity: LogSeverity + + +@json_schema_type +class MetricEvent(EventCommon): + type: Literal[EventType.METRIC.value] = EventType.METRIC.value + metric: str # this would be an enum + value: Union[int, float] + unit: str + + +@json_schema_type +class StructuredLogType(Enum): + SPAN_START = "span_start" + SPAN_END = "span_end" + + +@json_schema_type +class SpanStartPayload(BaseModel): + type: Literal[StructuredLogType.SPAN_START.value] = ( + StructuredLogType.SPAN_START.value + ) name: str - type: ArtifactType - size: int - created_at: datetime - metadata: Dict[str, Any] + parent_span_id: Optional[str] = None @json_schema_type -class CreateExperimentRequest(BaseModel): - name: str - metadata: Optional[Dict[str, Any]] = None +class SpanEndPayload(BaseModel): + type: Literal[StructuredLogType.SPAN_END.value] = StructuredLogType.SPAN_END.value + status: SpanStatus + + +StructuredLogPayload = Annotated[ + Union[ + SpanStartPayload, + SpanEndPayload, + ], + Field(discriminator="type"), +] @json_schema_type -class UpdateExperimentRequest(BaseModel): - experiment_id: str - status: Optional[ExperimentStatus] = None - metadata: Optional[Dict[str, Any]] = None +class StructuredLogEvent(EventCommon): + type: Literal[EventType.STRUCTURED_LOG.value] = EventType.STRUCTURED_LOG.value + payload: StructuredLogPayload -@json_schema_type -class CreateRunRequest(BaseModel): - experiment_id: str - metadata: Optional[Dict[str, Any]] = None - - -@json_schema_type -class UpdateRunRequest(BaseModel): - run_id: str - status: Optional[str] = None - ended_at: Optional[datetime] = None - metadata: Optional[Dict[str, Any]] = None - - -@json_schema_type -class LogMetricsRequest(BaseModel): - run_id: str - metrics: List[Metric] - - -@json_schema_type -class LogMessagesRequest(BaseModel): - logs: List[Log] - run_id: Optional[str] = None - - -@json_schema_type -class UploadArtifactRequest(BaseModel): - experiment_id: str - name: str - artifact_type: str - content: bytes - metadata: Optional[Dict[str, Any]] = None - - -@json_schema_type -class LogSearchRequest(BaseModel): - query: str - filters: Optional[Dict[str, Any]] = None +Event = Annotated[ + Union[ + UnstructuredLogEvent, + MetricEvent, + StructuredLogEvent, + ], + Field(discriminator="type"), +] class Telemetry(Protocol): - @webmethod(route="/experiments/create") - def create_experiment(self, request: CreateExperimentRequest) -> Experiment: ... + @webmethod(route="/telemetry/log_event") + async def log_event(self, event: Event): ... - @webmethod(route="/experiments/list") - def list_experiments(self) -> List[Experiment]: ... - - @webmethod(route="/experiments/get") - def get_experiment(self, experiment_id: str) -> Experiment: ... - - @webmethod(route="/experiments/update") - def update_experiment(self, request: UpdateExperimentRequest) -> Experiment: ... - - @webmethod(route="/experiments/create_run") - def create_run(self, request: CreateRunRequest) -> Run: ... - - @webmethod(route="/runs/update") - def update_run(self, request: UpdateRunRequest) -> Run: ... - - @webmethod(route="/runs/log_metrics") - def log_metrics(self, request: LogMetricsRequest) -> None: ... - - @webmethod(route="/runs/metrics", method="GET") - def get_metrics(self, run_id: str) -> List[Metric]: ... - - @webmethod(route="/logging/log_messages") - def log_messages(self, request: LogMessagesRequest) -> None: ... - - @webmethod(route="/logging/get_logs") - def get_logs(self, request: LogSearchRequest) -> List[Log]: ... - - @webmethod(route="/experiments/artifacts/upload") - def upload_artifact(self, request: UploadArtifactRequest) -> Artifact: ... - - @webmethod(route="/experiments/artifacts/get") - def list_artifacts(self, experiment_id: str) -> List[Artifact]: ... - - @webmethod(route="/artifacts/get") - def get_artifact(self, artifact_id: str) -> Artifact: ... + @webmethod(route="/telemetry/get_trace", method="GET") + async def get_trace(self, trace_id: str) -> Trace: ... diff --git a/llama_toolchain/telemetry/console/__init__.py b/llama_toolchain/telemetry/console/__init__.py new file mode 100644 index 000000000..4a0c2f6ee --- /dev/null +++ b/llama_toolchain/telemetry/console/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import ConsoleConfig + + +async def get_provider_impl(config: ConsoleConfig, _deps): + from .console import ConsoleTelemetryImpl + + impl = ConsoleTelemetryImpl(config) + await impl.initialize() + return impl diff --git a/llama_toolchain/telemetry/console/config.py b/llama_toolchain/telemetry/console/config.py new file mode 100644 index 000000000..c639c6798 --- /dev/null +++ b/llama_toolchain/telemetry/console/config.py @@ -0,0 +1,13 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_models.schema_utils import json_schema_type + +from pydantic import BaseModel + + +@json_schema_type +class ConsoleConfig(BaseModel): ... diff --git a/llama_toolchain/telemetry/console/console.py b/llama_toolchain/telemetry/console/console.py new file mode 100644 index 000000000..2e7b9980d --- /dev/null +++ b/llama_toolchain/telemetry/console/console.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Optional + +from llama_toolchain.telemetry.api import * # noqa: F403 +from .config import ConsoleConfig + + +class ConsoleTelemetryImpl(Telemetry): + def __init__(self, config: ConsoleConfig) -> None: + self.config = config + self.spans = {} + + async def initialize(self) -> None: ... + + async def shutdown(self) -> None: ... + + async def log_event(self, event: Event): + if ( + isinstance(event, StructuredLogEvent) + and event.payload.type == StructuredLogType.SPAN_START.value + ): + self.spans[event.span_id] = event.payload + + names = [] + span_id = event.span_id + while True: + span_payload = self.spans.get(span_id) + if not span_payload: + break + + names = [span_payload.name] + names + span_id = span_payload.parent_span_id + + span_name = ".".join(names) if names else None + + formatted = format_event(event, span_name) + if formatted: + print(formatted) + + async def get_trace(self, trace_id: str) -> Trace: + raise NotImplementedError() + + +COLORS = { + "reset": "\033[0m", + "bold": "\033[1m", + "dim": "\033[2m", + "red": "\033[31m", + "green": "\033[32m", + "yellow": "\033[33m", + "blue": "\033[34m", + "magenta": "\033[35m", + "cyan": "\033[36m", + "white": "\033[37m", +} + +SEVERITY_COLORS = { + LogSeverity.VERBOSE: COLORS["dim"] + COLORS["white"], + LogSeverity.DEBUG: COLORS["cyan"], + LogSeverity.INFO: COLORS["green"], + LogSeverity.WARN: COLORS["yellow"], + LogSeverity.ERROR: COLORS["red"], + LogSeverity.CRITICAL: COLORS["bold"] + COLORS["red"], +} + + +def format_event(event: Event, span_name: str) -> Optional[str]: + timestamp = event.timestamp.strftime("%H:%M:%S.%f")[:-3] + span = "" + if span_name: + span = f"{COLORS['magenta']}[{span_name}]{COLORS['reset']} " + if isinstance(event, UnstructuredLogEvent): + severity_color = SEVERITY_COLORS.get(event.severity, COLORS["reset"]) + return ( + f"{COLORS['dim']}{timestamp}{COLORS['reset']} " + f"{severity_color}[{event.severity.name}]{COLORS['reset']} " + f"{span}" + f"{event.message}" + ) + + elif isinstance(event, StructuredLogEvent): + return None + + return f"Unknown event type: {event}" diff --git a/llama_toolchain/telemetry/providers.py b/llama_toolchain/telemetry/providers.py new file mode 100644 index 000000000..7b04145b3 --- /dev/null +++ b/llama_toolchain/telemetry/providers.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import List + +from llama_toolchain.core.datatypes import * # noqa: F403 + + +def available_providers() -> List[ProviderSpec]: + return [ + InlineProviderSpec( + api=Api.telemetry, + provider_type="console", + pip_packages=[], + module="llama_toolchain.telemetry.console", + config_class="llama_toolchain.telemetry.console.ConsoleConfig", + ), + ] diff --git a/llama_toolchain/telemetry/tracing.py b/llama_toolchain/telemetry/tracing.py new file mode 100644 index 000000000..6afe5c2fb --- /dev/null +++ b/llama_toolchain/telemetry/tracing.py @@ -0,0 +1,236 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +import base64 +import logging +import queue +import threading +import uuid +from datetime import datetime +from functools import wraps +from typing import Any, Dict, List + + +from llama_toolchain.telemetry.api import * # noqa: F403 + + +def generate_short_uuid(len: int = 12): + full_uuid = uuid.uuid4() + uuid_bytes = full_uuid.bytes + encoded = base64.urlsafe_b64encode(uuid_bytes) + return encoded.rstrip(b"=").decode("ascii")[:len] + + +CURRENT_TRACE_CONTEXT = None +BACKGROUND_LOGGER = None + + +class BackgroundLogger: + def __init__(self, api: Telemetry, capacity: int = 1000): + self.api = api + self.log_queue = queue.Queue(maxsize=capacity) + self.worker_thread = threading.Thread(target=self._process_logs, daemon=True) + self.worker_thread.start() + + def log_event(self, event): + try: + self.log_queue.put_nowait(event) + except queue.Full: + print("Log queue is full, dropping event") + + def _process_logs(self): + while True: + try: + event = self.log_queue.get() + # figure out how to use a thread's native loop + asyncio.run(self.api.log_event(event)) + except Exception: + import traceback + + traceback.print_exc() + print("Error processing log event") + finally: + self.log_queue.task_done() + + def __del__(self): + self.log_queue.join() + + +class TraceContext: + spans: List[Span] = [] + + def __init__(self, logger: BackgroundLogger, trace_id: str): + self.logger = logger + self.trace_id = trace_id + + def push_span(self, name: str, attributes: Dict[str, Any] = None): + current_span = self.get_current_span() + span = Span( + span_id=generate_short_uuid(), + trace_id=self.trace_id, + name=name, + start_time=datetime.now(), + parent_span_id=current_span.span_id if current_span else None, + attributes=attributes, + ) + + self.logger.log_event( + StructuredLogEvent( + trace_id=span.trace_id, + span_id=span.span_id, + timestamp=span.start_time, + attributes=span.attributes, + payload=SpanStartPayload( + name=span.name, + parent_span_id=span.parent_span_id, + ), + ) + ) + + self.spans.append(span) + + def pop_span(self, status: SpanStatus = SpanStatus.OK): + span = self.spans.pop() + if span is not None: + self.logger.log_event( + StructuredLogEvent( + trace_id=span.trace_id, + span_id=span.span_id, + timestamp=span.start_time, + attributes=span.attributes, + payload=SpanEndPayload( + status=status, + ), + ) + ) + + def get_current_span(self): + return self.spans[-1] if self.spans else None + + +def setup_logger(api: Telemetry, level: int = logging.INFO): + global BACKGROUND_LOGGER + + BACKGROUND_LOGGER = BackgroundLogger(api) + logger = logging.getLogger() + logger.setLevel(level) + logger.addHandler(TelemetryHandler()) + + +async def start_trace(name: str, attributes: Dict[str, Any] = None): + global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER + + if BACKGROUND_LOGGER is None: + print("No Telemetry implementation set. Skipping trace initialization...") + return + + trace_id = generate_short_uuid() + context = TraceContext(BACKGROUND_LOGGER, trace_id) + context.push_span(name, {"__root__": True, **(attributes or {})}) + + CURRENT_TRACE_CONTEXT = context + + +async def end_trace(status: SpanStatus = SpanStatus.OK): + global CURRENT_TRACE_CONTEXT + + context = CURRENT_TRACE_CONTEXT + if context is None: + return + + context.pop_span(status) + CURRENT_TRACE_CONTEXT = None + + +def severity(levelname: str) -> LogSeverity: + if levelname == "DEBUG": + return LogSeverity.DEBUG + elif levelname == "INFO": + return LogSeverity.INFO + elif levelname == "WARNING": + return LogSeverity.WARNING + elif levelname == "ERROR": + return LogSeverity.ERROR + elif levelname == "CRITICAL": + return LogSeverity.CRITICAL + else: + raise ValueError(f"Unknown log level: {levelname}") + + +# TODO: ideally, the actual emitting should be done inside a separate daemon +# process completely isolated from the server +class TelemetryHandler(logging.Handler): + def emit(self, record: logging.LogRecord): + # horrendous hack to avoid logging from asyncio and getting into an infinite loop + if record.module in ("asyncio", "selector_events"): + return + + global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER + + if BACKGROUND_LOGGER is None: + raise RuntimeError("Telemetry API not initialized") + + context = CURRENT_TRACE_CONTEXT + if context is None: + return + + span = context.get_current_span() + if span is None: + return + + BACKGROUND_LOGGER.log_event( + UnstructuredLogEvent( + trace_id=span.trace_id, + span_id=span.span_id, + timestamp=datetime.now(), + message=self.format(record), + severity=severity(record.levelname), + ) + ) + + def close(self): + pass + + +def span(name: str, attributes: Dict[str, Any] = None): + def decorator(func): + @wraps(func) + def sync_wrapper(*args, **kwargs): + try: + global CURRENT_TRACE_CONTEXT + + context = CURRENT_TRACE_CONTEXT + if context: + context.push_span(name, attributes) + result = func(*args, **kwargs) + finally: + context.pop_span() + return result + + @wraps(func) + async def async_wrapper(*args, **kwargs): + try: + global CURRENT_TRACE_CONTEXT + + context = CURRENT_TRACE_CONTEXT + if context: + context.push_span(name, attributes) + result = await func(*args, **kwargs) + finally: + context.pop_span() + return result + + @wraps(func) + def wrapper(*args, **kwargs): + if asyncio.iscoroutinefunction(func): + return async_wrapper(*args, **kwargs) + else: + return sync_wrapper(*args, **kwargs) + + return wrapper + + return decorator