This commit is contained in:
Dinesh Yeduguru 2024-11-26 15:41:08 -08:00
parent af8a1fe5b3
commit b3e149334a
3 changed files with 22 additions and 6 deletions

View file

@ -155,17 +155,19 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
traced_methods = {} traced_methods = {}
for parent in cls_child.__mro__[1:]: # Skip the class itself for parent in cls_child.__mro__[1:]: # Skip the class itself
for name, method in vars(parent).items(): for name, method in vars(parent).items():
if inspect.isfunction(method) and method._trace_input: if inspect.isfunction(method) and getattr(
traced_methods[name] = method._trace_input method, "_trace_input", None
): # noqa: B009
traced_methods[name] = getattr(method, "_trace_input") # noqa: B009
# Trace child class methods if their name matches a traced parent method # Trace child class methods if their name matches a traced parent method
for name, method in vars(cls_child).items(): for name, method in vars(cls_child).items():
if inspect.isfunction(method) and not name.startswith("_"): if inspect.isfunction(method) and not name.startswith("_"):
if name in traced_methods: if name in traced_methods:
# Copy the trace configuration from the parent # Copy the trace configuration from the parent
method._trace_input = traced_methods[name] setattr(method, "_trace_input", traced_methods[name]) # noqa: B010
cls_child.__dict__[name] = trace_method(method) setattr(cls_child, name, trace_method(method)) # noqa: B010
# Set the new __init_subclass__ # Set the new __init_subclass__
cls.__init_subclass__ = classmethod(__init_subclass__) cls.__init_subclass__ = classmethod(__init_subclass__)

View file

@ -568,7 +568,13 @@ class ChatAgent(ShieldRunnerMixin):
) )
) )
with tracing.span("tool_execution"): with tracing.span(
"tool_execution",
{
"tool_name": tool_call.tool_name,
"input": message.model_dump_json(),
},
) as span:
result_messages = await execute_tool_call_maybe( result_messages = await execute_tool_call_maybe(
self.tools_dict, self.tools_dict,
[message], [message],
@ -577,6 +583,7 @@ class ChatAgent(ShieldRunnerMixin):
len(result_messages) == 1 len(result_messages) == 1
), "Currently not supporting multiple messages" ), "Currently not supporting multiple messages"
result_message = result_messages[0] result_message = result_messages[0]
span.set_attribute("output", result_message.model_dump_json())
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(

View file

@ -5,7 +5,9 @@
# the root directory of this source tree. # the root directory of this source tree.
import json import json
from typing import Optional from typing import List, Optional
from llama_stack.apis.telemetry.telemetry import Trace
from .config import LogFormat from .config import LogFormat
@ -52,6 +54,11 @@ class ConsoleTelemetryImpl(Telemetry):
async def get_trace(self, trace_id: str) -> Trace: async def get_trace(self, trace_id: str) -> Trace:
raise NotImplementedError() raise NotImplementedError()
async def get_traces_for_session(
self, session_id: str, lookback: str = "1h", limit: int = 100
) -> List[Trace]:
raise NotImplementedError()
COLORS = { COLORS = {
"reset": "\033[0m", "reset": "\033[0m",