add tracing to library client

This commit is contained in:
Dinesh Yeduguru 2024-12-09 15:30:10 -08:00
parent c699e884b5
commit 739f272e35

View file

@ -29,6 +29,7 @@ from llama_stack.distribution.stack import (
get_stack_run_config_from_template, get_stack_run_config_from_template,
replace_env_vars, replace_env_vars,
) )
from llama_stack.providers.utils.telemetry.tracing import end_trace, start_trace
T = TypeVar("T") T = TypeVar("T")
@ -234,21 +235,29 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
return await self._call_non_streaming(path, "POST", body) return await self._call_non_streaming(path, "POST", body)
async def _call_non_streaming(self, path: str, method: str, body: dict = None): async def _call_non_streaming(self, path: str, method: str, body: dict = None):
func = self.endpoint_impls.get(path) await start_trace(path, {"__location__": "library_client"})
if not func: try:
raise ValueError(f"No endpoint found for {path}") func = self.endpoint_impls.get(path)
if not func:
raise ValueError(f"No endpoint found for {path}")
body = self._convert_body(path, body) body = self._convert_body(path, body)
return await func(**body) return await func(**body)
finally:
end_trace()
async def _call_streaming(self, path: str, method: str, body: dict = None): async def _call_streaming(self, path: str, method: str, body: dict = None):
func = self.endpoint_impls.get(path) await start_trace(path, {"__location__": "library_client"})
if not func: try:
raise ValueError(f"No endpoint found for {path}") func = self.endpoint_impls.get(path)
if not func:
raise ValueError(f"No endpoint found for {path}")
body = self._convert_body(path, body) body = self._convert_body(path, body)
async for chunk in await func(**body): async for chunk in await func(**body):
yield chunk yield chunk
finally:
end_trace()
def _convert_body(self, path: str, body: Optional[dict] = None) -> dict: def _convert_body(self, path: str, body: Optional[dict] = None) -> dict:
if not body: if not body: