llama-stack-mirror/llama_stack/providers/utils/telemetry/trace_protocol.py
Ashwin Bharambe 2e5bfcd42a
Update Telemetry API so OpenAPI generation can work (#640)
We cannot use recursive types because not only does our OpenAPI
generator not like them, even if it did, it is not easy for all client
languages to automatically construct proper APIs (especially considering
garbage collection) around them. For now, we can return a `Dict[str,
SpanWithStatus]` instead of `SpanWithChildren` and rely on the client to
reconstruct the tree.

Also fixed a super subtle issue with the OpenAPI generation process
(monkey-patching of json_schema_type wasn't working because of import
reordering.)
2024-12-16 13:00:14 -08:00

145 lines
5.1 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import inspect
from datetime import datetime
from functools import wraps
from typing import Any, AsyncGenerator, Callable, Type, TypeVar
from uuid import UUID
from pydantic import BaseModel
T = TypeVar("T")
def serialize_value(value: Any) -> Any:
"""Serialize a single value into JSON-compatible format."""
if value is None:
return None
elif isinstance(value, (str, int, float, bool)):
return value
elif isinstance(value, BaseModel):
return value.model_dump()
elif isinstance(value, (list, tuple, set)):
return [serialize_value(item) for item in value]
elif isinstance(value, dict):
return {str(k): serialize_value(v) for k, v in value.items()}
elif isinstance(value, (datetime, UUID)):
return str(value)
else:
return str(value)
def trace_protocol(cls: Type[T]) -> Type[T]:
"""
A class decorator that automatically traces all methods in a protocol/base class
and its inheriting classes.
"""
def trace_method(method: Callable) -> Callable:
is_async = asyncio.iscoroutinefunction(method)
is_async_gen = inspect.isasyncgenfunction(method)
def create_span_context(self: Any, *args: Any, **kwargs: Any) -> tuple:
class_name = self.__class__.__name__
method_name = method.__name__
span_type = (
"async_generator" if is_async_gen else "async" if is_async else "sync"
)
sig = inspect.signature(method)
param_names = list(sig.parameters.keys())[1:] # Skip 'self'
combined_args = {}
for i, arg in enumerate(args):
param_name = (
param_names[i] if i < len(param_names) else f"position_{i+1}"
)
combined_args[param_name] = serialize_value(arg)
for k, v in kwargs.items():
combined_args[str(k)] = serialize_value(v)
span_attributes = {
"__autotraced__": True,
"__class__": class_name,
"__method__": method_name,
"__type__": span_type,
"__args__": str(combined_args),
}
return class_name, method_name, span_attributes
@wraps(method)
async def async_gen_wrapper(
self: Any, *args: Any, **kwargs: Any
) -> AsyncGenerator:
from llama_stack.providers.utils.telemetry import tracing
class_name, method_name, span_attributes = create_span_context(
self, *args, **kwargs
)
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
try:
count = 0
async for item in method(self, *args, **kwargs):
yield item
count += 1
finally:
span.set_attribute("chunk_count", count)
@wraps(method)
async def async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
from llama_stack.providers.utils.telemetry import tracing
class_name, method_name, span_attributes = create_span_context(
self, *args, **kwargs
)
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
try:
result = await method(self, *args, **kwargs)
span.set_attribute("output", serialize_value(result))
return result
except Exception as e:
span.set_attribute("error", str(e))
raise
@wraps(method)
def sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
from llama_stack.providers.utils.telemetry import tracing
class_name, method_name, span_attributes = create_span_context(
self, *args, **kwargs
)
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
try:
result = method(self, *args, **kwargs)
span.set_attribute("output", serialize_value(result))
return result
except Exception as _e:
raise
if is_async_gen:
return async_gen_wrapper
elif is_async:
return async_wrapper
else:
return sync_wrapper
original_init_subclass = getattr(cls, "__init_subclass__", None)
def __init_subclass__(cls_child, **kwargs): # noqa: N807
if original_init_subclass:
original_init_subclass(**kwargs)
for name, method in vars(cls_child).items():
if inspect.isfunction(method) and not name.startswith("_"):
setattr(cls_child, name, trace_method(method)) # noqa: B010
cls.__init_subclass__ = classmethod(__init_subclass__)
return cls