mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Merge branch 'api_updates_3' into fix_cli_api_updates_3
This commit is contained in:
commit
28065ca53a
8 changed files with 345 additions and 80 deletions
|
@ -6,19 +6,20 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
from typing import Any, AsyncGenerator
|
from typing import Any, AsyncGenerator, List, Optional
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import httpx
|
import httpx
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_models.llama3.api import * # noqa: F403
|
||||||
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
||||||
|
|
||||||
from .event_logger import EventLogger
|
from .event_logger import EventLogger
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
|
||||||
|
|
||||||
|
|
||||||
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Inference:
|
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Inference:
|
||||||
return InferenceClient(config.url)
|
return InferenceClient(config.url)
|
||||||
|
|
5
llama_stack/providers/adapters/telemetry/__init__.py
Normal file
5
llama_stack/providers/adapters/telemetry/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
|
@ -0,0 +1,15 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .config import OpenTelemetryConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: OpenTelemetryConfig, _deps):
|
||||||
|
from .opentelemetry import OpenTelemetryAdapter
|
||||||
|
|
||||||
|
impl = OpenTelemetryAdapter(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
|
@ -0,0 +1,12 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class OpenTelemetryConfig(BaseModel):
|
||||||
|
jaeger_host: str = "localhost"
|
||||||
|
jaeger_port: int = 6831
|
|
@ -0,0 +1,201 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from opentelemetry import metrics, trace
|
||||||
|
from opentelemetry.exporter.jaeger.thrift import JaegerExporter
|
||||||
|
from opentelemetry.sdk.metrics import MeterProvider
|
||||||
|
from opentelemetry.sdk.metrics.export import (
|
||||||
|
ConsoleMetricExporter,
|
||||||
|
PeriodicExportingMetricReader,
|
||||||
|
)
|
||||||
|
from opentelemetry.sdk.resources import Resource
|
||||||
|
from opentelemetry.sdk.trace import TracerProvider
|
||||||
|
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||||
|
from opentelemetry.semconv.resource import ResourceAttributes
|
||||||
|
|
||||||
|
from llama_stack.apis.telemetry import * # noqa: F403
|
||||||
|
|
||||||
|
from .config import OpenTelemetryConfig
|
||||||
|
|
||||||
|
|
||||||
|
def string_to_trace_id(s: str) -> int:
|
||||||
|
# Convert the string to bytes and then to an integer
|
||||||
|
return int.from_bytes(s.encode(), byteorder="big", signed=False)
|
||||||
|
|
||||||
|
|
||||||
|
def string_to_span_id(s: str) -> int:
|
||||||
|
# Use only the first 8 bytes (64 bits) for span ID
|
||||||
|
return int.from_bytes(s.encode()[:8], byteorder="big", signed=False)
|
||||||
|
|
||||||
|
|
||||||
|
def is_tracing_enabled(tracer):
|
||||||
|
with tracer.start_as_current_span("check_tracing") as span:
|
||||||
|
return span.is_recording()
|
||||||
|
|
||||||
|
|
||||||
|
class OpenTelemetryAdapter(Telemetry):
|
||||||
|
def __init__(self, config: OpenTelemetryConfig):
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.resource = Resource.create(
|
||||||
|
{ResourceAttributes.SERVICE_NAME: "foobar-service"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set up tracing with Jaeger exporter
|
||||||
|
jaeger_exporter = JaegerExporter(
|
||||||
|
agent_host_name=self.config.jaeger_host,
|
||||||
|
agent_port=self.config.jaeger_port,
|
||||||
|
)
|
||||||
|
trace_provider = TracerProvider(resource=self.resource)
|
||||||
|
trace_processor = BatchSpanProcessor(jaeger_exporter)
|
||||||
|
trace_provider.add_span_processor(trace_processor)
|
||||||
|
trace.set_tracer_provider(trace_provider)
|
||||||
|
self.tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
# Set up metrics
|
||||||
|
metric_reader = PeriodicExportingMetricReader(ConsoleMetricExporter())
|
||||||
|
metric_provider = MeterProvider(
|
||||||
|
resource=self.resource, metric_readers=[metric_reader]
|
||||||
|
)
|
||||||
|
metrics.set_meter_provider(metric_provider)
|
||||||
|
self.meter = metrics.get_meter(__name__)
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
trace.get_tracer_provider().shutdown()
|
||||||
|
metrics.get_meter_provider().shutdown()
|
||||||
|
|
||||||
|
async def log_event(self, event: Event) -> None:
|
||||||
|
if isinstance(event, UnstructuredLogEvent):
|
||||||
|
self._log_unstructured(event)
|
||||||
|
elif isinstance(event, MetricEvent):
|
||||||
|
self._log_metric(event)
|
||||||
|
elif isinstance(event, StructuredLogEvent):
|
||||||
|
self._log_structured(event)
|
||||||
|
|
||||||
|
def _log_unstructured(self, event: UnstructuredLogEvent) -> None:
|
||||||
|
span = trace.get_current_span()
|
||||||
|
span.add_event(
|
||||||
|
name=event.message,
|
||||||
|
attributes={"severity": event.severity.value, **event.attributes},
|
||||||
|
timestamp=event.timestamp,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _log_metric(self, event: MetricEvent) -> None:
|
||||||
|
if isinstance(event.value, int):
|
||||||
|
self.meter.create_counter(
|
||||||
|
name=event.metric,
|
||||||
|
unit=event.unit,
|
||||||
|
description=f"Counter for {event.metric}",
|
||||||
|
).add(event.value, attributes=event.attributes)
|
||||||
|
elif isinstance(event.value, float):
|
||||||
|
self.meter.create_gauge(
|
||||||
|
name=event.metric,
|
||||||
|
unit=event.unit,
|
||||||
|
description=f"Gauge for {event.metric}",
|
||||||
|
).set(event.value, attributes=event.attributes)
|
||||||
|
|
||||||
|
def _log_structured(self, event: StructuredLogEvent) -> None:
|
||||||
|
if isinstance(event.payload, SpanStartPayload):
|
||||||
|
context = trace.set_span_in_context(
|
||||||
|
trace.NonRecordingSpan(
|
||||||
|
trace.SpanContext(
|
||||||
|
trace_id=string_to_trace_id(event.trace_id),
|
||||||
|
span_id=string_to_span_id(event.span_id),
|
||||||
|
is_remote=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
span = self.tracer.start_span(
|
||||||
|
name=event.payload.name,
|
||||||
|
kind=trace.SpanKind.INTERNAL,
|
||||||
|
context=context,
|
||||||
|
attributes=event.attributes,
|
||||||
|
)
|
||||||
|
|
||||||
|
if event.payload.parent_span_id:
|
||||||
|
span.set_parent(
|
||||||
|
trace.SpanContext(
|
||||||
|
trace_id=string_to_trace_id(event.trace_id),
|
||||||
|
span_id=string_to_span_id(event.payload.parent_span_id),
|
||||||
|
is_remote=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif isinstance(event.payload, SpanEndPayload):
|
||||||
|
span = trace.get_current_span()
|
||||||
|
span.set_status(
|
||||||
|
trace.Status(
|
||||||
|
trace.StatusCode.OK
|
||||||
|
if event.payload.status == SpanStatus.OK
|
||||||
|
else trace.StatusCode.ERROR
|
||||||
|
)
|
||||||
|
)
|
||||||
|
span.end(end_time=event.timestamp)
|
||||||
|
|
||||||
|
async def get_trace(self, trace_id: str) -> Trace:
|
||||||
|
# we need to look up the root span id
|
||||||
|
raise NotImplementedError("not yet no")
|
||||||
|
|
||||||
|
|
||||||
|
# Usage example
|
||||||
|
async def main():
|
||||||
|
telemetry = OpenTelemetryTelemetry("my-service")
|
||||||
|
await telemetry.initialize()
|
||||||
|
|
||||||
|
# Log an unstructured event
|
||||||
|
await telemetry.log_event(
|
||||||
|
UnstructuredLogEvent(
|
||||||
|
trace_id="trace123",
|
||||||
|
span_id="span456",
|
||||||
|
timestamp=datetime.now(),
|
||||||
|
message="This is a log message",
|
||||||
|
severity=LogSeverity.INFO,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log a metric event
|
||||||
|
await telemetry.log_event(
|
||||||
|
MetricEvent(
|
||||||
|
trace_id="trace123",
|
||||||
|
span_id="span456",
|
||||||
|
timestamp=datetime.now(),
|
||||||
|
metric="my_metric",
|
||||||
|
value=42,
|
||||||
|
unit="count",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log a structured event (span start)
|
||||||
|
await telemetry.log_event(
|
||||||
|
StructuredLogEvent(
|
||||||
|
trace_id="trace123",
|
||||||
|
span_id="span789",
|
||||||
|
timestamp=datetime.now(),
|
||||||
|
payload=SpanStartPayload(name="my_operation"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log a structured event (span end)
|
||||||
|
await telemetry.log_event(
|
||||||
|
StructuredLogEvent(
|
||||||
|
trace_id="trace123",
|
||||||
|
span_id="span789",
|
||||||
|
timestamp=datetime.now(),
|
||||||
|
payload=SpanEndPayload(status=SpanStatus.OK),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
await telemetry.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(main())
|
|
@ -26,6 +26,7 @@ from llama_stack.apis.memory import * # noqa: F403
|
||||||
from llama_stack.apis.safety import * # noqa: F403
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
|
from llama_stack.providers.utils.telemetry import tracing
|
||||||
|
|
||||||
from .persistence import AgentPersistence
|
from .persistence import AgentPersistence
|
||||||
from .rag.context_retriever import generate_rag_query
|
from .rag.context_retriever import generate_rag_query
|
||||||
|
@ -138,6 +139,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
async def create_session(self, name: str) -> str:
|
async def create_session(self, name: str) -> str:
|
||||||
return await self.storage.create_session(name)
|
return await self.storage.create_session(name)
|
||||||
|
|
||||||
|
@tracing.span("create_and_execute_turn")
|
||||||
async def create_and_execute_turn(
|
async def create_and_execute_turn(
|
||||||
self, request: AgentTurnCreateRequest
|
self, request: AgentTurnCreateRequest
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
@ -266,6 +268,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
yield final_response
|
yield final_response
|
||||||
|
|
||||||
|
@tracing.span("run_shields")
|
||||||
async def run_multiple_shields_wrapper(
|
async def run_multiple_shields_wrapper(
|
||||||
self,
|
self,
|
||||||
turn_id: str,
|
turn_id: str,
|
||||||
|
@ -348,9 +351,10 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
# TODO: find older context from the session and either replace it
|
# TODO: find older context from the session and either replace it
|
||||||
# or append with a sliding window. this is really a very simplistic implementation
|
# or append with a sliding window. this is really a very simplistic implementation
|
||||||
rag_context, bank_ids = await self._retrieve_context(
|
with tracing.span("retrieve_rag_context"):
|
||||||
session_id, input_messages, attachments
|
rag_context, bank_ids = await self._retrieve_context(
|
||||||
)
|
session_id, input_messages, attachments
|
||||||
|
)
|
||||||
|
|
||||||
step_id = str(uuid.uuid4())
|
step_id = str(uuid.uuid4())
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
|
@ -403,55 +407,57 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
tool_calls = []
|
tool_calls = []
|
||||||
content = ""
|
content = ""
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
async for chunk in self.inference_api.chat_completion(
|
|
||||||
self.agent_config.model,
|
|
||||||
input_messages,
|
|
||||||
tools=self._get_tools(),
|
|
||||||
tool_prompt_format=self.agent_config.tool_prompt_format,
|
|
||||||
stream=True,
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
):
|
|
||||||
event = chunk.event
|
|
||||||
if event.event_type == ChatCompletionResponseEventType.start:
|
|
||||||
continue
|
|
||||||
elif event.event_type == ChatCompletionResponseEventType.complete:
|
|
||||||
stop_reason = StopReason.end_of_turn
|
|
||||||
continue
|
|
||||||
|
|
||||||
delta = event.delta
|
with tracing.span("inference"):
|
||||||
if isinstance(delta, ToolCallDelta):
|
async for chunk in self.inference_api.chat_completion(
|
||||||
if delta.parse_status == ToolCallParseStatus.success:
|
self.agent_config.model,
|
||||||
tool_calls.append(delta.content)
|
input_messages,
|
||||||
|
tools=self._get_tools(),
|
||||||
|
tool_prompt_format=self.agent_config.tool_prompt_format,
|
||||||
|
stream=True,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
):
|
||||||
|
event = chunk.event
|
||||||
|
if event.event_type == ChatCompletionResponseEventType.start:
|
||||||
|
continue
|
||||||
|
elif event.event_type == ChatCompletionResponseEventType.complete:
|
||||||
|
stop_reason = StopReason.end_of_turn
|
||||||
|
continue
|
||||||
|
|
||||||
if stream:
|
delta = event.delta
|
||||||
yield AgentTurnResponseStreamChunk(
|
if isinstance(delta, ToolCallDelta):
|
||||||
event=AgentTurnResponseEvent(
|
if delta.parse_status == ToolCallParseStatus.success:
|
||||||
payload=AgentTurnResponseStepProgressPayload(
|
tool_calls.append(delta.content)
|
||||||
step_type=StepType.inference.value,
|
|
||||||
step_id=step_id,
|
if stream:
|
||||||
model_response_text_delta="",
|
yield AgentTurnResponseStreamChunk(
|
||||||
tool_call_delta=delta,
|
event=AgentTurnResponseEvent(
|
||||||
|
payload=AgentTurnResponseStepProgressPayload(
|
||||||
|
step_type=StepType.inference.value,
|
||||||
|
step_id=step_id,
|
||||||
|
model_response_text_delta="",
|
||||||
|
tool_call_delta=delta,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
elif isinstance(delta, str):
|
elif isinstance(delta, str):
|
||||||
content += delta
|
content += delta
|
||||||
if stream and event.stop_reason is None:
|
if stream and event.stop_reason is None:
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepProgressPayload(
|
payload=AgentTurnResponseStepProgressPayload(
|
||||||
step_type=StepType.inference.value,
|
step_type=StepType.inference.value,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
model_response_text_delta=event.delta,
|
model_response_text_delta=event.delta,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
else:
|
||||||
else:
|
raise ValueError(f"Unexpected delta type {type(delta)}")
|
||||||
raise ValueError(f"Unexpected delta type {type(delta)}")
|
|
||||||
|
|
||||||
if event.stop_reason is not None:
|
if event.stop_reason is not None:
|
||||||
stop_reason = event.stop_reason
|
stop_reason = event.stop_reason
|
||||||
|
|
||||||
stop_reason = stop_reason or StopReason.out_of_tokens
|
stop_reason = stop_reason or StopReason.out_of_tokens
|
||||||
message = CompletionMessage(
|
message = CompletionMessage(
|
||||||
|
@ -528,14 +534,15 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
result_messages = await execute_tool_call_maybe(
|
with tracing.span("tool_execution"):
|
||||||
self.tools_dict,
|
result_messages = await execute_tool_call_maybe(
|
||||||
[message],
|
self.tools_dict,
|
||||||
)
|
[message],
|
||||||
assert (
|
)
|
||||||
len(result_messages) == 1
|
assert (
|
||||||
), "Currently not supporting multiple messages"
|
len(result_messages) == 1
|
||||||
result_message = result_messages[0]
|
), "Currently not supporting multiple messages"
|
||||||
|
result_message = result_messages[0]
|
||||||
|
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
|
@ -669,7 +676,8 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
for a in attachments
|
for a in attachments
|
||||||
]
|
]
|
||||||
await self.memory_api.insert_documents(bank_id, documents)
|
with tracing.span("insert_documents"):
|
||||||
|
await self.memory_api.insert_documents(bank_id, documents)
|
||||||
else:
|
else:
|
||||||
session_info = await self.storage.get_session_info(session_id)
|
session_info = await self.storage.get_session_info(session_id)
|
||||||
if session_info.memory_bank_id:
|
if session_info.memory_bank_id:
|
||||||
|
|
|
@ -27,4 +27,18 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
config_class="llama_stack.providers.adapters.telemetry.sample.SampleConfig",
|
config_class="llama_stack.providers.adapters.telemetry.sample.SampleConfig",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.telemetry,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_id="opentelemetry-jaeger",
|
||||||
|
pip_packages=[
|
||||||
|
"opentelemetry-api",
|
||||||
|
"opentelemetry-sdk",
|
||||||
|
"opentelemetry-exporter-jaeger",
|
||||||
|
"opentelemetry-semantic-conventions",
|
||||||
|
],
|
||||||
|
module="llama_stack.providers.adapters.telemetry.opentelemetry",
|
||||||
|
config_class="llama_stack.providers.adapters.telemetry.opentelemetry.OpenTelemetryConfig",
|
||||||
|
),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -12,7 +12,7 @@ import threading
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Callable, Dict, List
|
||||||
|
|
||||||
|
|
||||||
from llama_stack.apis.telemetry import * # noqa: F403
|
from llama_stack.apis.telemetry import * # noqa: F403
|
||||||
|
@ -196,33 +196,40 @@ class TelemetryHandler(logging.Handler):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def span(name: str, attributes: Dict[str, Any] = None):
|
class SpanContextManager:
|
||||||
def decorator(func):
|
def __init__(self, name: str, attributes: Dict[str, Any] = None):
|
||||||
|
self.name = name
|
||||||
|
self.attributes = attributes
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
global CURRENT_TRACE_CONTEXT
|
||||||
|
context = CURRENT_TRACE_CONTEXT
|
||||||
|
if context:
|
||||||
|
context.push_span(self.name, self.attributes)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
global CURRENT_TRACE_CONTEXT
|
||||||
|
context = CURRENT_TRACE_CONTEXT
|
||||||
|
if context:
|
||||||
|
context.pop_span()
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self.__enter__()
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||||
|
self.__exit__(exc_type, exc_value, traceback)
|
||||||
|
|
||||||
|
def __call__(self, func: Callable):
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def sync_wrapper(*args, **kwargs):
|
def sync_wrapper(*args, **kwargs):
|
||||||
try:
|
with self:
|
||||||
global CURRENT_TRACE_CONTEXT
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
context = CURRENT_TRACE_CONTEXT
|
|
||||||
if context:
|
|
||||||
context.push_span(name, attributes)
|
|
||||||
result = func(*args, **kwargs)
|
|
||||||
finally:
|
|
||||||
context.pop_span()
|
|
||||||
return result
|
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
async def async_wrapper(*args, **kwargs):
|
async def async_wrapper(*args, **kwargs):
|
||||||
try:
|
async with self:
|
||||||
global CURRENT_TRACE_CONTEXT
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
context = CURRENT_TRACE_CONTEXT
|
|
||||||
if context:
|
|
||||||
context.push_span(name, attributes)
|
|
||||||
result = await func(*args, **kwargs)
|
|
||||||
finally:
|
|
||||||
context.pop_span()
|
|
||||||
return result
|
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
|
@ -233,4 +240,6 @@ def span(name: str, attributes: Dict[str, Any] = None):
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
def span(name: str, attributes: Dict[str, Any] = None):
|
||||||
|
return SpanContextManager(name, attributes)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue