forked from phoenix-oss/llama-stack-mirror
fix: tracing fixes for trace context propogation across coroutines (#1522)
# What does this PR do? This PR has two fixes needed for correct trace context propagation across asycnio boundary Fix 1: Start using context vars to store the global trace context. This is needed since we cannot use the same trace context across coroutines since the state is shared. each coroutine should have its own trace context so that each of it can start storing its state correctly. Fix 2: Start a new span for each new coroutines started for running shields to keep the span tree clean ## Test Plan ### Integration tests with server LLAMA_STACK_DISABLE_VERSION_CHECK=true llama stack run ~/.llama/distributions/together/together-run.yaml LLAMA_STACK_CONFIG=http://localhost:8321 pytest -s --safety-shield meta-llama/Llama-Guard-3-8B --text-model meta-llama/Llama-3.1-8B-Instruct server logs: https://gist.github.com/dineshyv/51ac5d9864ed031d0d89ce77352821fe test logs: https://gist.github.com/dineshyv/e66acc1c4648a42f1854600609c467f3 ### Integration tests with library client LLAMA_STACK_CONFIG=fireworks pytest -s --safety-shield meta-llama/Llama-Guard-3-8B --text-model meta-llama/Llama-3.1-8B-Instruct logs: https://gist.github.com/dineshyv/ca160696a0b167223378673fb1dcefb8 ### Apps test with server: ``` LLAMA_STACK_DISABLE_VERSION_CHECK=true llama stack run ~/.llama/distributions/together/together-run.yaml python -m examples.agents.e2e_loop_with_client_tools localhost 8321 ``` server logs: https://gist.github.com/dineshyv/1717a572d8f7c14279c36123b79c5797 app logs: https://gist.github.com/dineshyv/44167e9f57806a0ba3b710c32aec02f8
This commit is contained in:
parent
e3edca7739
commit
ead9397e22
3 changed files with 54 additions and 35 deletions
|
@ -181,7 +181,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
||||||
with tracing.span("create_and_execute_turn") as span:
|
async with tracing.span("create_and_execute_turn") as span:
|
||||||
span.set_attribute("session_id", request.session_id)
|
span.set_attribute("session_id", request.session_id)
|
||||||
span.set_attribute("agent_id", self.agent_id)
|
span.set_attribute("agent_id", self.agent_id)
|
||||||
span.set_attribute("request", request.model_dump_json())
|
span.set_attribute("request", request.model_dump_json())
|
||||||
|
@ -191,7 +191,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
|
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
|
||||||
with tracing.span("resume_turn") as span:
|
async with tracing.span("resume_turn") as span:
|
||||||
span.set_attribute("agent_id", self.agent_id)
|
span.set_attribute("agent_id", self.agent_id)
|
||||||
span.set_attribute("session_id", request.session_id)
|
span.set_attribute("session_id", request.session_id)
|
||||||
span.set_attribute("turn_id", request.turn_id)
|
span.set_attribute("turn_id", request.turn_id)
|
||||||
|
@ -390,7 +390,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
shields: List[str],
|
shields: List[str],
|
||||||
touchpoint: str,
|
touchpoint: str,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
with tracing.span("run_shields") as span:
|
async with tracing.span("run_shields") as span:
|
||||||
span.set_attribute("input", [m.model_dump_json() for m in messages])
|
span.set_attribute("input", [m.model_dump_json() for m in messages])
|
||||||
if len(shields) == 0:
|
if len(shields) == 0:
|
||||||
span.set_attribute("output", "no shields")
|
span.set_attribute("output", "no shields")
|
||||||
|
@ -508,7 +508,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
content = ""
|
content = ""
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
|
|
||||||
with tracing.span("inference") as span:
|
async with tracing.span("inference") as span:
|
||||||
async for chunk in await self.inference_api.chat_completion(
|
async for chunk in await self.inference_api.chat_completion(
|
||||||
self.agent_config.model,
|
self.agent_config.model,
|
||||||
input_messages,
|
input_messages,
|
||||||
|
@ -685,7 +685,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
tool_name = tool_call.tool_name
|
tool_name = tool_call.tool_name
|
||||||
if isinstance(tool_name, BuiltinTool):
|
if isinstance(tool_name, BuiltinTool):
|
||||||
tool_name = tool_name.value
|
tool_name = tool_name.value
|
||||||
with tracing.span(
|
async with tracing.span(
|
||||||
"tool_execution",
|
"tool_execution",
|
||||||
{
|
{
|
||||||
"tool_name": tool_name,
|
"tool_name": tool_name,
|
||||||
|
|
|
@ -10,6 +10,7 @@ from typing import List
|
||||||
|
|
||||||
from llama_stack.apis.inference import Message
|
from llama_stack.apis.inference import Message
|
||||||
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
||||||
|
from llama_stack.providers.utils.telemetry import tracing
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -32,15 +33,14 @@ class ShieldRunnerMixin:
|
||||||
self.output_shields = output_shields
|
self.output_shields = output_shields
|
||||||
|
|
||||||
async def run_multiple_shields(self, messages: List[Message], identifiers: List[str]) -> None:
|
async def run_multiple_shields(self, messages: List[Message], identifiers: List[str]) -> None:
|
||||||
responses = await asyncio.gather(
|
async def run_shield_with_span(identifier: str):
|
||||||
*[
|
async with tracing.span(f"run_shield_{identifier}"):
|
||||||
self.safety_api.run_shield(
|
return await self.safety_api.run_shield(
|
||||||
shield_id=identifier,
|
shield_id=identifier,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
)
|
)
|
||||||
for identifier in identifiers
|
|
||||||
]
|
responses = await asyncio.gather(*[run_shield_with_span(identifier) for identifier in identifiers])
|
||||||
)
|
|
||||||
for identifier, response in zip(identifiers, responses, strict=False):
|
for identifier, response in zip(identifiers, responses, strict=False):
|
||||||
if not response.violation:
|
if not response.violation:
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
|
import contextvars
|
||||||
import logging
|
import logging
|
||||||
import queue
|
import queue
|
||||||
import threading
|
import threading
|
||||||
|
@ -24,9 +25,10 @@ from llama_stack.apis.telemetry import (
|
||||||
Telemetry,
|
Telemetry,
|
||||||
UnstructuredLogEvent,
|
UnstructuredLogEvent,
|
||||||
)
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import serialize_value
|
from llama_stack.providers.utils.telemetry.trace_protocol import serialize_value
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
logger = get_logger(__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
def generate_short_uuid(len: int = 8):
|
def generate_short_uuid(len: int = 8):
|
||||||
|
@ -36,7 +38,7 @@ def generate_short_uuid(len: int = 8):
|
||||||
return encoded.rstrip(b"=").decode("ascii")[:len]
|
return encoded.rstrip(b"=").decode("ascii")[:len]
|
||||||
|
|
||||||
|
|
||||||
CURRENT_TRACE_CONTEXT = None
|
CURRENT_TRACE_CONTEXT = contextvars.ContextVar("trace_context", default=None)
|
||||||
BACKGROUND_LOGGER = None
|
BACKGROUND_LOGGER = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -51,7 +53,7 @@ class BackgroundLogger:
|
||||||
try:
|
try:
|
||||||
self.log_queue.put_nowait(event)
|
self.log_queue.put_nowait(event)
|
||||||
except queue.Full:
|
except queue.Full:
|
||||||
log.error("Log queue is full, dropping event")
|
logger.error("Log queue is full, dropping event")
|
||||||
|
|
||||||
def _process_logs(self):
|
def _process_logs(self):
|
||||||
while True:
|
while True:
|
||||||
|
@ -129,35 +131,36 @@ def setup_logger(api: Telemetry, level: int = logging.INFO):
|
||||||
|
|
||||||
if BACKGROUND_LOGGER is None:
|
if BACKGROUND_LOGGER is None:
|
||||||
BACKGROUND_LOGGER = BackgroundLogger(api)
|
BACKGROUND_LOGGER = BackgroundLogger(api)
|
||||||
logger = logging.getLogger()
|
root_logger = logging.getLogger()
|
||||||
logger.setLevel(level)
|
root_logger.setLevel(level)
|
||||||
logger.addHandler(TelemetryHandler())
|
root_logger.addHandler(TelemetryHandler())
|
||||||
|
|
||||||
|
|
||||||
async def start_trace(name: str, attributes: Dict[str, Any] = None) -> TraceContext:
|
async def start_trace(name: str, attributes: Dict[str, Any] = None) -> TraceContext:
|
||||||
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
|
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
|
||||||
|
|
||||||
if BACKGROUND_LOGGER is None:
|
if BACKGROUND_LOGGER is None:
|
||||||
log.info("No Telemetry implementation set. Skipping trace initialization...")
|
logger.debug("No Telemetry implementation set. Skipping trace initialization...")
|
||||||
return
|
return
|
||||||
|
|
||||||
trace_id = generate_short_uuid(16)
|
trace_id = generate_short_uuid(16)
|
||||||
context = TraceContext(BACKGROUND_LOGGER, trace_id)
|
context = TraceContext(BACKGROUND_LOGGER, trace_id)
|
||||||
context.push_span(name, {"__root__": True, **(attributes or {})})
|
context.push_span(name, {"__root__": True, **(attributes or {})})
|
||||||
|
|
||||||
CURRENT_TRACE_CONTEXT = context
|
CURRENT_TRACE_CONTEXT.set(context)
|
||||||
return context
|
return context
|
||||||
|
|
||||||
|
|
||||||
async def end_trace(status: SpanStatus = SpanStatus.OK):
|
async def end_trace(status: SpanStatus = SpanStatus.OK):
|
||||||
global CURRENT_TRACE_CONTEXT
|
global CURRENT_TRACE_CONTEXT
|
||||||
|
|
||||||
context = CURRENT_TRACE_CONTEXT
|
context = CURRENT_TRACE_CONTEXT.get()
|
||||||
if context is None:
|
if context is None:
|
||||||
|
logger.debug("No trace context to end")
|
||||||
return
|
return
|
||||||
|
|
||||||
context.pop_span(status)
|
context.pop_span(status)
|
||||||
CURRENT_TRACE_CONTEXT = None
|
CURRENT_TRACE_CONTEXT.set(None)
|
||||||
|
|
||||||
|
|
||||||
def severity(levelname: str) -> LogSeverity:
|
def severity(levelname: str) -> LogSeverity:
|
||||||
|
@ -188,7 +191,7 @@ class TelemetryHandler(logging.Handler):
|
||||||
if BACKGROUND_LOGGER is None:
|
if BACKGROUND_LOGGER is None:
|
||||||
raise RuntimeError("Telemetry API not initialized")
|
raise RuntimeError("Telemetry API not initialized")
|
||||||
|
|
||||||
context = CURRENT_TRACE_CONTEXT
|
context = CURRENT_TRACE_CONTEXT.get()
|
||||||
if context is None:
|
if context is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -218,15 +221,21 @@ class SpanContextManager:
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
global CURRENT_TRACE_CONTEXT
|
global CURRENT_TRACE_CONTEXT
|
||||||
context = CURRENT_TRACE_CONTEXT
|
context = CURRENT_TRACE_CONTEXT.get()
|
||||||
if context:
|
if not context:
|
||||||
|
logger.debug("No trace context to push span")
|
||||||
|
return self
|
||||||
|
|
||||||
self.span = context.push_span(self.name, self.attributes)
|
self.span = context.push_span(self.name, self.attributes)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback):
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
global CURRENT_TRACE_CONTEXT
|
global CURRENT_TRACE_CONTEXT
|
||||||
context = CURRENT_TRACE_CONTEXT
|
context = CURRENT_TRACE_CONTEXT.get()
|
||||||
if context:
|
if not context:
|
||||||
|
logger.debug("No trace context to pop span")
|
||||||
|
return
|
||||||
|
|
||||||
context.pop_span()
|
context.pop_span()
|
||||||
|
|
||||||
def set_attribute(self, key: str, value: Any):
|
def set_attribute(self, key: str, value: Any):
|
||||||
|
@ -237,15 +246,21 @@ class SpanContextManager:
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
global CURRENT_TRACE_CONTEXT
|
global CURRENT_TRACE_CONTEXT
|
||||||
context = CURRENT_TRACE_CONTEXT
|
context = CURRENT_TRACE_CONTEXT.get()
|
||||||
if context:
|
if not context:
|
||||||
|
logger.debug("No trace context to push span")
|
||||||
|
return self
|
||||||
|
|
||||||
self.span = context.push_span(self.name, self.attributes)
|
self.span = context.push_span(self.name, self.attributes)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||||
global CURRENT_TRACE_CONTEXT
|
global CURRENT_TRACE_CONTEXT
|
||||||
context = CURRENT_TRACE_CONTEXT
|
context = CURRENT_TRACE_CONTEXT.get()
|
||||||
if context:
|
if not context:
|
||||||
|
logger.debug("No trace context to pop span")
|
||||||
|
return
|
||||||
|
|
||||||
context.pop_span()
|
context.pop_span()
|
||||||
|
|
||||||
def __call__(self, func: Callable):
|
def __call__(self, func: Callable):
|
||||||
|
@ -275,7 +290,11 @@ def span(name: str, attributes: Dict[str, Any] = None):
|
||||||
|
|
||||||
def get_current_span() -> Optional[Span]:
|
def get_current_span() -> Optional[Span]:
|
||||||
global CURRENT_TRACE_CONTEXT
|
global CURRENT_TRACE_CONTEXT
|
||||||
context = CURRENT_TRACE_CONTEXT
|
if CURRENT_TRACE_CONTEXT is None:
|
||||||
|
logger.debug("No trace context to get current span")
|
||||||
|
return None
|
||||||
|
|
||||||
|
context = CURRENT_TRACE_CONTEXT.get()
|
||||||
if context:
|
if context:
|
||||||
return context.get_current_span()
|
return context.get_current_span()
|
||||||
return None
|
return None
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue