forked from phoenix-oss/llama-stack-mirror
add tracing back to the lib cli (#595)
Adds back all the tracing logic removed from library client. also adds back the logging to agent_instance.
This commit is contained in:
parent
1c03ba239e
commit
e128f2547a
4 changed files with 76 additions and 117 deletions
|
@ -24,6 +24,7 @@ from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.distribution.build import print_pip_install_help
|
from llama_stack.distribution.build import print_pip_install_help
|
||||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||||
|
from llama_stack.distribution.datatypes import Api
|
||||||
from llama_stack.distribution.resolver import ProviderRegistry
|
from llama_stack.distribution.resolver import ProviderRegistry
|
||||||
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
||||||
from llama_stack.distribution.stack import (
|
from llama_stack.distribution.stack import (
|
||||||
|
@ -32,6 +33,12 @@ from llama_stack.distribution.stack import (
|
||||||
replace_env_vars,
|
replace_env_vars,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
|
end_trace,
|
||||||
|
setup_logger,
|
||||||
|
start_trace,
|
||||||
|
)
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
@ -240,6 +247,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if Api.telemetry in self.impls:
|
||||||
|
setup_logger(self.impls[Api.telemetry])
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:")
|
console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:")
|
||||||
console.print(yaml.dump(self.config.model_dump(), indent=2))
|
console.print(yaml.dump(self.config.model_dump(), indent=2))
|
||||||
|
@ -276,21 +286,29 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
async def _call_non_streaming(
|
async def _call_non_streaming(
|
||||||
self, path: str, body: dict = None, cast_to: Any = None
|
self, path: str, body: dict = None, cast_to: Any = None
|
||||||
):
|
):
|
||||||
func = self.endpoint_impls.get(path)
|
await start_trace(path, {"__location__": "library_client"})
|
||||||
if not func:
|
try:
|
||||||
raise ValueError(f"No endpoint found for {path}")
|
func = self.endpoint_impls.get(path)
|
||||||
|
if not func:
|
||||||
|
raise ValueError(f"No endpoint found for {path}")
|
||||||
|
|
||||||
body = self._convert_body(path, body)
|
body = self._convert_body(path, body)
|
||||||
return convert_pydantic_to_json_value(await func(**body), cast_to)
|
return convert_pydantic_to_json_value(await func(**body), cast_to)
|
||||||
|
finally:
|
||||||
|
await end_trace()
|
||||||
|
|
||||||
async def _call_streaming(self, path: str, body: dict = None, cast_to: Any = None):
|
async def _call_streaming(self, path: str, body: dict = None, cast_to: Any = None):
|
||||||
func = self.endpoint_impls.get(path)
|
await start_trace(path, {"__location__": "library_client"})
|
||||||
if not func:
|
try:
|
||||||
raise ValueError(f"No endpoint found for {path}")
|
func = self.endpoint_impls.get(path)
|
||||||
|
if not func:
|
||||||
|
raise ValueError(f"No endpoint found for {path}")
|
||||||
|
|
||||||
body = self._convert_body(path, body)
|
body = self._convert_body(path, body)
|
||||||
async for chunk in await func(**body):
|
async for chunk in await func(**body):
|
||||||
yield convert_pydantic_to_json_value(chunk, cast_to)
|
yield convert_pydantic_to_json_value(chunk, cast_to)
|
||||||
|
finally:
|
||||||
|
await end_trace()
|
||||||
|
|
||||||
def _convert_body(self, path: str, body: Optional[dict] = None) -> dict:
|
def _convert_body(self, path: str, body: Optional[dict] = None) -> dict:
|
||||||
if not body:
|
if not body:
|
||||||
|
|
|
@ -185,9 +185,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
stream=request.stream,
|
stream=request.stream,
|
||||||
):
|
):
|
||||||
if isinstance(chunk, CompletionMessage):
|
if isinstance(chunk, CompletionMessage):
|
||||||
# log.info(
|
log.info(
|
||||||
# f"{chunk.role.capitalize()}: {chunk.content}",
|
f"{chunk.role.capitalize()}: {chunk.content}",
|
||||||
# )
|
)
|
||||||
output_message = chunk
|
output_message = chunk
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -280,7 +280,6 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
touchpoint: str,
|
touchpoint: str,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
with tracing.span("run_shields") as span:
|
with tracing.span("run_shields") as span:
|
||||||
span.set_attribute("turn_id", turn_id)
|
|
||||||
span.set_attribute("input", [m.model_dump_json() for m in messages])
|
span.set_attribute("input", [m.model_dump_json() for m in messages])
|
||||||
if len(shields) == 0:
|
if len(shields) == 0:
|
||||||
span.set_attribute("output", "no shields")
|
span.set_attribute("output", "no shields")
|
||||||
|
@ -405,11 +404,6 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
n_iter = 0
|
n_iter = 0
|
||||||
while True:
|
while True:
|
||||||
msg = input_messages[-1]
|
msg = input_messages[-1]
|
||||||
# if len(str(msg)) > 1000:
|
|
||||||
# msg_str = f"{str(msg)[:500]}...<more>...{str(msg)[-500:]}"
|
|
||||||
# else:
|
|
||||||
# msg_str = str(msg)
|
|
||||||
# log.info(f"{msg_str}")
|
|
||||||
|
|
||||||
step_id = str(uuid.uuid4())
|
step_id = str(uuid.uuid4())
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
|
@ -514,12 +508,12 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
|
|
||||||
if n_iter >= self.agent_config.max_infer_iters:
|
if n_iter >= self.agent_config.max_infer_iters:
|
||||||
# log.info("Done with MAX iterations, exiting.")
|
log.info("Done with MAX iterations, exiting.")
|
||||||
yield message
|
yield message
|
||||||
break
|
break
|
||||||
|
|
||||||
if stop_reason == StopReason.out_of_tokens:
|
if stop_reason == StopReason.out_of_tokens:
|
||||||
# log.info("Out of token budget, exiting.")
|
log.info("Out of token budget, exiting.")
|
||||||
yield message
|
yield message
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -533,10 +527,10 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
message.content = [message.content] + attachments
|
message.content = [message.content] + attachments
|
||||||
yield message
|
yield message
|
||||||
else:
|
else:
|
||||||
# log.info(f"Partial message: {str(message)}")
|
log.info(f"Partial message: {str(message)}")
|
||||||
input_messages = input_messages + [message]
|
input_messages = input_messages + [message]
|
||||||
else:
|
else:
|
||||||
# log.info(f"{str(message)}")
|
log.info(f"{str(message)}")
|
||||||
try:
|
try:
|
||||||
tool_call = message.tool_calls[0]
|
tool_call = message.tool_calls[0]
|
||||||
|
|
||||||
|
@ -800,7 +794,7 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
|
||||||
path = urlparse(uri).path
|
path = urlparse(uri).path
|
||||||
basename = os.path.basename(path)
|
basename = os.path.basename(path)
|
||||||
filepath = f"{tempdir}/{make_random_string() + basename}"
|
filepath = f"{tempdir}/{make_random_string() + basename}"
|
||||||
# log.info(f"Downloading {url} -> {filepath}")
|
log.info(f"Downloading {url} -> {filepath}")
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
r = await client.get(uri)
|
r = await client.get(uri)
|
||||||
|
|
|
@ -7,33 +7,24 @@
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import threading
|
from datetime import datetime
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
from opentelemetry.sdk.trace import SpanProcessor
|
from opentelemetry.sdk.trace import SpanProcessor
|
||||||
from opentelemetry.trace import Span
|
from opentelemetry.trace import Span
|
||||||
|
|
||||||
|
|
||||||
class SQLiteSpanProcessor(SpanProcessor):
|
class SQLiteSpanProcessor(SpanProcessor):
|
||||||
def __init__(self, conn_string, ttl_days=30):
|
def __init__(self, conn_string):
|
||||||
"""Initialize the SQLite span processor with a connection string."""
|
"""Initialize the SQLite span processor with a connection string."""
|
||||||
self.conn_string = conn_string
|
self.conn_string = conn_string
|
||||||
self.ttl_days = ttl_days
|
self.conn = None
|
||||||
self.cleanup_task = None
|
|
||||||
self._thread_local = threading.local()
|
|
||||||
self._connections: Dict[int, sqlite3.Connection] = {}
|
|
||||||
self._lock = threading.Lock()
|
|
||||||
self.setup_database()
|
self.setup_database()
|
||||||
|
|
||||||
def _get_connection(self) -> sqlite3.Connection:
|
def _get_connection(self) -> sqlite3.Connection:
|
||||||
"""Get a thread-specific database connection."""
|
"""Get the database connection."""
|
||||||
thread_id = threading.get_ident()
|
if self.conn is None:
|
||||||
with self._lock:
|
self.conn = sqlite3.connect(self.conn_string, check_same_thread=False)
|
||||||
if thread_id not in self._connections:
|
return self.conn
|
||||||
conn = sqlite3.connect(self.conn_string)
|
|
||||||
self._connections[thread_id] = conn
|
|
||||||
return self._connections[thread_id]
|
|
||||||
|
|
||||||
def setup_database(self):
|
def setup_database(self):
|
||||||
"""Create the necessary tables if they don't exist."""
|
"""Create the necessary tables if they don't exist."""
|
||||||
|
@ -94,60 +85,6 @@ class SQLiteSpanProcessor(SpanProcessor):
|
||||||
conn.commit()
|
conn.commit()
|
||||||
cursor.close()
|
cursor.close()
|
||||||
|
|
||||||
# Start periodic cleanup in a separate thread
|
|
||||||
self.cleanup_task = threading.Thread(target=self._periodic_cleanup, daemon=True)
|
|
||||||
self.cleanup_task.start()
|
|
||||||
|
|
||||||
def _cleanup_old_data(self):
|
|
||||||
"""Delete records older than TTL."""
|
|
||||||
try:
|
|
||||||
conn = self._get_connection()
|
|
||||||
cutoff_date = (datetime.now() - timedelta(days=self.ttl_days)).isoformat()
|
|
||||||
cursor = conn.cursor()
|
|
||||||
|
|
||||||
# Delete old span events
|
|
||||||
cursor.execute(
|
|
||||||
"""
|
|
||||||
DELETE FROM span_events
|
|
||||||
WHERE span_id IN (
|
|
||||||
SELECT span_id FROM spans
|
|
||||||
WHERE trace_id IN (
|
|
||||||
SELECT trace_id FROM traces
|
|
||||||
WHERE created_at < ?
|
|
||||||
)
|
|
||||||
)
|
|
||||||
""",
|
|
||||||
(cutoff_date,),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Delete old spans
|
|
||||||
cursor.execute(
|
|
||||||
"""
|
|
||||||
DELETE FROM spans
|
|
||||||
WHERE trace_id IN (
|
|
||||||
SELECT trace_id FROM traces
|
|
||||||
WHERE created_at < ?
|
|
||||||
)
|
|
||||||
""",
|
|
||||||
(cutoff_date,),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Delete old traces
|
|
||||||
cursor.execute("DELETE FROM traces WHERE created_at < ?", (cutoff_date,))
|
|
||||||
|
|
||||||
conn.commit()
|
|
||||||
cursor.close()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error during cleanup: {e}")
|
|
||||||
|
|
||||||
def _periodic_cleanup(self):
|
|
||||||
"""Run cleanup periodically."""
|
|
||||||
import time
|
|
||||||
|
|
||||||
while True:
|
|
||||||
time.sleep(3600) # Sleep for 1 hour
|
|
||||||
self._cleanup_old_data()
|
|
||||||
|
|
||||||
def on_start(self, span: Span, parent_context=None):
|
def on_start(self, span: Span, parent_context=None):
|
||||||
"""Called when a span starts."""
|
"""Called when a span starts."""
|
||||||
pass
|
pass
|
||||||
|
@ -231,11 +168,9 @@ class SQLiteSpanProcessor(SpanProcessor):
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
"""Cleanup any resources."""
|
"""Cleanup any resources."""
|
||||||
with self._lock:
|
if self.conn:
|
||||||
for conn in self._connections.values():
|
self.conn.close()
|
||||||
if conn:
|
self.conn = None
|
||||||
conn.close()
|
|
||||||
self._connections.clear()
|
|
||||||
|
|
||||||
def force_flush(self, timeout_millis=30000):
|
def force_flush(self, timeout_millis=30000):
|
||||||
"""Force export of spans."""
|
"""Force export of spans."""
|
||||||
|
|
|
@ -6,29 +6,31 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
from datetime import datetime
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, AsyncGenerator, Callable, Type, TypeVar
|
from typing import Any, AsyncGenerator, Callable, Type, TypeVar
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
def serialize_value(value: Any) -> str:
|
def serialize_value(value: Any) -> Any:
|
||||||
"""Helper function to serialize values to string representation."""
|
"""Serialize a single value into JSON-compatible format."""
|
||||||
try:
|
if value is None:
|
||||||
if isinstance(value, BaseModel):
|
return None
|
||||||
return value.model_dump_json()
|
elif isinstance(value, (str, int, float, bool)):
|
||||||
elif isinstance(value, list) and value and isinstance(value[0], BaseModel):
|
return value
|
||||||
return json.dumps([item.model_dump_json() for item in value])
|
elif isinstance(value, BaseModel):
|
||||||
elif hasattr(value, "to_dict"):
|
return value.model_dump()
|
||||||
return json.dumps(value.to_dict())
|
elif isinstance(value, (list, tuple, set)):
|
||||||
elif isinstance(value, (dict, list, int, float, str, bool)):
|
return [serialize_value(item) for item in value]
|
||||||
return json.dumps(value)
|
elif isinstance(value, dict):
|
||||||
else:
|
return {str(k): serialize_value(v) for k, v in value.items()}
|
||||||
return str(value)
|
elif isinstance(value, (datetime, UUID)):
|
||||||
except Exception:
|
return str(value)
|
||||||
|
else:
|
||||||
return str(value)
|
return str(value)
|
||||||
|
|
||||||
|
|
||||||
|
@ -47,16 +49,26 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
|
||||||
def create_span_context(self: Any, *args: Any, **kwargs: Any) -> tuple:
|
def create_span_context(self: Any, *args: Any, **kwargs: Any) -> tuple:
|
||||||
class_name = self.__class__.__name__
|
class_name = self.__class__.__name__
|
||||||
method_name = method.__name__
|
method_name = method.__name__
|
||||||
|
|
||||||
span_type = (
|
span_type = (
|
||||||
"async_generator" if is_async_gen else "async" if is_async else "sync"
|
"async_generator" if is_async_gen else "async" if is_async else "sync"
|
||||||
)
|
)
|
||||||
|
sig = inspect.signature(method)
|
||||||
|
param_names = list(sig.parameters.keys())[1:] # Skip 'self'
|
||||||
|
combined_args = {}
|
||||||
|
for i, arg in enumerate(args):
|
||||||
|
param_name = (
|
||||||
|
param_names[i] if i < len(param_names) else f"position_{i+1}"
|
||||||
|
)
|
||||||
|
combined_args[param_name] = serialize_value(arg)
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
combined_args[str(k)] = serialize_value(v)
|
||||||
|
|
||||||
span_attributes = {
|
span_attributes = {
|
||||||
"__autotraced__": True,
|
"__autotraced__": True,
|
||||||
"__class__": class_name,
|
"__class__": class_name,
|
||||||
"__method__": method_name,
|
"__method__": method_name,
|
||||||
"__type__": span_type,
|
"__type__": span_type,
|
||||||
"__args__": serialize_value(args),
|
"__args__": str(combined_args),
|
||||||
}
|
}
|
||||||
|
|
||||||
return class_name, method_name, span_attributes
|
return class_name, method_name, span_attributes
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue