tracing for APIs

This commit is contained in:
Dinesh Yeduguru 2024-11-26 14:07:48 -08:00
parent c2a4850a79
commit af8a1fe5b3
8 changed files with 126 additions and 63 deletions

View file

@ -23,6 +23,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.distribution.tracing import trace_protocol, traced
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.common.deployment_types import * # noqa: F403 from llama_stack.apis.common.deployment_types import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
@ -418,6 +419,7 @@ class AgentStepResponse(BaseModel):
@runtime_checkable @runtime_checkable
@trace_protocol
class Agents(Protocol): class Agents(Protocol):
@webmethod(route="/agents/create") @webmethod(route="/agents/create")
async def create_agent( async def create_agent(
@ -426,6 +428,7 @@ class Agents(Protocol):
) -> AgentCreateResponse: ... ) -> AgentCreateResponse: ...
@webmethod(route="/agents/turn/create") @webmethod(route="/agents/turn/create")
@traced(input="messages")
async def create_agent_turn( async def create_agent_turn(
self, self,
agent_id: str, agent_id: str,

View file

@ -21,7 +21,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.distribution.tracing import trace_protocol from llama_stack.distribution.tracing import trace_protocol, traced
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403 from llama_stack.apis.models import * # noqa: F403
@ -227,6 +227,7 @@ class Inference(Protocol):
model_store: ModelStore model_store: ModelStore
@webmethod(route="/inference/completion") @webmethod(route="/inference/completion")
@traced(input="content")
async def completion( async def completion(
self, self,
model_id: str, model_id: str,
@ -238,6 +239,7 @@ class Inference(Protocol):
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: ... ) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: ...
@webmethod(route="/inference/chat-completion") @webmethod(route="/inference/chat-completion")
@traced(input="messages")
async def chat_completion( async def chat_completion(
self, self,
model_id: str, model_id: str,
@ -255,6 +257,7 @@ class Inference(Protocol):
]: ... ]: ...
@webmethod(route="/inference/embeddings") @webmethod(route="/inference/embeddings")
@traced(input="contents")
async def embeddings( async def embeddings(
self, self,
model_id: str, model_id: str,

View file

@ -16,7 +16,7 @@ from pydantic import BaseModel, Field
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403 from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.distribution.tracing import trace_protocol from llama_stack.distribution.tracing import trace_protocol, traced
@json_schema_type @json_schema_type
@ -50,6 +50,7 @@ class Memory(Protocol):
# this will just block now until documents are inserted, but it should # this will just block now until documents are inserted, but it should
# probably return a Job instance which can be polled for completion # probably return a Job instance which can be polled for completion
@traced(input="documents")
@webmethod(route="/memory/insert") @webmethod(route="/memory/insert")
async def insert_documents( async def insert_documents(
self, self,
@ -59,6 +60,7 @@ class Memory(Protocol):
) -> None: ... ) -> None: ...
@webmethod(route="/memory/query") @webmethod(route="/memory/query")
@traced(input="query")
async def query_documents( async def query_documents(
self, self,
bank_id: str, bank_id: str,

View file

@ -20,6 +20,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_stack.apis.resource import Resource, ResourceType from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.distribution.tracing import trace_protocol
@json_schema_type @json_schema_type
@ -129,6 +130,7 @@ class MemoryBankInput(BaseModel):
@runtime_checkable @runtime_checkable
@trace_protocol
class MemoryBanks(Protocol): class MemoryBanks(Protocol):
@webmethod(route="/memory-banks/list", method="GET") @webmethod(route="/memory-banks/list", method="GET")
async def list_memory_banks(self) -> List[MemoryBank]: ... async def list_memory_banks(self) -> List[MemoryBank]: ...

View file

@ -10,6 +10,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from llama_stack.apis.resource import Resource, ResourceType from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.distribution.tracing import trace_protocol
class CommonModelFields(BaseModel): class CommonModelFields(BaseModel):
@ -43,6 +44,7 @@ class ModelInput(CommonModelFields):
@runtime_checkable @runtime_checkable
@trace_protocol
class Models(Protocol): class Models(Protocol):
@webmethod(route="/models/list", method="GET") @webmethod(route="/models/list", method="GET")
async def list_models(self) -> List[Model]: ... async def list_models(self) -> List[Model]: ...

View file

@ -10,6 +10,8 @@ from typing import Any, Dict, List, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.distribution.tracing import trace_protocol, traced
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403 from llama_stack.apis.shields import * # noqa: F403
@ -43,10 +45,12 @@ class ShieldStore(Protocol):
@runtime_checkable @runtime_checkable
@trace_protocol
class Safety(Protocol): class Safety(Protocol):
shield_store: ShieldStore shield_store: ShieldStore
@webmethod(route="/safety/run-shield") @webmethod(route="/safety/run-shield")
@traced(input="messages")
async def run_shield( async def run_shield(
self, self,
shield_id: str, shield_id: str,

View file

@ -6,98 +6,132 @@
import asyncio import asyncio
import inspect import inspect
import json
from functools import wraps from functools import wraps
from typing import Any, AsyncGenerator, Callable, Type, TypeVar from typing import Any, AsyncGenerator, Callable, Type, TypeVar
from pydantic import BaseModel
from llama_stack.providers.utils.telemetry import tracing from llama_stack.providers.utils.telemetry import tracing
T = TypeVar("T") T = TypeVar("T")
def serialize_value(value: Any) -> str:
"""Helper function to serialize values to string representation."""
try:
if isinstance(value, BaseModel):
return value.model_dump_json()
elif isinstance(value, list) and value and isinstance(value[0], BaseModel):
return json.dumps([item.model_dump_json() for item in value])
elif hasattr(value, "to_dict"): # For objects with to_dict method
return json.dumps(value.to_dict())
elif isinstance(value, (dict, list, int, float, str, bool)):
return json.dumps(value)
else:
return str(value)
except Exception:
return str(value)
def traced(input: str = None):
"""
A method decorator that enables tracing with input and output capture.
Args:
input: Name of the input parameter to capture in traces
"""
def decorator(method: Callable) -> Callable:
method._trace_input = input
return method
return decorator
def trace_protocol(cls: Type[T]) -> Type[T]: def trace_protocol(cls: Type[T]) -> Type[T]:
""" """
A class decorator that automatically traces all methods in a protocol/base class A class decorator that automatically traces all methods in a protocol/base class
and its inheriting classes. Supports sync methods, async methods, and async generators. and its inheriting classes.
""" """
def trace_method(method: Callable) -> Callable: def trace_method(method: Callable) -> Callable:
is_async = asyncio.iscoroutinefunction(method) is_async = asyncio.iscoroutinefunction(method)
is_async_gen = inspect.isasyncgenfunction(method) is_async_gen = inspect.isasyncgenfunction(method)
def get_traced_input(args: tuple, kwargs: dict) -> dict:
trace_input = getattr(method, "_trace_input", None)
if not trace_input:
return {}
# Get the mapping of parameter names to values
sig = inspect.signature(method)
bound_args = sig.bind(None, *args, **kwargs) # None for self
bound_args.apply_defaults()
params = dict(list(bound_args.arguments.items())[1:]) # Skip 'self'
# Return the input value if the key exists
if trace_input in params:
return {"input": serialize_value(params[trace_input])}
return {}
def create_span_context(self: Any, *args: Any, **kwargs: Any) -> tuple:
class_name = self.__class__.__name__
method_name = method.__name__
span_type = (
"async_generator" if is_async_gen else "async" if is_async else "sync"
)
span_attributes = {
"class": class_name,
"method": method_name,
"type": span_type,
"args": serialize_value(args),
**get_traced_input(args, kwargs),
}
return class_name, method_name, span_attributes
@wraps(method) @wraps(method)
async def async_gen_wrapper( async def async_gen_wrapper(
self: Any, *args: Any, **kwargs: Any self: Any, *args: Any, **kwargs: Any
) -> AsyncGenerator: ) -> AsyncGenerator:
class_name = self.__class__.__name__ class_name, method_name, span_attributes = create_span_context(
method_name = f"{class_name}.{method.__name__}" self, *args, **kwargs
)
args_repr = [repr(arg) for arg in args] with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
kwargs_repr = [f"{k}={repr(v)}" for k, v in kwargs.items()]
signature = ", ".join(args_repr + kwargs_repr)
with tracing.span(
f"{class_name}.{method_name}",
{
"class": class_name,
"method": method_name,
"signature": signature,
"type": "async_generator",
},
) as span:
output = []
try: try:
async for item in method(self, *args, **kwargs): async for item in method(self, *args, **kwargs):
output.append(item)
yield item yield item
except Exception as e:
raise
finally: finally:
span.set_attribute("output", output) span.set_attribute("output", "streaming output")
@wraps(method) @wraps(method)
async def async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: async def async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
class_name = self.__class__.__name__ class_name, method_name, span_attributes = create_span_context(
method_name = f"{class_name}.{method.__name__}" self, *args, **kwargs
)
args_repr = [repr(arg) for arg in args] with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
kwargs_repr = [f"{k}={repr(v)}" for k, v in kwargs.items()]
signature = ", ".join(args_repr + kwargs_repr)
with tracing.span(
f"{class_name}.{method_name}",
{
"class": class_name,
"method": method_name,
"signature": signature,
"type": "async",
},
):
try: try:
result = await method(self, *args, **kwargs) result = await method(self, *args, **kwargs)
span.set_attribute("output", serialize_value(result))
return result return result
except Exception as e: except Exception as e:
span.set_attribute("error", str(e))
raise raise
@wraps(method) @wraps(method)
def sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: def sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
class_name = self.__class__.__name__ class_name, method_name, span_attributes = create_span_context(
method_name = f"{class_name}.{method.__name__}" self, *args, **kwargs
)
args_repr = [repr(arg) for arg in args] with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
kwargs_repr = [f"{k}={repr(v)}" for k, v in kwargs.items()]
signature = ", ".join(args_repr + kwargs_repr)
with tracing.span(
f"{class_name}.{method_name}",
{
"class": class_name,
"method": method_name,
"signature": signature,
"type": "sync",
},
):
try: try:
result = method(self, *args, **kwargs) result = method(self, *args, **kwargs)
span.set_attribute("output", serialize_value(result))
return result return result
except Exception as e: except Exception as e:
raise raise
@ -109,11 +143,6 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
else: else:
return sync_wrapper return sync_wrapper
# Trace all existing methods in the base class
for name, method in vars(cls).items():
if inspect.isfunction(method) and not name.startswith("__"):
setattr(cls, name, trace_method(method))
# Store the original __init_subclass__ if it exists # Store the original __init_subclass__ if it exists
original_init_subclass = getattr(cls, "__init_subclass__", None) original_init_subclass = getattr(cls, "__init_subclass__", None)
@ -123,10 +152,20 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
if original_init_subclass: if original_init_subclass:
original_init_subclass(**kwargs) original_init_subclass(**kwargs)
# Trace all methods defined in the child class traced_methods = {}
for parent in cls_child.__mro__[1:]: # Skip the class itself
for name, method in vars(parent).items():
if inspect.isfunction(method) and method._trace_input:
traced_methods[name] = method._trace_input
# Trace child class methods if their name matches a traced parent method
for name, method in vars(cls_child).items(): for name, method in vars(cls_child).items():
if inspect.isfunction(method) and not name.startswith("__"): if inspect.isfunction(method) and not name.startswith("_"):
setattr(cls_child, name, trace_method(method)) if name in traced_methods:
# Copy the trace configuration from the parent
method._trace_input = traced_methods[name]
cls_child.__dict__[name] = trace_method(method)
# Set the new __init_subclass__ # Set the new __init_subclass__
cls.__init_subclass__ = classmethod(__init_subclass__) cls.__init_subclass__ = classmethod(__init_subclass__)

View file

@ -204,6 +204,7 @@ class SpanContextManager:
def __init__(self, name: str, attributes: Dict[str, Any] = None): def __init__(self, name: str, attributes: Dict[str, Any] = None):
self.name = name self.name = name
self.attributes = attributes self.attributes = attributes
self.span = None
def __enter__(self): def __enter__(self):
global CURRENT_TRACE_CONTEXT global CURRENT_TRACE_CONTEXT
@ -225,10 +226,17 @@ class SpanContextManager:
self.span.attributes[key] = value self.span.attributes[key] = value
async def __aenter__(self): async def __aenter__(self):
return self.__enter__() global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context:
self.span = context.push_span(self.name, self.attributes)
return self
async def __aexit__(self, exc_type, exc_value, traceback): async def __aexit__(self, exc_type, exc_value, traceback):
self.__exit__(exc_type, exc_value, traceback) global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context:
context.pop_span()
def __call__(self, func: Callable): def __call__(self, func: Callable):
@wraps(func) @wraps(func)