From c2a4850a7963cd4678b8b7c79657f44364949b35 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 26 Nov 2024 11:28:24 -0800 Subject: [PATCH] tracing decorator for apiis --- llama_stack/apis/inference/inference.py | 3 + llama_stack/apis/memory/memory.py | 2 + llama_stack/apis/shields/shields.py | 2 + llama_stack/distribution/tracing.py | 134 ++++++++++++++++++++++++ 4 files changed, 141 insertions(+) create mode 100644 llama_stack/distribution/tracing.py diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 5aadd97c7..85b29a147 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -21,6 +21,8 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field from typing_extensions import Annotated +from llama_stack.distribution.tracing import trace_protocol + from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.models import * # noqa: F403 @@ -220,6 +222,7 @@ class ModelStore(Protocol): @runtime_checkable +@trace_protocol class Inference(Protocol): model_store: ModelStore diff --git a/llama_stack/apis/memory/memory.py b/llama_stack/apis/memory/memory.py index 48b6e2241..b75df8a1a 100644 --- a/llama_stack/apis/memory/memory.py +++ b/llama_stack/apis/memory/memory.py @@ -16,6 +16,7 @@ from pydantic import BaseModel, Field from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.memory_banks import * # noqa: F403 +from llama_stack.distribution.tracing import trace_protocol @json_schema_type @@ -43,6 +44,7 @@ class MemoryBankStore(Protocol): @runtime_checkable +@trace_protocol class Memory(Protocol): memory_bank_store: MemoryBankStore diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py index 5ee444f68..b28605727 100644 --- a/llama_stack/apis/shields/shields.py +++ b/llama_stack/apis/shields/shields.py @@ -10,6 +10,7 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel from llama_stack.apis.resource import Resource, ResourceType +from llama_stack.distribution.tracing import trace_protocol class CommonShieldFields(BaseModel): @@ -38,6 +39,7 @@ class ShieldInput(CommonShieldFields): @runtime_checkable +@trace_protocol class Shields(Protocol): @webmethod(route="/shields/list", method="GET") async def list_shields(self) -> List[Shield]: ... diff --git a/llama_stack/distribution/tracing.py b/llama_stack/distribution/tracing.py new file mode 100644 index 000000000..caddeee72 --- /dev/null +++ b/llama_stack/distribution/tracing.py @@ -0,0 +1,134 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +import inspect +from functools import wraps +from typing import Any, AsyncGenerator, Callable, Type, TypeVar + +from llama_stack.providers.utils.telemetry import tracing + +T = TypeVar("T") + + +def trace_protocol(cls: Type[T]) -> Type[T]: + """ + A class decorator that automatically traces all methods in a protocol/base class + and its inheriting classes. Supports sync methods, async methods, and async generators. + """ + + def trace_method(method: Callable) -> Callable: + is_async = asyncio.iscoroutinefunction(method) + is_async_gen = inspect.isasyncgenfunction(method) + + @wraps(method) + async def async_gen_wrapper( + self: Any, *args: Any, **kwargs: Any + ) -> AsyncGenerator: + class_name = self.__class__.__name__ + method_name = f"{class_name}.{method.__name__}" + + args_repr = [repr(arg) for arg in args] + kwargs_repr = [f"{k}={repr(v)}" for k, v in kwargs.items()] + signature = ", ".join(args_repr + kwargs_repr) + + with tracing.span( + f"{class_name}.{method_name}", + { + "class": class_name, + "method": method_name, + "signature": signature, + "type": "async_generator", + }, + ) as span: + output = [] + try: + async for item in method(self, *args, **kwargs): + output.append(item) + yield item + except Exception as e: + raise + finally: + span.set_attribute("output", output) + + @wraps(method) + async def async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + class_name = self.__class__.__name__ + method_name = f"{class_name}.{method.__name__}" + + args_repr = [repr(arg) for arg in args] + kwargs_repr = [f"{k}={repr(v)}" for k, v in kwargs.items()] + signature = ", ".join(args_repr + kwargs_repr) + + with tracing.span( + f"{class_name}.{method_name}", + { + "class": class_name, + "method": method_name, + "signature": signature, + "type": "async", + }, + ): + try: + result = await method(self, *args, **kwargs) + return result + except Exception as e: + raise + + @wraps(method) + def sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + class_name = self.__class__.__name__ + method_name = f"{class_name}.{method.__name__}" + + args_repr = [repr(arg) for arg in args] + kwargs_repr = [f"{k}={repr(v)}" for k, v in kwargs.items()] + signature = ", ".join(args_repr + kwargs_repr) + + with tracing.span( + f"{class_name}.{method_name}", + { + "class": class_name, + "method": method_name, + "signature": signature, + "type": "sync", + }, + ): + try: + result = method(self, *args, **kwargs) + return result + except Exception as e: + raise + + if is_async_gen: + return async_gen_wrapper + elif is_async: + return async_wrapper + else: + return sync_wrapper + + # Trace all existing methods in the base class + for name, method in vars(cls).items(): + if inspect.isfunction(method) and not name.startswith("__"): + setattr(cls, name, trace_method(method)) + + # Store the original __init_subclass__ if it exists + original_init_subclass = getattr(cls, "__init_subclass__", None) + + # Define a new __init_subclass__ to handle child classes + def __init_subclass__(cls_child, **kwargs): # noqa: N807 + # Call original __init_subclass__ if it exists + if original_init_subclass: + original_init_subclass(**kwargs) + + # Trace all methods defined in the child class + for name, method in vars(cls_child).items(): + if inspect.isfunction(method) and not name.startswith("__"): + setattr(cls_child, name, trace_method(method)) + + # Set the new __init_subclass__ + cls.__init_subclass__ = classmethod(__init_subclass__) + + return cls