mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
fix:Tracing fixes for correct trace context propogation
This commit is contained in:
parent
0db3a2f511
commit
38805e0ca1
3 changed files with 54 additions and 35 deletions
|
@ -181,7 +181,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
||||||
with tracing.span("create_and_execute_turn") as span:
|
async with tracing.span("create_and_execute_turn") as span:
|
||||||
span.set_attribute("session_id", request.session_id)
|
span.set_attribute("session_id", request.session_id)
|
||||||
span.set_attribute("agent_id", self.agent_id)
|
span.set_attribute("agent_id", self.agent_id)
|
||||||
span.set_attribute("request", request.model_dump_json())
|
span.set_attribute("request", request.model_dump_json())
|
||||||
|
@ -191,7 +191,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
|
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
|
||||||
with tracing.span("resume_turn") as span:
|
async with tracing.span("resume_turn") as span:
|
||||||
span.set_attribute("agent_id", self.agent_id)
|
span.set_attribute("agent_id", self.agent_id)
|
||||||
span.set_attribute("session_id", request.session_id)
|
span.set_attribute("session_id", request.session_id)
|
||||||
span.set_attribute("turn_id", request.turn_id)
|
span.set_attribute("turn_id", request.turn_id)
|
||||||
|
@ -390,7 +390,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
shields: List[str],
|
shields: List[str],
|
||||||
touchpoint: str,
|
touchpoint: str,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
with tracing.span("run_shields") as span:
|
async with tracing.span("run_shields") as span:
|
||||||
span.set_attribute("input", [m.model_dump_json() for m in messages])
|
span.set_attribute("input", [m.model_dump_json() for m in messages])
|
||||||
if len(shields) == 0:
|
if len(shields) == 0:
|
||||||
span.set_attribute("output", "no shields")
|
span.set_attribute("output", "no shields")
|
||||||
|
@ -508,7 +508,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
content = ""
|
content = ""
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
|
|
||||||
with tracing.span("inference") as span:
|
async with tracing.span("inference") as span:
|
||||||
async for chunk in await self.inference_api.chat_completion(
|
async for chunk in await self.inference_api.chat_completion(
|
||||||
self.agent_config.model,
|
self.agent_config.model,
|
||||||
input_messages,
|
input_messages,
|
||||||
|
@ -685,7 +685,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
tool_name = tool_call.tool_name
|
tool_name = tool_call.tool_name
|
||||||
if isinstance(tool_name, BuiltinTool):
|
if isinstance(tool_name, BuiltinTool):
|
||||||
tool_name = tool_name.value
|
tool_name = tool_name.value
|
||||||
with tracing.span(
|
async with tracing.span(
|
||||||
"tool_execution",
|
"tool_execution",
|
||||||
{
|
{
|
||||||
"tool_name": tool_name,
|
"tool_name": tool_name,
|
||||||
|
|
|
@ -10,6 +10,7 @@ from typing import List
|
||||||
|
|
||||||
from llama_stack.apis.inference import Message
|
from llama_stack.apis.inference import Message
|
||||||
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
||||||
|
from llama_stack.providers.utils.telemetry import tracing
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -32,15 +33,14 @@ class ShieldRunnerMixin:
|
||||||
self.output_shields = output_shields
|
self.output_shields = output_shields
|
||||||
|
|
||||||
async def run_multiple_shields(self, messages: List[Message], identifiers: List[str]) -> None:
|
async def run_multiple_shields(self, messages: List[Message], identifiers: List[str]) -> None:
|
||||||
responses = await asyncio.gather(
|
async def run_shield_with_span(identifier: str):
|
||||||
*[
|
async with tracing.span(f"run_shield_{identifier}"):
|
||||||
self.safety_api.run_shield(
|
return await self.safety_api.run_shield(
|
||||||
shield_id=identifier,
|
shield_id=identifier,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
)
|
)
|
||||||
for identifier in identifiers
|
|
||||||
]
|
responses = await asyncio.gather(*[run_shield_with_span(identifier) for identifier in identifiers])
|
||||||
)
|
|
||||||
for identifier, response in zip(identifiers, responses, strict=False):
|
for identifier, response in zip(identifiers, responses, strict=False):
|
||||||
if not response.violation:
|
if not response.violation:
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
|
import contextvars
|
||||||
import logging
|
import logging
|
||||||
import queue
|
import queue
|
||||||
import threading
|
import threading
|
||||||
|
@ -24,9 +25,10 @@ from llama_stack.apis.telemetry import (
|
||||||
Telemetry,
|
Telemetry,
|
||||||
UnstructuredLogEvent,
|
UnstructuredLogEvent,
|
||||||
)
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import serialize_value
|
from llama_stack.providers.utils.telemetry.trace_protocol import serialize_value
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
logger = get_logger(__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
def generate_short_uuid(len: int = 8):
|
def generate_short_uuid(len: int = 8):
|
||||||
|
@ -36,7 +38,7 @@ def generate_short_uuid(len: int = 8):
|
||||||
return encoded.rstrip(b"=").decode("ascii")[:len]
|
return encoded.rstrip(b"=").decode("ascii")[:len]
|
||||||
|
|
||||||
|
|
||||||
CURRENT_TRACE_CONTEXT = None
|
CURRENT_TRACE_CONTEXT = contextvars.ContextVar("trace_context", default=None)
|
||||||
BACKGROUND_LOGGER = None
|
BACKGROUND_LOGGER = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -51,7 +53,7 @@ class BackgroundLogger:
|
||||||
try:
|
try:
|
||||||
self.log_queue.put_nowait(event)
|
self.log_queue.put_nowait(event)
|
||||||
except queue.Full:
|
except queue.Full:
|
||||||
log.error("Log queue is full, dropping event")
|
logger.error("Log queue is full, dropping event")
|
||||||
|
|
||||||
def _process_logs(self):
|
def _process_logs(self):
|
||||||
while True:
|
while True:
|
||||||
|
@ -129,35 +131,36 @@ def setup_logger(api: Telemetry, level: int = logging.INFO):
|
||||||
|
|
||||||
if BACKGROUND_LOGGER is None:
|
if BACKGROUND_LOGGER is None:
|
||||||
BACKGROUND_LOGGER = BackgroundLogger(api)
|
BACKGROUND_LOGGER = BackgroundLogger(api)
|
||||||
logger = logging.getLogger()
|
root_logger = logging.getLogger()
|
||||||
logger.setLevel(level)
|
root_logger.setLevel(level)
|
||||||
logger.addHandler(TelemetryHandler())
|
root_logger.addHandler(TelemetryHandler())
|
||||||
|
|
||||||
|
|
||||||
async def start_trace(name: str, attributes: Dict[str, Any] = None) -> TraceContext:
|
async def start_trace(name: str, attributes: Dict[str, Any] = None) -> TraceContext:
|
||||||
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
|
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
|
||||||
|
|
||||||
if BACKGROUND_LOGGER is None:
|
if BACKGROUND_LOGGER is None:
|
||||||
log.info("No Telemetry implementation set. Skipping trace initialization...")
|
logger.debug("No Telemetry implementation set. Skipping trace initialization...")
|
||||||
return
|
return
|
||||||
|
|
||||||
trace_id = generate_short_uuid(16)
|
trace_id = generate_short_uuid(16)
|
||||||
context = TraceContext(BACKGROUND_LOGGER, trace_id)
|
context = TraceContext(BACKGROUND_LOGGER, trace_id)
|
||||||
context.push_span(name, {"__root__": True, **(attributes or {})})
|
context.push_span(name, {"__root__": True, **(attributes or {})})
|
||||||
|
|
||||||
CURRENT_TRACE_CONTEXT = context
|
CURRENT_TRACE_CONTEXT.set(context)
|
||||||
return context
|
return context
|
||||||
|
|
||||||
|
|
||||||
async def end_trace(status: SpanStatus = SpanStatus.OK):
|
async def end_trace(status: SpanStatus = SpanStatus.OK):
|
||||||
global CURRENT_TRACE_CONTEXT
|
global CURRENT_TRACE_CONTEXT
|
||||||
|
|
||||||
context = CURRENT_TRACE_CONTEXT
|
context = CURRENT_TRACE_CONTEXT.get()
|
||||||
if context is None:
|
if context is None:
|
||||||
|
logger.debug("No trace context to end")
|
||||||
return
|
return
|
||||||
|
|
||||||
context.pop_span(status)
|
context.pop_span(status)
|
||||||
CURRENT_TRACE_CONTEXT = None
|
CURRENT_TRACE_CONTEXT.set(None)
|
||||||
|
|
||||||
|
|
||||||
def severity(levelname: str) -> LogSeverity:
|
def severity(levelname: str) -> LogSeverity:
|
||||||
|
@ -188,7 +191,7 @@ class TelemetryHandler(logging.Handler):
|
||||||
if BACKGROUND_LOGGER is None:
|
if BACKGROUND_LOGGER is None:
|
||||||
raise RuntimeError("Telemetry API not initialized")
|
raise RuntimeError("Telemetry API not initialized")
|
||||||
|
|
||||||
context = CURRENT_TRACE_CONTEXT
|
context = CURRENT_TRACE_CONTEXT.get()
|
||||||
if context is None:
|
if context is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -218,15 +221,21 @@ class SpanContextManager:
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
global CURRENT_TRACE_CONTEXT
|
global CURRENT_TRACE_CONTEXT
|
||||||
context = CURRENT_TRACE_CONTEXT
|
context = CURRENT_TRACE_CONTEXT.get()
|
||||||
if context:
|
if not context:
|
||||||
|
logger.debug("No trace context to push span")
|
||||||
|
return self
|
||||||
|
|
||||||
self.span = context.push_span(self.name, self.attributes)
|
self.span = context.push_span(self.name, self.attributes)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback):
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
global CURRENT_TRACE_CONTEXT
|
global CURRENT_TRACE_CONTEXT
|
||||||
context = CURRENT_TRACE_CONTEXT
|
context = CURRENT_TRACE_CONTEXT.get()
|
||||||
if context:
|
if not context:
|
||||||
|
logger.debug("No trace context to pop span")
|
||||||
|
return
|
||||||
|
|
||||||
context.pop_span()
|
context.pop_span()
|
||||||
|
|
||||||
def set_attribute(self, key: str, value: Any):
|
def set_attribute(self, key: str, value: Any):
|
||||||
|
@ -237,15 +246,21 @@ class SpanContextManager:
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
global CURRENT_TRACE_CONTEXT
|
global CURRENT_TRACE_CONTEXT
|
||||||
context = CURRENT_TRACE_CONTEXT
|
context = CURRENT_TRACE_CONTEXT.get()
|
||||||
if context:
|
if not context:
|
||||||
|
logger.debug("No trace context to push span")
|
||||||
|
return self
|
||||||
|
|
||||||
self.span = context.push_span(self.name, self.attributes)
|
self.span = context.push_span(self.name, self.attributes)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||||
global CURRENT_TRACE_CONTEXT
|
global CURRENT_TRACE_CONTEXT
|
||||||
context = CURRENT_TRACE_CONTEXT
|
context = CURRENT_TRACE_CONTEXT.get()
|
||||||
if context:
|
if not context:
|
||||||
|
logger.debug("No trace context to pop span")
|
||||||
|
return
|
||||||
|
|
||||||
context.pop_span()
|
context.pop_span()
|
||||||
|
|
||||||
def __call__(self, func: Callable):
|
def __call__(self, func: Callable):
|
||||||
|
@ -275,7 +290,11 @@ def span(name: str, attributes: Dict[str, Any] = None):
|
||||||
|
|
||||||
def get_current_span() -> Optional[Span]:
|
def get_current_span() -> Optional[Span]:
|
||||||
global CURRENT_TRACE_CONTEXT
|
global CURRENT_TRACE_CONTEXT
|
||||||
context = CURRENT_TRACE_CONTEXT
|
if CURRENT_TRACE_CONTEXT is None:
|
||||||
|
logger.debug("No trace context to get current span")
|
||||||
|
return None
|
||||||
|
|
||||||
|
context = CURRENT_TRACE_CONTEXT.get()
|
||||||
if context:
|
if context:
|
||||||
return context.get_current_span()
|
return context.get_current_span()
|
||||||
return None
|
return None
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue