Simplified Telemetry API and tying it to logger (#57)

* Simplified Telemetry API and tying it to logger

* small update which adds a METRIC type

* move span events one level down into structured log events

---------

Co-authored-by: Ashwin Bharambe <ashwin@meta.com>
This commit is contained in:
Ashwin Bharambe 2024-09-11 14:25:37 -07:00 committed by GitHub
parent 1433aaf9f7
commit 191cd28831
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 524 additions and 162 deletions

View file

@ -9,7 +9,7 @@ from typing import List
from llama_toolchain.core.datatypes import Api, InlineProviderSpec, ProviderSpec
def available_agentic_system_providers() -> List[ProviderSpec]:
def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.agentic_system,

View file

@ -19,6 +19,7 @@ class Api(Enum):
safety = "safety"
agentic_system = "agentic_system"
memory = "memory"
telemetry = "telemetry"
@json_schema_type

View file

@ -4,17 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import importlib
import inspect
from typing import Dict, List
from llama_toolchain.agentic_system.api import AgenticSystem
from llama_toolchain.agentic_system.providers import available_agentic_system_providers
from llama_toolchain.inference.api import Inference
from llama_toolchain.inference.providers import available_inference_providers
from llama_toolchain.memory.api import Memory
from llama_toolchain.memory.providers import available_memory_providers
from llama_toolchain.safety.api import Safety
from llama_toolchain.safety.providers import available_safety_providers
from llama_toolchain.telemetry.api import Telemetry
from .datatypes import (
Api,
@ -44,7 +42,7 @@ def distribution_dependencies(distribution: DistributionSpec) -> List[str]:
def stack_apis() -> List[Api]:
return [Api.inference, Api.safety, Api.agentic_system, Api.memory]
return [v for v in Api]
def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
@ -55,6 +53,7 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
Api.safety: Safety,
Api.agentic_system: AgenticSystem,
Api.memory: Memory,
Api.telemetry: Telemetry,
}
for api, protocol in protocols.items():
@ -82,20 +81,13 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
inference_providers_by_id = {
a.provider_type: a for a in available_inference_providers()
}
safety_providers_by_id = {a.provider_type: a for a in available_safety_providers()}
agentic_system_providers_by_id = {
a.provider_type: a for a in available_agentic_system_providers()
}
ret = {}
for api in stack_apis():
name = api.name.lower()
module = importlib.import_module(f"llama_toolchain.{name}.providers")
ret[api] = {
"remote": remote_provider_spec(api),
**{a.provider_type: a for a in module.available_providers()},
}
ret = {
Api.inference: inference_providers_by_id,
Api.safety: safety_providers_by_id,
Api.agentic_system: agentic_system_providers_by_id,
Api.memory: {a.provider_type: a for a in available_memory_providers()},
}
for k, v in ret.items():
v["remote"] = remote_provider_spec(k)
return ret

View file

@ -21,12 +21,16 @@ def available_distribution_specs() -> List[DistributionSpec]:
Api.memory: "meta-reference-faiss",
Api.safety: "meta-reference",
Api.agentic_system: "meta-reference",
Api.telemetry: "console",
},
),
DistributionSpec(
distribution_type="remote",
description="Point to remote services for all llama stack APIs",
providers={x: "remote" for x in Api},
providers={
**{x: "remote" for x in Api},
Api.telemetry: "console",
},
),
DistributionSpec(
distribution_type="local-ollama",
@ -36,6 +40,7 @@ def available_distribution_specs() -> List[DistributionSpec]:
Api.safety: "meta-reference",
Api.agentic_system: "meta-reference",
Api.memory: "meta-reference-faiss",
Api.telemetry: "console",
},
),
DistributionSpec(
@ -46,6 +51,7 @@ def available_distribution_specs() -> List[DistributionSpec]:
Api.safety: "meta-reference",
Api.agentic_system: "meta-reference",
Api.memory: "meta-reference-faiss",
Api.telemetry: "console",
},
),
DistributionSpec(
@ -56,6 +62,7 @@ def available_distribution_specs() -> List[DistributionSpec]:
Api.safety: "meta-reference",
Api.agentic_system: "meta-reference",
Api.memory: "meta-reference-faiss",
Api.telemetry: "console",
},
),
]

View file

@ -38,6 +38,13 @@ from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_toolchain.telemetry.tracing import (
end_trace,
setup_logger,
SpanStatus,
start_trace,
)
from .datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
from .distribution import api_endpoints, api_providers
from .dynamic import instantiate_provider
@ -88,6 +95,8 @@ async def passthrough(
downstream_url: str,
downstream_headers: Optional[Dict[str, str]] = None,
):
await start_trace(request.path, {"downstream_url": downstream_url})
headers = dict(request.headers)
headers.pop("host", None)
headers.update(downstream_headers or {})
@ -95,6 +104,7 @@ async def passthrough(
content = await request.body()
client = httpx.AsyncClient()
erred = False
try:
req = client.build_request(
method=request.method,
@ -120,17 +130,25 @@ async def passthrough(
)
except httpx.ReadTimeout:
erred = True
return Response(content="Downstream server timed out", status_code=504)
except httpx.NetworkError as e:
erred = True
return Response(content=f"Network error: {str(e)}", status_code=502)
except httpx.TooManyRedirects:
erred = True
return Response(content="Too many redirects", status_code=502)
except SSLError as e:
erred = True
return Response(content=f"SSL error: {str(e)}", status_code=502)
except httpx.HTTPStatusError as e:
erred = True
return Response(content=str(e), status_code=e.response.status_code)
except Exception as e:
erred = True
return Response(content=f"Unexpected error: {str(e)}", status_code=500)
finally:
await end_trace(SpanStatus.OK if not erred else SpanStatus.ERROR)
def handle_sigint(*args, **kwargs):
@ -159,7 +177,7 @@ def create_dynamic_passthrough(
def create_dynamic_typed_route(func: Any, method: str):
hints = get_type_hints(func)
response_model = hints["return"]
response_model = hints.get("return")
# NOTE: I think it is better to just add a method within each Api
# "Protocol" / adapter-impl to tell what sort of a response this request
@ -170,6 +188,8 @@ def create_dynamic_typed_route(func: Any, method: str):
if is_streaming:
async def endpoint(**kwargs):
await start_trace(func.__name__)
async def sse_generator(event_gen):
try:
async for item in event_gen:
@ -187,6 +207,8 @@ def create_dynamic_typed_route(func: Any, method: str):
},
}
)
finally:
await end_trace()
return StreamingResponse(
sse_generator(func(**kwargs)), media_type="text/event-stream"
@ -195,6 +217,7 @@ def create_dynamic_typed_route(func: Any, method: str):
else:
async def endpoint(**kwargs):
await start_trace(func.__name__)
try:
return (
await func(**kwargs)
@ -204,6 +227,8 @@ def create_dynamic_typed_route(func: Any, method: str):
except Exception as e:
traceback.print_exception(e)
raise translate_exception(e) from e
finally:
await end_trace()
sig = inspect.signature(func)
if method == "post":
@ -293,6 +318,8 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
provider_specs[api] = providers[provider_type]
impls = resolve_impls(provider_specs, config)
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])
for provider_spec in provider_specs.values():
api = provider_spec.api

View file

@ -9,7 +9,7 @@ from typing import List
from llama_toolchain.core.datatypes import * # noqa: F403
def available_inference_providers() -> List[ProviderSpec]:
def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.inference,

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import uuid
from typing import Any, Dict, List, Optional
@ -20,8 +21,11 @@ from llama_toolchain.memory.common.vector_store import (
BankWithIndex,
EmbeddingIndex,
)
from llama_toolchain.telemetry import tracing
from .config import FaissImplConfig
logger = logging.getLogger(__name__)
class FaissIndex(EmbeddingIndex):
id_by_index: Dict[int, str]
@ -32,11 +36,12 @@ class FaissIndex(EmbeddingIndex):
self.id_by_index = {}
self.chunk_by_index = {}
@tracing.span(name="add_chunks")
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
indexlen = len(self.id_by_index)
for i, chunk in enumerate(chunks):
self.chunk_by_index[indexlen + i] = chunk
print(f"Adding chunk #{indexlen + i} tokens={chunk.token_count}")
logger.info(f"Adding chunk #{indexlen + i} tokens={chunk.token_count}")
self.id_by_index[indexlen + i] = chunk.document_id
self.index.add(np.array(embeddings).astype(np.float32))

View file

@ -14,7 +14,7 @@ EMBEDDING_DEPS = [
]
def available_memory_providers() -> List[ProviderSpec]:
def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.memory,

View file

@ -9,7 +9,7 @@ from typing import List
from llama_toolchain.core.datatypes import Api, InlineProviderSpec, ProviderSpec
def available_safety_providers() -> List[ProviderSpec]:
def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.safety,

View file

@ -6,170 +6,126 @@
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Protocol, Union
from typing import Any, Dict, Literal, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
from pydantic import BaseModel, Field
from typing_extensions import Annotated
@json_schema_type
class ExperimentStatus(Enum):
NOT_STARTED = "not_started"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
class SpanStatus(Enum):
OK = "ok"
ERROR = "error"
@json_schema_type
class Experiment(BaseModel):
id: str
class Span(BaseModel):
span_id: str
trace_id: str
parent_span_id: Optional[str] = None
name: str
status: ExperimentStatus
created_at: datetime
updated_at: datetime
metadata: Dict[str, Any]
start_time: datetime
end_time: Optional[datetime] = None
attributes: Optional[Dict[str, Any]] = Field(default_factory=dict)
@json_schema_type
class Run(BaseModel):
id: str
experiment_id: str
status: str
started_at: datetime
ended_at: Optional[datetime]
metadata: Dict[str, Any]
class Trace(BaseModel):
trace_id: str
root_span_id: str
start_time: datetime
end_time: Optional[datetime] = None
@json_schema_type
class Metric(BaseModel):
name: str
value: Union[float, int, str, bool]
timestamp: datetime
run_id: str
@json_schema_type
class Log(BaseModel):
message: str
level: str
timestamp: datetime
additional_info: Dict[str, Any]
@json_schema_type
class ArtifactType(Enum):
MODEL = "model"
DATASET = "dataset"
CHECKPOINT = "checkpoint"
PLOT = "plot"
class EventType(Enum):
UNSTRUCTURED_LOG = "unstructured_log"
STRUCTURED_LOG = "structured_log"
METRIC = "metric"
CONFIG = "config"
CODE = "code"
OTHER = "other"
@json_schema_type
class Artifact(BaseModel):
id: str
class LogSeverity(Enum):
VERBOSE = "verbose"
DEBUG = "debug"
INFO = "info"
WARN = "warn"
ERROR = "error"
CRITICAL = "critical"
class EventCommon(BaseModel):
trace_id: str
span_id: str
timestamp: datetime
attributes: Optional[Dict[str, Any]] = Field(default_factory=dict)
@json_schema_type
class UnstructuredLogEvent(EventCommon):
type: Literal[EventType.UNSTRUCTURED_LOG.value] = EventType.UNSTRUCTURED_LOG.value
message: str
severity: LogSeverity
@json_schema_type
class MetricEvent(EventCommon):
type: Literal[EventType.METRIC.value] = EventType.METRIC.value
metric: str # this would be an enum
value: Union[int, float]
unit: str
@json_schema_type
class StructuredLogType(Enum):
SPAN_START = "span_start"
SPAN_END = "span_end"
@json_schema_type
class SpanStartPayload(BaseModel):
type: Literal[StructuredLogType.SPAN_START.value] = (
StructuredLogType.SPAN_START.value
)
name: str
type: ArtifactType
size: int
created_at: datetime
metadata: Dict[str, Any]
parent_span_id: Optional[str] = None
@json_schema_type
class CreateExperimentRequest(BaseModel):
name: str
metadata: Optional[Dict[str, Any]] = None
class SpanEndPayload(BaseModel):
type: Literal[StructuredLogType.SPAN_END.value] = StructuredLogType.SPAN_END.value
status: SpanStatus
StructuredLogPayload = Annotated[
Union[
SpanStartPayload,
SpanEndPayload,
],
Field(discriminator="type"),
]
@json_schema_type
class UpdateExperimentRequest(BaseModel):
experiment_id: str
status: Optional[ExperimentStatus] = None
metadata: Optional[Dict[str, Any]] = None
class StructuredLogEvent(EventCommon):
type: Literal[EventType.STRUCTURED_LOG.value] = EventType.STRUCTURED_LOG.value
payload: StructuredLogPayload
@json_schema_type
class CreateRunRequest(BaseModel):
experiment_id: str
metadata: Optional[Dict[str, Any]] = None
@json_schema_type
class UpdateRunRequest(BaseModel):
run_id: str
status: Optional[str] = None
ended_at: Optional[datetime] = None
metadata: Optional[Dict[str, Any]] = None
@json_schema_type
class LogMetricsRequest(BaseModel):
run_id: str
metrics: List[Metric]
@json_schema_type
class LogMessagesRequest(BaseModel):
logs: List[Log]
run_id: Optional[str] = None
@json_schema_type
class UploadArtifactRequest(BaseModel):
experiment_id: str
name: str
artifact_type: str
content: bytes
metadata: Optional[Dict[str, Any]] = None
@json_schema_type
class LogSearchRequest(BaseModel):
query: str
filters: Optional[Dict[str, Any]] = None
Event = Annotated[
Union[
UnstructuredLogEvent,
MetricEvent,
StructuredLogEvent,
],
Field(discriminator="type"),
]
class Telemetry(Protocol):
@webmethod(route="/experiments/create")
def create_experiment(self, request: CreateExperimentRequest) -> Experiment: ...
@webmethod(route="/telemetry/log_event")
async def log_event(self, event: Event): ...
@webmethod(route="/experiments/list")
def list_experiments(self) -> List[Experiment]: ...
@webmethod(route="/experiments/get")
def get_experiment(self, experiment_id: str) -> Experiment: ...
@webmethod(route="/experiments/update")
def update_experiment(self, request: UpdateExperimentRequest) -> Experiment: ...
@webmethod(route="/experiments/create_run")
def create_run(self, request: CreateRunRequest) -> Run: ...
@webmethod(route="/runs/update")
def update_run(self, request: UpdateRunRequest) -> Run: ...
@webmethod(route="/runs/log_metrics")
def log_metrics(self, request: LogMetricsRequest) -> None: ...
@webmethod(route="/runs/metrics", method="GET")
def get_metrics(self, run_id: str) -> List[Metric]: ...
@webmethod(route="/logging/log_messages")
def log_messages(self, request: LogMessagesRequest) -> None: ...
@webmethod(route="/logging/get_logs")
def get_logs(self, request: LogSearchRequest) -> List[Log]: ...
@webmethod(route="/experiments/artifacts/upload")
def upload_artifact(self, request: UploadArtifactRequest) -> Artifact: ...
@webmethod(route="/experiments/artifacts/get")
def list_artifacts(self, experiment_id: str) -> List[Artifact]: ...
@webmethod(route="/artifacts/get")
def get_artifact(self, artifact_id: str) -> Artifact: ...
@webmethod(route="/telemetry/get_trace", method="GET")
async def get_trace(self, trace_id: str) -> Trace: ...

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import ConsoleConfig
async def get_provider_impl(config: ConsoleConfig, _deps):
from .console import ConsoleTelemetryImpl
impl = ConsoleTelemetryImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
@json_schema_type
class ConsoleConfig(BaseModel): ...

View file

@ -0,0 +1,89 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from llama_toolchain.telemetry.api import * # noqa: F403
from .config import ConsoleConfig
class ConsoleTelemetryImpl(Telemetry):
def __init__(self, config: ConsoleConfig) -> None:
self.config = config
self.spans = {}
async def initialize(self) -> None: ...
async def shutdown(self) -> None: ...
async def log_event(self, event: Event):
if (
isinstance(event, StructuredLogEvent)
and event.payload.type == StructuredLogType.SPAN_START.value
):
self.spans[event.span_id] = event.payload
names = []
span_id = event.span_id
while True:
span_payload = self.spans.get(span_id)
if not span_payload:
break
names = [span_payload.name] + names
span_id = span_payload.parent_span_id
span_name = ".".join(names) if names else None
formatted = format_event(event, span_name)
if formatted:
print(formatted)
async def get_trace(self, trace_id: str) -> Trace:
raise NotImplementedError()
COLORS = {
"reset": "\033[0m",
"bold": "\033[1m",
"dim": "\033[2m",
"red": "\033[31m",
"green": "\033[32m",
"yellow": "\033[33m",
"blue": "\033[34m",
"magenta": "\033[35m",
"cyan": "\033[36m",
"white": "\033[37m",
}
SEVERITY_COLORS = {
LogSeverity.VERBOSE: COLORS["dim"] + COLORS["white"],
LogSeverity.DEBUG: COLORS["cyan"],
LogSeverity.INFO: COLORS["green"],
LogSeverity.WARN: COLORS["yellow"],
LogSeverity.ERROR: COLORS["red"],
LogSeverity.CRITICAL: COLORS["bold"] + COLORS["red"],
}
def format_event(event: Event, span_name: str) -> Optional[str]:
timestamp = event.timestamp.strftime("%H:%M:%S.%f")[:-3]
span = ""
if span_name:
span = f"{COLORS['magenta']}[{span_name}]{COLORS['reset']} "
if isinstance(event, UnstructuredLogEvent):
severity_color = SEVERITY_COLORS.get(event.severity, COLORS["reset"])
return (
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
f"{severity_color}[{event.severity.name}]{COLORS['reset']} "
f"{span}"
f"{event.message}"
)
elif isinstance(event, StructuredLogEvent):
return None
return f"Unknown event type: {event}"

View file

@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_toolchain.core.datatypes import * # noqa: F403
def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.telemetry,
provider_type="console",
pip_packages=[],
module="llama_toolchain.telemetry.console",
config_class="llama_toolchain.telemetry.console.ConsoleConfig",
),
]

View file

@ -0,0 +1,236 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import base64
import logging
import queue
import threading
import uuid
from datetime import datetime
from functools import wraps
from typing import Any, Dict, List
from llama_toolchain.telemetry.api import * # noqa: F403
def generate_short_uuid(len: int = 12):
full_uuid = uuid.uuid4()
uuid_bytes = full_uuid.bytes
encoded = base64.urlsafe_b64encode(uuid_bytes)
return encoded.rstrip(b"=").decode("ascii")[:len]
CURRENT_TRACE_CONTEXT = None
BACKGROUND_LOGGER = None
class BackgroundLogger:
def __init__(self, api: Telemetry, capacity: int = 1000):
self.api = api
self.log_queue = queue.Queue(maxsize=capacity)
self.worker_thread = threading.Thread(target=self._process_logs, daemon=True)
self.worker_thread.start()
def log_event(self, event):
try:
self.log_queue.put_nowait(event)
except queue.Full:
print("Log queue is full, dropping event")
def _process_logs(self):
while True:
try:
event = self.log_queue.get()
# figure out how to use a thread's native loop
asyncio.run(self.api.log_event(event))
except Exception:
import traceback
traceback.print_exc()
print("Error processing log event")
finally:
self.log_queue.task_done()
def __del__(self):
self.log_queue.join()
class TraceContext:
spans: List[Span] = []
def __init__(self, logger: BackgroundLogger, trace_id: str):
self.logger = logger
self.trace_id = trace_id
def push_span(self, name: str, attributes: Dict[str, Any] = None):
current_span = self.get_current_span()
span = Span(
span_id=generate_short_uuid(),
trace_id=self.trace_id,
name=name,
start_time=datetime.now(),
parent_span_id=current_span.span_id if current_span else None,
attributes=attributes,
)
self.logger.log_event(
StructuredLogEvent(
trace_id=span.trace_id,
span_id=span.span_id,
timestamp=span.start_time,
attributes=span.attributes,
payload=SpanStartPayload(
name=span.name,
parent_span_id=span.parent_span_id,
),
)
)
self.spans.append(span)
def pop_span(self, status: SpanStatus = SpanStatus.OK):
span = self.spans.pop()
if span is not None:
self.logger.log_event(
StructuredLogEvent(
trace_id=span.trace_id,
span_id=span.span_id,
timestamp=span.start_time,
attributes=span.attributes,
payload=SpanEndPayload(
status=status,
),
)
)
def get_current_span(self):
return self.spans[-1] if self.spans else None
def setup_logger(api: Telemetry, level: int = logging.INFO):
global BACKGROUND_LOGGER
BACKGROUND_LOGGER = BackgroundLogger(api)
logger = logging.getLogger()
logger.setLevel(level)
logger.addHandler(TelemetryHandler())
async def start_trace(name: str, attributes: Dict[str, Any] = None):
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
if BACKGROUND_LOGGER is None:
print("No Telemetry implementation set. Skipping trace initialization...")
return
trace_id = generate_short_uuid()
context = TraceContext(BACKGROUND_LOGGER, trace_id)
context.push_span(name, {"__root__": True, **(attributes or {})})
CURRENT_TRACE_CONTEXT = context
async def end_trace(status: SpanStatus = SpanStatus.OK):
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context is None:
return
context.pop_span(status)
CURRENT_TRACE_CONTEXT = None
def severity(levelname: str) -> LogSeverity:
if levelname == "DEBUG":
return LogSeverity.DEBUG
elif levelname == "INFO":
return LogSeverity.INFO
elif levelname == "WARNING":
return LogSeverity.WARNING
elif levelname == "ERROR":
return LogSeverity.ERROR
elif levelname == "CRITICAL":
return LogSeverity.CRITICAL
else:
raise ValueError(f"Unknown log level: {levelname}")
# TODO: ideally, the actual emitting should be done inside a separate daemon
# process completely isolated from the server
class TelemetryHandler(logging.Handler):
def emit(self, record: logging.LogRecord):
# horrendous hack to avoid logging from asyncio and getting into an infinite loop
if record.module in ("asyncio", "selector_events"):
return
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
if BACKGROUND_LOGGER is None:
raise RuntimeError("Telemetry API not initialized")
context = CURRENT_TRACE_CONTEXT
if context is None:
return
span = context.get_current_span()
if span is None:
return
BACKGROUND_LOGGER.log_event(
UnstructuredLogEvent(
trace_id=span.trace_id,
span_id=span.span_id,
timestamp=datetime.now(),
message=self.format(record),
severity=severity(record.levelname),
)
)
def close(self):
pass
def span(name: str, attributes: Dict[str, Any] = None):
def decorator(func):
@wraps(func)
def sync_wrapper(*args, **kwargs):
try:
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context:
context.push_span(name, attributes)
result = func(*args, **kwargs)
finally:
context.pop_span()
return result
@wraps(func)
async def async_wrapper(*args, **kwargs):
try:
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context:
context.push_span(name, attributes)
result = await func(*args, **kwargs)
finally:
context.pop_span()
return result
@wraps(func)
def wrapper(*args, **kwargs):
if asyncio.iscoroutinefunction(func):
return async_wrapper(*args, **kwargs)
else:
return sync_wrapper(*args, **kwargs)
return wrapper
return decorator