forked from phoenix-oss/llama-stack-mirror
add tracing to library client (#591)
This commit is contained in:
parent
ab7145a04f
commit
bc1fddf1df
2 changed files with 49 additions and 17 deletions
|
@ -22,6 +22,7 @@ from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.distribution.build import print_pip_install_help
|
from llama_stack.distribution.build import print_pip_install_help
|
||||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||||
|
from llama_stack.distribution.datatypes import Api
|
||||||
from llama_stack.distribution.resolver import ProviderRegistry
|
from llama_stack.distribution.resolver import ProviderRegistry
|
||||||
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
||||||
from llama_stack.distribution.stack import (
|
from llama_stack.distribution.stack import (
|
||||||
|
@ -29,6 +30,11 @@ from llama_stack.distribution.stack import (
|
||||||
get_stack_run_config_from_template,
|
get_stack_run_config_from_template,
|
||||||
replace_env_vars,
|
replace_env_vars,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
|
end_trace,
|
||||||
|
setup_logger,
|
||||||
|
start_trace,
|
||||||
|
)
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
@ -187,6 +193,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# Set up telemetry logger similar to server.py
|
||||||
|
if Api.telemetry in self.impls:
|
||||||
|
setup_logger(self.impls[Api.telemetry])
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:")
|
console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:")
|
||||||
console.print(yaml.dump(self.config.model_dump(), indent=2))
|
console.print(yaml.dump(self.config.model_dump(), indent=2))
|
||||||
|
@ -234,21 +244,29 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
return await self._call_non_streaming(path, "POST", body)
|
return await self._call_non_streaming(path, "POST", body)
|
||||||
|
|
||||||
async def _call_non_streaming(self, path: str, method: str, body: dict = None):
|
async def _call_non_streaming(self, path: str, method: str, body: dict = None):
|
||||||
func = self.endpoint_impls.get(path)
|
await start_trace(path, {"__location__": "library_client"})
|
||||||
if not func:
|
try:
|
||||||
raise ValueError(f"No endpoint found for {path}")
|
func = self.endpoint_impls.get(path)
|
||||||
|
if not func:
|
||||||
|
raise ValueError(f"No endpoint found for {path}")
|
||||||
|
|
||||||
body = self._convert_body(path, body)
|
body = self._convert_body(path, body)
|
||||||
return await func(**body)
|
return await func(**body)
|
||||||
|
finally:
|
||||||
|
end_trace()
|
||||||
|
|
||||||
async def _call_streaming(self, path: str, method: str, body: dict = None):
|
async def _call_streaming(self, path: str, method: str, body: dict = None):
|
||||||
func = self.endpoint_impls.get(path)
|
await start_trace(path, {"__location__": "library_client"})
|
||||||
if not func:
|
try:
|
||||||
raise ValueError(f"No endpoint found for {path}")
|
func = self.endpoint_impls.get(path)
|
||||||
|
if not func:
|
||||||
|
raise ValueError(f"No endpoint found for {path}")
|
||||||
|
|
||||||
body = self._convert_body(path, body)
|
body = self._convert_body(path, body)
|
||||||
async for chunk in await func(**body):
|
async for chunk in await func(**body):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
finally:
|
||||||
|
end_trace()
|
||||||
|
|
||||||
def _convert_body(self, path: str, body: Optional[dict] = None) -> dict:
|
def _convert_body(self, path: str, body: Optional[dict] = None) -> dict:
|
||||||
if not body:
|
if not body:
|
||||||
|
|
|
@ -20,6 +20,7 @@ class SQLiteSpanProcessor(SpanProcessor):
|
||||||
"""Initialize the SQLite span processor with a connection string."""
|
"""Initialize the SQLite span processor with a connection string."""
|
||||||
self.conn_string = conn_string
|
self.conn_string = conn_string
|
||||||
self.ttl_days = ttl_days
|
self.ttl_days = ttl_days
|
||||||
|
self._shutdown_event = threading.Event()
|
||||||
self.cleanup_task = None
|
self.cleanup_task = None
|
||||||
self._thread_local = threading.local()
|
self._thread_local = threading.local()
|
||||||
self._connections: Dict[int, sqlite3.Connection] = {}
|
self._connections: Dict[int, sqlite3.Connection] = {}
|
||||||
|
@ -144,9 +145,10 @@ class SQLiteSpanProcessor(SpanProcessor):
|
||||||
"""Run cleanup periodically."""
|
"""Run cleanup periodically."""
|
||||||
import time
|
import time
|
||||||
|
|
||||||
while True:
|
while not self._shutdown_event.is_set():
|
||||||
time.sleep(3600) # Sleep for 1 hour
|
time.sleep(3600) # Sleep for 1 hour
|
||||||
self._cleanup_old_data()
|
if not self._shutdown_event.is_set():
|
||||||
|
self._cleanup_old_data()
|
||||||
|
|
||||||
def on_start(self, span: Span, parent_context=None):
|
def on_start(self, span: Span, parent_context=None):
|
||||||
"""Called when a span starts."""
|
"""Called when a span starts."""
|
||||||
|
@ -231,11 +233,23 @@ class SQLiteSpanProcessor(SpanProcessor):
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
"""Cleanup any resources."""
|
"""Cleanup any resources."""
|
||||||
|
self._shutdown_event.set()
|
||||||
|
|
||||||
|
# Wait for cleanup thread to finish if it exists
|
||||||
|
if self.cleanup_task and self.cleanup_task.is_alive():
|
||||||
|
self.cleanup_task.join(timeout=5.0)
|
||||||
|
current_thread_id = threading.get_ident()
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
for conn in self._connections.values():
|
# Close all connections from the current thread
|
||||||
if conn:
|
for thread_id, conn in list(self._connections.items()):
|
||||||
conn.close()
|
if thread_id == current_thread_id:
|
||||||
self._connections.clear()
|
try:
|
||||||
|
if conn:
|
||||||
|
conn.close()
|
||||||
|
del self._connections[thread_id]
|
||||||
|
except sqlite3.Error:
|
||||||
|
pass # Ignore errors during shutdown
|
||||||
|
|
||||||
def force_flush(self, timeout_millis=30000):
|
def force_flush(self, timeout_millis=30000):
|
||||||
"""Force export of spans."""
|
"""Force export of spans."""
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue