diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index 54b974258..c01823b99 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -117,7 +117,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient): ): super().__init__() self.async_client = AsyncLlamaStackAsLibraryClient( - config_path_or_template_name, custom_provider_registry + config_path_or_template_name, custom_provider_registry, provider_data ) self.pool_executor = ThreadPoolExecutor(max_workers=4) self.skip_logger_removal = skip_logger_removal @@ -143,18 +143,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient): root_logger.removeHandler(handler) print(f"Removed handler {handler.__class__.__name__} from root logger") - def _get_path( - self, - cast_to: Any, - options: Any, - *, - stream=False, - stream_cls=None, - ): - return options.url - def request(self, *args, **kwargs): - path = self._get_path(*args, **kwargs) if kwargs.get("stream"): # NOTE: We are using AsyncLlamaStackClient under the hood # A new event loop is needed to convert the AsyncStream @@ -180,18 +169,10 @@ class LlamaStackAsLibraryClient(LlamaStackClient): return sync_generator() else: - async def _traced_request(): - if self.provider_data: - set_request_provider_data( - {"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)} - ) - await start_trace(path, {"__location__": "library_client"}) - try: - return await self.async_client.request(*args, **kwargs) - finally: - await end_trace() + async def _request(): + return await self.async_client.request(*args, **kwargs) - return asyncio.run(_traced_request()) + return asyncio.run(_request()) class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): @@ -199,9 +180,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): self, config_path_or_template_name: str, custom_provider_registry: Optional[ProviderRegistry] = None, + provider_data: Optional[dict[str, Any]] = None, ): super().__init__() - # when using the library client, we should not log to console since many # of our logs are intended for server-side usage current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",") @@ -222,6 +203,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): self.config_path_or_template_name = config_path_or_template_name self.config = config self.custom_provider_registry = custom_provider_registry + self.provider_data = provider_data async def initialize(self): try: @@ -278,17 +260,24 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): if not self.endpoint_impls: raise ValueError("Client not initialized") + if self.provider_data: + set_request_provider_data( + {"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)} + ) + await start_trace(options.url, {"__location__": "library_client"}) if stream: - return await self._call_streaming( + response = await self._call_streaming( cast_to=cast_to, options=options, stream_cls=stream_cls, ) else: - return await self._call_non_streaming( + response = await self._call_non_streaming( cast_to=cast_to, options=options, ) + await end_trace() + return response async def _call_non_streaming( self,