diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index 693e2f56c..3a87f0c97 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -22,6 +22,7 @@ from termcolor import cprint from llama_stack.distribution.build import print_pip_install_help from llama_stack.distribution.configure import parse_and_maybe_upgrade_config +from llama_stack.distribution.datatypes import Api from llama_stack.distribution.resolver import ProviderRegistry from llama_stack.distribution.server.endpoints import get_all_api_endpoints from llama_stack.distribution.stack import ( @@ -29,6 +30,11 @@ from llama_stack.distribution.stack import ( get_stack_run_config_from_template, replace_env_vars, ) +from llama_stack.providers.utils.telemetry.tracing import ( + end_trace, + setup_logger, + start_trace, +) T = TypeVar("T") @@ -187,6 +193,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): ) return False + # Set up telemetry logger similar to server.py + if Api.telemetry in self.impls: + setup_logger(self.impls[Api.telemetry]) + console = Console() console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:") console.print(yaml.dump(self.config.model_dump(), indent=2)) @@ -234,21 +244,29 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): return await self._call_non_streaming(path, "POST", body) async def _call_non_streaming(self, path: str, method: str, body: dict = None): - func = self.endpoint_impls.get(path) - if not func: - raise ValueError(f"No endpoint found for {path}") + await start_trace(path, {"__location__": "library_client"}) + try: + func = self.endpoint_impls.get(path) + if not func: + raise ValueError(f"No endpoint found for {path}") - body = self._convert_body(path, body) - return await func(**body) + body = self._convert_body(path, body) + return await func(**body) + finally: + end_trace() async def _call_streaming(self, path: str, method: str, body: dict = None): - func = self.endpoint_impls.get(path) - if not func: - raise ValueError(f"No endpoint found for {path}") + await start_trace(path, {"__location__": "library_client"}) + try: + func = self.endpoint_impls.get(path) + if not func: + raise ValueError(f"No endpoint found for {path}") - body = self._convert_body(path, body) - async for chunk in await func(**body): - yield chunk + body = self._convert_body(path, body) + async for chunk in await func(**body): + yield chunk + finally: + end_trace() def _convert_body(self, path: str, body: Optional[dict] = None) -> dict: if not body: diff --git a/llama_stack/providers/inline/telemetry/meta_reference/sqlite_span_processor.py b/llama_stack/providers/inline/telemetry/meta_reference/sqlite_span_processor.py index 553dd5000..f8fdbc12f 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/sqlite_span_processor.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/sqlite_span_processor.py @@ -20,6 +20,7 @@ class SQLiteSpanProcessor(SpanProcessor): """Initialize the SQLite span processor with a connection string.""" self.conn_string = conn_string self.ttl_days = ttl_days + self._shutdown_event = threading.Event() self.cleanup_task = None self._thread_local = threading.local() self._connections: Dict[int, sqlite3.Connection] = {} @@ -144,9 +145,10 @@ class SQLiteSpanProcessor(SpanProcessor): """Run cleanup periodically.""" import time - while True: + while not self._shutdown_event.is_set(): time.sleep(3600) # Sleep for 1 hour - self._cleanup_old_data() + if not self._shutdown_event.is_set(): + self._cleanup_old_data() def on_start(self, span: Span, parent_context=None): """Called when a span starts.""" @@ -231,11 +233,23 @@ class SQLiteSpanProcessor(SpanProcessor): def shutdown(self): """Cleanup any resources.""" + self._shutdown_event.set() + + # Wait for cleanup thread to finish if it exists + if self.cleanup_task and self.cleanup_task.is_alive(): + self.cleanup_task.join(timeout=5.0) + current_thread_id = threading.get_ident() + with self._lock: - for conn in self._connections.values(): - if conn: - conn.close() - self._connections.clear() + # Close all connections from the current thread + for thread_id, conn in list(self._connections.items()): + if thread_id == current_thread_id: + try: + if conn: + conn.close() + del self._connections[thread_id] + except sqlite3.Error: + pass # Ignore errors during shutdown def force_flush(self, timeout_millis=30000): """Force export of spans."""