diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index 693e2f56c..911e9bbc7 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -29,6 +29,7 @@ from llama_stack.distribution.stack import ( get_stack_run_config_from_template, replace_env_vars, ) +from llama_stack.providers.utils.telemetry.tracing import end_trace, start_trace T = TypeVar("T") @@ -234,21 +235,29 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): return await self._call_non_streaming(path, "POST", body) async def _call_non_streaming(self, path: str, method: str, body: dict = None): - func = self.endpoint_impls.get(path) - if not func: - raise ValueError(f"No endpoint found for {path}") + await start_trace(path, {"__location__": "library_client"}) + try: + func = self.endpoint_impls.get(path) + if not func: + raise ValueError(f"No endpoint found for {path}") - body = self._convert_body(path, body) - return await func(**body) + body = self._convert_body(path, body) + return await func(**body) + finally: + end_trace() async def _call_streaming(self, path: str, method: str, body: dict = None): - func = self.endpoint_impls.get(path) - if not func: - raise ValueError(f"No endpoint found for {path}") + await start_trace(path, {"__location__": "library_client"}) + try: + func = self.endpoint_impls.get(path) + if not func: + raise ValueError(f"No endpoint found for {path}") - body = self._convert_body(path, body) - async for chunk in await func(**body): - yield chunk + body = self._convert_body(path, body) + async for chunk in await func(**body): + yield chunk + finally: + end_trace() def _convert_body(self, path: str, body: Optional[dict] = None) -> dict: if not body: