diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index fdc68c0a4..192667f2c 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -269,7 +269,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): set_request_provider_data( {"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)} ) - await start_trace(options.url, {"__location__": "library_client"}) + if stream: response = await self._call_streaming( cast_to=cast_to, @@ -281,7 +281,6 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): cast_to=cast_to, options=options, ) - await end_trace() return response def _find_matching_endpoint(self, method: str, path: str) -> tuple[Any, dict]: @@ -323,7 +322,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): matched_func, path_params = self._find_matching_endpoint(options.method, path) body |= path_params body = self._convert_body(path, options.method, body) - result = await matched_func(**body) + await start_trace(options.url, {"__location__": "library_client"}) + try: + result = await matched_func(**body) + finally: + await end_trace() json_content = json.dumps(convert_pydantic_to_json_value(result)) mock_response = httpx.Response( @@ -366,10 +369,14 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): body = self._convert_body(path, options.method, body) async def gen(): - async for chunk in await func(**body): - data = json.dumps(convert_pydantic_to_json_value(chunk)) - sse_event = f"data: {data}\n\n" - yield sse_event.encode("utf-8") + await start_trace(options.url, {"__location__": "library_client"}) + try: + async for chunk in await func(**body): + data = json.dumps(convert_pydantic_to_json_value(chunk)) + sse_event = f"data: {data}\n\n" + yield sse_event.encode("utf-8") + finally: + await end_trace() mock_response = httpx.Response( status_code=httpx.codes.OK, diff --git a/llama_stack/providers/utils/telemetry/dataset_mixin.py b/llama_stack/providers/utils/telemetry/dataset_mixin.py index bf5e79c3d..e488f2475 100644 --- a/llama_stack/providers/utils/telemetry/dataset_mixin.py +++ b/llama_stack/providers/utils/telemetry/dataset_mixin.py @@ -53,7 +53,7 @@ class TelemetryDatasetMixin: spans = [] for trace in traces: - spans_by_id = await self.get_span_tree( + spans_by_id = await self.query_span_tree( span_id=trace.root_span_id, attributes_to_return=attributes_to_return, max_depth=max_depth,