mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
enhance the tracing span utility to make a context manager
This commit is contained in:
parent
484dc2e4f5
commit
5d75c2437b
2 changed files with 96 additions and 77 deletions
|
@ -26,6 +26,7 @@ from llama_stack.apis.memory import * # noqa: F403
|
|||
from llama_stack.apis.safety import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
||||
from .persistence import AgentPersistence
|
||||
from .rag.context_retriever import generate_rag_query
|
||||
|
@ -138,6 +139,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
async def create_session(self, name: str) -> str:
|
||||
return await self.storage.create_session(name)
|
||||
|
||||
@tracing.span("create_and_execute_turn")
|
||||
async def create_and_execute_turn(
|
||||
self, request: AgentTurnCreateRequest
|
||||
) -> AsyncGenerator:
|
||||
|
@ -266,6 +268,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
yield final_response
|
||||
|
||||
@tracing.span("run_shields")
|
||||
async def run_multiple_shields_wrapper(
|
||||
self,
|
||||
turn_id: str,
|
||||
|
@ -348,9 +351,10 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
# TODO: find older context from the session and either replace it
|
||||
# or append with a sliding window. this is really a very simplistic implementation
|
||||
rag_context, bank_ids = await self._retrieve_context(
|
||||
session_id, input_messages, attachments
|
||||
)
|
||||
with tracing.span("retrieve_rag_context"):
|
||||
rag_context, bank_ids = await self._retrieve_context(
|
||||
session_id, input_messages, attachments
|
||||
)
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
|
@ -403,55 +407,57 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
tool_calls = []
|
||||
content = ""
|
||||
stop_reason = None
|
||||
async for chunk in self.inference_api.chat_completion(
|
||||
self.agent_config.model,
|
||||
input_messages,
|
||||
tools=self._get_tools(),
|
||||
tool_prompt_format=self.agent_config.tool_prompt_format,
|
||||
stream=True,
|
||||
sampling_params=sampling_params,
|
||||
):
|
||||
event = chunk.event
|
||||
if event.event_type == ChatCompletionResponseEventType.start:
|
||||
continue
|
||||
elif event.event_type == ChatCompletionResponseEventType.complete:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
continue
|
||||
|
||||
delta = event.delta
|
||||
if isinstance(delta, ToolCallDelta):
|
||||
if delta.parse_status == ToolCallParseStatus.success:
|
||||
tool_calls.append(delta.content)
|
||||
with tracing.span("inference"):
|
||||
async for chunk in self.inference_api.chat_completion(
|
||||
self.agent_config.model,
|
||||
input_messages,
|
||||
tools=self._get_tools(),
|
||||
tool_prompt_format=self.agent_config.tool_prompt_format,
|
||||
stream=True,
|
||||
sampling_params=sampling_params,
|
||||
):
|
||||
event = chunk.event
|
||||
if event.event_type == ChatCompletionResponseEventType.start:
|
||||
continue
|
||||
elif event.event_type == ChatCompletionResponseEventType.complete:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
continue
|
||||
|
||||
if stream:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.inference.value,
|
||||
step_id=step_id,
|
||||
model_response_text_delta="",
|
||||
tool_call_delta=delta,
|
||||
delta = event.delta
|
||||
if isinstance(delta, ToolCallDelta):
|
||||
if delta.parse_status == ToolCallParseStatus.success:
|
||||
tool_calls.append(delta.content)
|
||||
|
||||
if stream:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.inference.value,
|
||||
step_id=step_id,
|
||||
model_response_text_delta="",
|
||||
tool_call_delta=delta,
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
elif isinstance(delta, str):
|
||||
content += delta
|
||||
if stream and event.stop_reason is None:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.inference.value,
|
||||
step_id=step_id,
|
||||
model_response_text_delta=event.delta,
|
||||
elif isinstance(delta, str):
|
||||
content += delta
|
||||
if stream and event.stop_reason is None:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.inference.value,
|
||||
step_id=step_id,
|
||||
model_response_text_delta=event.delta,
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unexpected delta type {type(delta)}")
|
||||
else:
|
||||
raise ValueError(f"Unexpected delta type {type(delta)}")
|
||||
|
||||
if event.stop_reason is not None:
|
||||
stop_reason = event.stop_reason
|
||||
if event.stop_reason is not None:
|
||||
stop_reason = event.stop_reason
|
||||
|
||||
stop_reason = stop_reason or StopReason.out_of_tokens
|
||||
message = CompletionMessage(
|
||||
|
@ -528,14 +534,15 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
|
||||
result_messages = await execute_tool_call_maybe(
|
||||
self.tools_dict,
|
||||
[message],
|
||||
)
|
||||
assert (
|
||||
len(result_messages) == 1
|
||||
), "Currently not supporting multiple messages"
|
||||
result_message = result_messages[0]
|
||||
with tracing.span("tool_execution"):
|
||||
result_messages = await execute_tool_call_maybe(
|
||||
self.tools_dict,
|
||||
[message],
|
||||
)
|
||||
assert (
|
||||
len(result_messages) == 1
|
||||
), "Currently not supporting multiple messages"
|
||||
result_message = result_messages[0]
|
||||
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
|
@ -669,7 +676,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
for a in attachments
|
||||
]
|
||||
await self.memory_api.insert_documents(bank_id, documents)
|
||||
with tracing.span("insert_documents"):
|
||||
await self.memory_api.insert_documents(bank_id, documents)
|
||||
else:
|
||||
session_info = await self.storage.get_session_info(session_id)
|
||||
if session_info.memory_bank_id:
|
||||
|
|
|
@ -12,7 +12,7 @@ import threading
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
||||
|
||||
from llama_stack.apis.telemetry import * # noqa: F403
|
||||
|
@ -196,33 +196,42 @@ class TelemetryHandler(logging.Handler):
|
|||
pass
|
||||
|
||||
|
||||
def span(name: str, attributes: Dict[str, Any] = None):
|
||||
def decorator(func):
|
||||
class SpanContextManager:
|
||||
def __init__(self, name: str, attributes: Dict[str, Any] = None):
|
||||
self.name = name
|
||||
self.attributes = attributes
|
||||
|
||||
def __enter__(self):
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
context = CURRENT_TRACE_CONTEXT
|
||||
if context:
|
||||
context.push_span(self.name, self.attributes)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
context = CURRENT_TRACE_CONTEXT
|
||||
if context:
|
||||
context.pop_span()
|
||||
|
||||
async def __aenter__(self):
|
||||
return self.__enter__()
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
self.__exit__(exc_type, exc_value, traceback)
|
||||
|
||||
def __call__(self, func: Callable):
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
try:
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
|
||||
context = CURRENT_TRACE_CONTEXT
|
||||
if context:
|
||||
context.push_span(name, attributes)
|
||||
result = func(*args, **kwargs)
|
||||
finally:
|
||||
context.pop_span()
|
||||
return result
|
||||
print("sync wrapper")
|
||||
with self:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
try:
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
|
||||
context = CURRENT_TRACE_CONTEXT
|
||||
if context:
|
||||
context.push_span(name, attributes)
|
||||
result = await func(*args, **kwargs)
|
||||
finally:
|
||||
context.pop_span()
|
||||
return result
|
||||
print("async wrapper")
|
||||
async with self:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
|
@ -233,4 +242,6 @@ def span(name: str, attributes: Dict[str, Any] = None):
|
|||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
def span(name: str, attributes: Dict[str, Any] = None):
|
||||
return SpanContextManager(name, attributes)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue