enhance the tracing span utility to make a context manager

This commit is contained in:
Ashwin Bharambe 2024-09-22 20:55:15 -07:00
parent 484dc2e4f5
commit 5d75c2437b
2 changed files with 96 additions and 77 deletions

View file

@ -26,6 +26,7 @@ from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403
from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.kvstore import KVStore
from llama_stack.providers.utils.telemetry import tracing
from .persistence import AgentPersistence from .persistence import AgentPersistence
from .rag.context_retriever import generate_rag_query from .rag.context_retriever import generate_rag_query
@ -138,6 +139,7 @@ class ChatAgent(ShieldRunnerMixin):
async def create_session(self, name: str) -> str: async def create_session(self, name: str) -> str:
return await self.storage.create_session(name) return await self.storage.create_session(name)
@tracing.span("create_and_execute_turn")
async def create_and_execute_turn( async def create_and_execute_turn(
self, request: AgentTurnCreateRequest self, request: AgentTurnCreateRequest
) -> AsyncGenerator: ) -> AsyncGenerator:
@ -266,6 +268,7 @@ class ChatAgent(ShieldRunnerMixin):
yield final_response yield final_response
@tracing.span("run_shields")
async def run_multiple_shields_wrapper( async def run_multiple_shields_wrapper(
self, self,
turn_id: str, turn_id: str,
@ -348,9 +351,10 @@ class ChatAgent(ShieldRunnerMixin):
# TODO: find older context from the session and either replace it # TODO: find older context from the session and either replace it
# or append with a sliding window. this is really a very simplistic implementation # or append with a sliding window. this is really a very simplistic implementation
rag_context, bank_ids = await self._retrieve_context( with tracing.span("retrieve_rag_context"):
session_id, input_messages, attachments rag_context, bank_ids = await self._retrieve_context(
) session_id, input_messages, attachments
)
step_id = str(uuid.uuid4()) step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
@ -403,55 +407,57 @@ class ChatAgent(ShieldRunnerMixin):
tool_calls = [] tool_calls = []
content = "" content = ""
stop_reason = None stop_reason = None
async for chunk in self.inference_api.chat_completion(
self.agent_config.model,
input_messages,
tools=self._get_tools(),
tool_prompt_format=self.agent_config.tool_prompt_format,
stream=True,
sampling_params=sampling_params,
):
event = chunk.event
if event.event_type == ChatCompletionResponseEventType.start:
continue
elif event.event_type == ChatCompletionResponseEventType.complete:
stop_reason = StopReason.end_of_turn
continue
delta = event.delta with tracing.span("inference"):
if isinstance(delta, ToolCallDelta): async for chunk in self.inference_api.chat_completion(
if delta.parse_status == ToolCallParseStatus.success: self.agent_config.model,
tool_calls.append(delta.content) input_messages,
tools=self._get_tools(),
tool_prompt_format=self.agent_config.tool_prompt_format,
stream=True,
sampling_params=sampling_params,
):
event = chunk.event
if event.event_type == ChatCompletionResponseEventType.start:
continue
elif event.event_type == ChatCompletionResponseEventType.complete:
stop_reason = StopReason.end_of_turn
continue
if stream: delta = event.delta
yield AgentTurnResponseStreamChunk( if isinstance(delta, ToolCallDelta):
event=AgentTurnResponseEvent( if delta.parse_status == ToolCallParseStatus.success:
payload=AgentTurnResponseStepProgressPayload( tool_calls.append(delta.content)
step_type=StepType.inference.value,
step_id=step_id, if stream:
model_response_text_delta="", yield AgentTurnResponseStreamChunk(
tool_call_delta=delta, event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_id=step_id,
model_response_text_delta="",
tool_call_delta=delta,
)
) )
) )
)
elif isinstance(delta, str): elif isinstance(delta, str):
content += delta content += delta
if stream and event.stop_reason is None: if stream and event.stop_reason is None:
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload( payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value, step_type=StepType.inference.value,
step_id=step_id, step_id=step_id,
model_response_text_delta=event.delta, model_response_text_delta=event.delta,
)
) )
) )
) else:
else: raise ValueError(f"Unexpected delta type {type(delta)}")
raise ValueError(f"Unexpected delta type {type(delta)}")
if event.stop_reason is not None: if event.stop_reason is not None:
stop_reason = event.stop_reason stop_reason = event.stop_reason
stop_reason = stop_reason or StopReason.out_of_tokens stop_reason = stop_reason or StopReason.out_of_tokens
message = CompletionMessage( message = CompletionMessage(
@ -528,14 +534,15 @@ class ChatAgent(ShieldRunnerMixin):
) )
) )
result_messages = await execute_tool_call_maybe( with tracing.span("tool_execution"):
self.tools_dict, result_messages = await execute_tool_call_maybe(
[message], self.tools_dict,
) [message],
assert ( )
len(result_messages) == 1 assert (
), "Currently not supporting multiple messages" len(result_messages) == 1
result_message = result_messages[0] ), "Currently not supporting multiple messages"
result_message = result_messages[0]
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
@ -669,7 +676,8 @@ class ChatAgent(ShieldRunnerMixin):
) )
for a in attachments for a in attachments
] ]
await self.memory_api.insert_documents(bank_id, documents) with tracing.span("insert_documents"):
await self.memory_api.insert_documents(bank_id, documents)
else: else:
session_info = await self.storage.get_session_info(session_id) session_info = await self.storage.get_session_info(session_id)
if session_info.memory_bank_id: if session_info.memory_bank_id:

View file

@ -12,7 +12,7 @@ import threading
import uuid import uuid
from datetime import datetime from datetime import datetime
from functools import wraps from functools import wraps
from typing import Any, Dict, List from typing import Any, Callable, Dict, List
from llama_stack.apis.telemetry import * # noqa: F403 from llama_stack.apis.telemetry import * # noqa: F403
@ -196,33 +196,42 @@ class TelemetryHandler(logging.Handler):
pass pass
def span(name: str, attributes: Dict[str, Any] = None): class SpanContextManager:
def decorator(func): def __init__(self, name: str, attributes: Dict[str, Any] = None):
self.name = name
self.attributes = attributes
def __enter__(self):
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context:
context.push_span(self.name, self.attributes)
return self
def __exit__(self, exc_type, exc_value, traceback):
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context:
context.pop_span()
async def __aenter__(self):
return self.__enter__()
async def __aexit__(self, exc_type, exc_value, traceback):
self.__exit__(exc_type, exc_value, traceback)
def __call__(self, func: Callable):
@wraps(func) @wraps(func)
def sync_wrapper(*args, **kwargs): def sync_wrapper(*args, **kwargs):
try: print("sync wrapper")
global CURRENT_TRACE_CONTEXT with self:
return func(*args, **kwargs)
context = CURRENT_TRACE_CONTEXT
if context:
context.push_span(name, attributes)
result = func(*args, **kwargs)
finally:
context.pop_span()
return result
@wraps(func) @wraps(func)
async def async_wrapper(*args, **kwargs): async def async_wrapper(*args, **kwargs):
try: print("async wrapper")
global CURRENT_TRACE_CONTEXT async with self:
return await func(*args, **kwargs)
context = CURRENT_TRACE_CONTEXT
if context:
context.push_span(name, attributes)
result = await func(*args, **kwargs)
finally:
context.pop_span()
return result
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
@ -233,4 +242,6 @@ def span(name: str, attributes: Dict[str, Any] = None):
return wrapper return wrapper
return decorator
def span(name: str, attributes: Dict[str, Any] = None):
return SpanContextManager(name, attributes)