From 05f6b44da79ddd45b40e323ede34b86e700737bc Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Thu, 16 Jan 2025 10:36:13 -0800 Subject: [PATCH] Fix telemetry (#787) # What does this PR do? PR fixes couple of issues with telemetry: 1) The REST refactor changed the method from get_span_tree to query_span_tree, which is causing the server side to return empty spans 2) Library client has introduced a new event loop, which required changing the location of where start and end trace are called ## Test Plan LLAMA_STACK_CONFIG="/Users/dineshyv/.llama/distributions/llamastack-fireworks/fireworks-run.yaml" pytest -v tests/client-sdk/agents/test_agents.py -k "test_builtin_tool_web_search" And querying for spans from the agent run using the library client. --- llama_stack/distribution/library_client.py | 21 ++++++++++++------- .../utils/telemetry/dataset_mixin.py | 2 +- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index fdc68c0a4..192667f2c 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -269,7 +269,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): set_request_provider_data( {"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)} ) - await start_trace(options.url, {"__location__": "library_client"}) + if stream: response = await self._call_streaming( cast_to=cast_to, @@ -281,7 +281,6 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): cast_to=cast_to, options=options, ) - await end_trace() return response def _find_matching_endpoint(self, method: str, path: str) -> tuple[Any, dict]: @@ -323,7 +322,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): matched_func, path_params = self._find_matching_endpoint(options.method, path) body |= path_params body = self._convert_body(path, options.method, body) - result = await matched_func(**body) + await start_trace(options.url, {"__location__": "library_client"}) + try: + result = await matched_func(**body) + finally: + await end_trace() json_content = json.dumps(convert_pydantic_to_json_value(result)) mock_response = httpx.Response( @@ -366,10 +369,14 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): body = self._convert_body(path, options.method, body) async def gen(): - async for chunk in await func(**body): - data = json.dumps(convert_pydantic_to_json_value(chunk)) - sse_event = f"data: {data}\n\n" - yield sse_event.encode("utf-8") + await start_trace(options.url, {"__location__": "library_client"}) + try: + async for chunk in await func(**body): + data = json.dumps(convert_pydantic_to_json_value(chunk)) + sse_event = f"data: {data}\n\n" + yield sse_event.encode("utf-8") + finally: + await end_trace() mock_response = httpx.Response( status_code=httpx.codes.OK, diff --git a/llama_stack/providers/utils/telemetry/dataset_mixin.py b/llama_stack/providers/utils/telemetry/dataset_mixin.py index bf5e79c3d..e488f2475 100644 --- a/llama_stack/providers/utils/telemetry/dataset_mixin.py +++ b/llama_stack/providers/utils/telemetry/dataset_mixin.py @@ -53,7 +53,7 @@ class TelemetryDatasetMixin: spans = [] for trace in traces: - spans_by_id = await self.get_span_tree( + spans_by_id = await self.query_span_tree( span_id=trace.root_span_id, attributes_to_return=attributes_to_return, max_depth=max_depth,