mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
add tracing back to the lib cli
This commit is contained in:
parent
e2054d53e4
commit
84904914c2
5 changed files with 106 additions and 61 deletions
|
@ -24,6 +24,7 @@ from termcolor import cprint
|
|||
|
||||
from llama_stack.distribution.build import print_pip_install_help
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.distribution.resolver import ProviderRegistry
|
||||
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
||||
from llama_stack.distribution.stack import (
|
||||
|
@ -32,6 +33,12 @@ from llama_stack.distribution.stack import (
|
|||
replace_env_vars,
|
||||
)
|
||||
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
end_trace,
|
||||
setup_logger,
|
||||
start_trace,
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
|
@ -240,6 +247,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
)
|
||||
return False
|
||||
|
||||
if Api.telemetry in self.impls:
|
||||
setup_logger(self.impls[Api.telemetry])
|
||||
|
||||
console = Console()
|
||||
console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:")
|
||||
console.print(yaml.dump(self.config.model_dump(), indent=2))
|
||||
|
@ -276,14 +286,20 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
async def _call_non_streaming(
|
||||
self, path: str, body: dict = None, cast_to: Any = None
|
||||
):
|
||||
await start_trace(path, {"__location__": "library_client"})
|
||||
try:
|
||||
func = self.endpoint_impls.get(path)
|
||||
if not func:
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
|
||||
body = self._convert_body(path, body)
|
||||
return convert_pydantic_to_json_value(await func(**body), cast_to)
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
async def _call_streaming(self, path: str, body: dict = None, cast_to: Any = None):
|
||||
await start_trace(path, {"__location__": "library_client"})
|
||||
try:
|
||||
func = self.endpoint_impls.get(path)
|
||||
if not func:
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
|
@ -291,6 +307,8 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
body = self._convert_body(path, body)
|
||||
async for chunk in await func(**body):
|
||||
yield convert_pydantic_to_json_value(chunk, cast_to)
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
def _convert_body(self, path: str, body: Optional[dict] = None) -> dict:
|
||||
if not body:
|
||||
|
|
|
@ -142,7 +142,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
async def create_and_execute_turn(
|
||||
self, request: AgentTurnCreateRequest
|
||||
) -> AsyncGenerator:
|
||||
with tracing.span("create_and_execute_turn") as span:
|
||||
with tracing.SpanContextManager("create_and_execute_turn") as span:
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("request", request.model_dump_json())
|
||||
|
@ -185,9 +185,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
stream=request.stream,
|
||||
):
|
||||
if isinstance(chunk, CompletionMessage):
|
||||
# log.info(
|
||||
# f"{chunk.role.capitalize()}: {chunk.content}",
|
||||
# )
|
||||
log.info(
|
||||
f"{chunk.role.capitalize()}: {chunk.content}",
|
||||
)
|
||||
output_message = chunk
|
||||
continue
|
||||
|
||||
|
@ -279,8 +279,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
shields: List[str],
|
||||
touchpoint: str,
|
||||
) -> AsyncGenerator:
|
||||
with tracing.span("run_shields") as span:
|
||||
span.set_attribute("turn_id", turn_id)
|
||||
with tracing.SpanContextManager("run_shields") as span:
|
||||
span.set_attribute("input", [m.model_dump_json() for m in messages])
|
||||
if len(shields) == 0:
|
||||
span.set_attribute("output", "no shields")
|
||||
|
@ -360,7 +359,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
# TODO: find older context from the session and either replace it
|
||||
# or append with a sliding window. this is really a very simplistic implementation
|
||||
with tracing.span("retrieve_rag_context") as span:
|
||||
with tracing.SpanContextManager("retrieve_rag_context") as span:
|
||||
rag_context, bank_ids = await self._retrieve_context(
|
||||
session_id, input_messages, attachments
|
||||
)
|
||||
|
@ -405,11 +404,11 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
n_iter = 0
|
||||
while True:
|
||||
msg = input_messages[-1]
|
||||
# if len(str(msg)) > 1000:
|
||||
# msg_str = f"{str(msg)[:500]}...<more>...{str(msg)[-500:]}"
|
||||
# else:
|
||||
# msg_str = str(msg)
|
||||
# log.info(f"{msg_str}")
|
||||
if len(str(msg)) > 1000:
|
||||
msg_str = f"{str(msg)[:500]}...<more>...{str(msg)[-500:]}"
|
||||
else:
|
||||
msg_str = str(msg)
|
||||
log.info(f"{msg_str}")
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
|
@ -425,7 +424,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
content = ""
|
||||
stop_reason = None
|
||||
|
||||
with tracing.span("inference") as span:
|
||||
with tracing.SpanContextManager("inference") as span:
|
||||
async for chunk in await self.inference_api.chat_completion(
|
||||
self.agent_config.model,
|
||||
input_messages,
|
||||
|
@ -514,12 +513,12 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
|
||||
if n_iter >= self.agent_config.max_infer_iters:
|
||||
# log.info("Done with MAX iterations, exiting.")
|
||||
log.info("Done with MAX iterations, exiting.")
|
||||
yield message
|
||||
break
|
||||
|
||||
if stop_reason == StopReason.out_of_tokens:
|
||||
# log.info("Out of token budget, exiting.")
|
||||
log.info("Out of token budget, exiting.")
|
||||
yield message
|
||||
break
|
||||
|
||||
|
@ -533,10 +532,10 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
message.content = [message.content] + attachments
|
||||
yield message
|
||||
else:
|
||||
# log.info(f"Partial message: {str(message)}")
|
||||
log.info(f"Partial message: {str(message)}")
|
||||
input_messages = input_messages + [message]
|
||||
else:
|
||||
# log.info(f"{str(message)}")
|
||||
log.info(f"{str(message)}")
|
||||
try:
|
||||
tool_call = message.tool_calls[0]
|
||||
|
||||
|
@ -564,7 +563,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
|
||||
with tracing.span(
|
||||
with tracing.SpanContextManager(
|
||||
"tool_execution",
|
||||
{
|
||||
"tool_name": tool_call.tool_name,
|
||||
|
@ -713,7 +712,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
for a in attachments
|
||||
]
|
||||
with tracing.span("insert_documents"):
|
||||
with tracing.SpanContextManager("insert_documents"):
|
||||
await self.memory_api.insert_documents(bank_id, documents)
|
||||
else:
|
||||
session_info = await self.storage.get_session_info(session_id)
|
||||
|
@ -800,7 +799,7 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
|
|||
path = urlparse(uri).path
|
||||
basename = os.path.basename(path)
|
||||
filepath = f"{tempdir}/{make_random_string() + basename}"
|
||||
# log.info(f"Downloading {url} -> {filepath}")
|
||||
log.info(f"Downloading {url} -> {filepath}")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(uri)
|
||||
|
|
|
@ -20,6 +20,7 @@ class SQLiteSpanProcessor(SpanProcessor):
|
|||
"""Initialize the SQLite span processor with a connection string."""
|
||||
self.conn_string = conn_string
|
||||
self.ttl_days = ttl_days
|
||||
self._shutdown_event = threading.Event()
|
||||
self.cleanup_task = None
|
||||
self._thread_local = threading.local()
|
||||
self._connections: Dict[int, sqlite3.Connection] = {}
|
||||
|
@ -144,8 +145,9 @@ class SQLiteSpanProcessor(SpanProcessor):
|
|||
"""Run cleanup periodically."""
|
||||
import time
|
||||
|
||||
while True:
|
||||
while not self._shutdown_event.is_set():
|
||||
time.sleep(3600) # Sleep for 1 hour
|
||||
if not self._shutdown_event.is_set():
|
||||
self._cleanup_old_data()
|
||||
|
||||
def on_start(self, span: Span, parent_context=None):
|
||||
|
@ -231,11 +233,23 @@ class SQLiteSpanProcessor(SpanProcessor):
|
|||
|
||||
def shutdown(self):
|
||||
"""Cleanup any resources."""
|
||||
self._shutdown_event.set()
|
||||
|
||||
# Wait for cleanup thread to finish if it exists
|
||||
if self.cleanup_task and self.cleanup_task.is_alive():
|
||||
self.cleanup_task.join(timeout=5.0)
|
||||
current_thread_id = threading.get_ident()
|
||||
|
||||
with self._lock:
|
||||
for conn in self._connections.values():
|
||||
# Close all connections from the current thread
|
||||
for thread_id, conn in list(self._connections.items()):
|
||||
if thread_id == current_thread_id:
|
||||
try:
|
||||
if conn:
|
||||
conn.close()
|
||||
self._connections.clear()
|
||||
del self._connections[thread_id]
|
||||
except sqlite3.Error:
|
||||
pass # Ignore errors during shutdown
|
||||
|
||||
def force_flush(self, timeout_millis=30000):
|
||||
"""Force export of spans."""
|
||||
|
|
|
@ -6,29 +6,31 @@
|
|||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
from typing import Any, AsyncGenerator, Callable, Type, TypeVar
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def serialize_value(value: Any) -> str:
|
||||
"""Helper function to serialize values to string representation."""
|
||||
try:
|
||||
if isinstance(value, BaseModel):
|
||||
return value.model_dump_json()
|
||||
elif isinstance(value, list) and value and isinstance(value[0], BaseModel):
|
||||
return json.dumps([item.model_dump_json() for item in value])
|
||||
elif hasattr(value, "to_dict"):
|
||||
return json.dumps(value.to_dict())
|
||||
elif isinstance(value, (dict, list, int, float, str, bool)):
|
||||
return json.dumps(value)
|
||||
else:
|
||||
def serialize_value(value: Any) -> Any:
|
||||
"""Serialize a single value into JSON-compatible format."""
|
||||
if value is None:
|
||||
return None
|
||||
elif isinstance(value, (str, int, float, bool)):
|
||||
return value
|
||||
elif isinstance(value, BaseModel):
|
||||
return value.model_dump()
|
||||
elif isinstance(value, (list, tuple, set)):
|
||||
return [serialize_value(item) for item in value]
|
||||
elif isinstance(value, dict):
|
||||
return {str(k): serialize_value(v) for k, v in value.items()}
|
||||
elif isinstance(value, (datetime, UUID)):
|
||||
return str(value)
|
||||
except Exception:
|
||||
else:
|
||||
return str(value)
|
||||
|
||||
|
||||
|
@ -47,16 +49,26 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
|
|||
def create_span_context(self: Any, *args: Any, **kwargs: Any) -> tuple:
|
||||
class_name = self.__class__.__name__
|
||||
method_name = method.__name__
|
||||
|
||||
span_type = (
|
||||
"async_generator" if is_async_gen else "async" if is_async else "sync"
|
||||
)
|
||||
sig = inspect.signature(method)
|
||||
param_names = list(sig.parameters.keys())[1:] # Skip 'self'
|
||||
combined_args = {}
|
||||
for i, arg in enumerate(args):
|
||||
param_name = (
|
||||
param_names[i] if i < len(param_names) else f"position_{i+1}"
|
||||
)
|
||||
combined_args[param_name] = serialize_value(arg)
|
||||
for k, v in kwargs.items():
|
||||
combined_args[str(k)] = serialize_value(v)
|
||||
|
||||
span_attributes = {
|
||||
"__autotraced__": True,
|
||||
"__class__": class_name,
|
||||
"__method__": method_name,
|
||||
"__type__": span_type,
|
||||
"__args__": serialize_value(args),
|
||||
"__args__": str(combined_args),
|
||||
}
|
||||
|
||||
return class_name, method_name, span_attributes
|
||||
|
@ -69,7 +81,9 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
|
|||
self, *args, **kwargs
|
||||
)
|
||||
|
||||
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
|
||||
with tracing.SpanContextManager(
|
||||
f"{class_name}.{method_name}", span_attributes
|
||||
) as span:
|
||||
try:
|
||||
count = 0
|
||||
async for item in method(self, *args, **kwargs):
|
||||
|
@ -84,7 +98,9 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
|
|||
self, *args, **kwargs
|
||||
)
|
||||
|
||||
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
|
||||
with tracing.SpanContextManager(
|
||||
f"{class_name}.{method_name}", span_attributes
|
||||
) as span:
|
||||
try:
|
||||
result = await method(self, *args, **kwargs)
|
||||
span.set_attribute("output", serialize_value(result))
|
||||
|
@ -99,7 +115,9 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
|
|||
self, *args, **kwargs
|
||||
)
|
||||
|
||||
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
|
||||
with tracing.SpanContextManager(
|
||||
f"{class_name}.{method_name}", span_attributes
|
||||
) as span:
|
||||
try:
|
||||
result = method(self, *args, **kwargs)
|
||||
span.set_attribute("output", serialize_value(result))
|
||||
|
|
|
@ -259,10 +259,6 @@ class SpanContextManager:
|
|||
return wrapper
|
||||
|
||||
|
||||
def span(name: str, attributes: Dict[str, Any] = None):
|
||||
return SpanContextManager(name, attributes)
|
||||
|
||||
|
||||
def get_current_span() -> Optional[Span]:
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
context = CURRENT_TRACE_CONTEXT
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue