diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index 14f62e3a6..48fcc437b 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -67,6 +67,7 @@ def in_notebook(): def stream_across_asyncio_run_boundary( async_gen_maker, pool_executor: ThreadPoolExecutor, + path: Optional[str] = None, ) -> Generator[T, None, None]: result_queue = queue.Queue() stop_event = threading.Event() @@ -74,6 +75,7 @@ def stream_across_asyncio_run_boundary( async def consumer(): # make sure we make the generator in the event loop context gen = await async_gen_maker() + await start_trace(path, {"__location__": "library_client"}) try: async for item in await gen: result_queue.put(item) @@ -85,6 +87,7 @@ def stream_across_asyncio_run_boundary( finally: result_queue.put(StopIteration) stop_event.set() + await end_trace() def run_async(): # Run our own loop to avoid double async generator cleanup which is done @@ -186,14 +189,34 @@ class LlamaStackAsLibraryClient(LlamaStackClient): return asyncio.run(self.async_client.initialize()) + def _get_path( + self, + cast_to: Any, + options: Any, + *, + stream=False, + stream_cls=None, + ): + return options.url + def request(self, *args, **kwargs): + path = self._get_path(*args, **kwargs) if kwargs.get("stream"): return stream_across_asyncio_run_boundary( lambda: self.async_client.request(*args, **kwargs), self.pool_executor, + path=path, ) else: - return asyncio.run(self.async_client.request(*args, **kwargs)) + + async def _traced_request(): + await start_trace(path, {"__location__": "library_client"}) + try: + return await self.async_client.request(*args, **kwargs) + finally: + await end_trace() + + return asyncio.run(_traced_request()) class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): @@ -206,7 +229,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): # when using the library client, we should not log to console since many # of our logs are intended for server-side usage - os.environ["TELEMETRY_SINKS"] = "sqlite" + current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",") + os.environ["TELEMETRY_SINKS"] = ",".join( + sink for sink in current_sinks if sink != "console" + ) if config_path_or_template_name.endswith(".yaml"): config_path = Path(config_path_or_template_name) @@ -295,41 +321,37 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): body = options.params or {} body |= options.json_data or {} - await start_trace(path, {"__location__": "library_client"}) - try: - func = self.endpoint_impls.get(path) - if not func: - raise ValueError(f"No endpoint found for {path}") + func = self.endpoint_impls.get(path) + if not func: + raise ValueError(f"No endpoint found for {path}") - body = self._convert_body(path, body) - result = await func(**body) + body = self._convert_body(path, body) + result = await func(**body) - json_content = json.dumps(convert_pydantic_to_json_value(result)) - mock_response = httpx.Response( - status_code=httpx.codes.OK, - content=json_content.encode("utf-8"), - headers={ - "Content-Type": "application/json", - }, - request=httpx.Request( - method=options.method, - url=options.url, - params=options.params, - headers=options.headers, - json=options.json_data, - ), - ) - response = APIResponse( - raw=mock_response, - client=self, - cast_to=cast_to, - options=options, - stream=False, - stream_cls=None, - ) - return response.parse() - finally: - await end_trace() + json_content = json.dumps(convert_pydantic_to_json_value(result)) + mock_response = httpx.Response( + status_code=httpx.codes.OK, + content=json_content.encode("utf-8"), + headers={ + "Content-Type": "application/json", + }, + request=httpx.Request( + method=options.method, + url=options.url, + params=options.params, + headers=options.headers, + json=options.json_data, + ), + ) + response = APIResponse( + raw=mock_response, + client=self, + cast_to=cast_to, + options=options, + stream=False, + stream_cls=None, + ) + return response.parse() async def _call_streaming( self, @@ -341,51 +363,47 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): path = options.url body = options.params or {} body |= options.json_data or {} - await start_trace(path, {"__location__": "library_client"}) - try: - func = self.endpoint_impls.get(path) - if not func: - raise ValueError(f"No endpoint found for {path}") + func = self.endpoint_impls.get(path) + if not func: + raise ValueError(f"No endpoint found for {path}") - body = self._convert_body(path, body) + body = self._convert_body(path, body) - async def gen(): - async for chunk in await func(**body): - data = json.dumps(convert_pydantic_to_json_value(chunk)) - sse_event = f"data: {data}\n\n" - yield sse_event.encode("utf-8") + async def gen(): + async for chunk in await func(**body): + data = json.dumps(convert_pydantic_to_json_value(chunk)) + sse_event = f"data: {data}\n\n" + yield sse_event.encode("utf-8") - mock_response = httpx.Response( - status_code=httpx.codes.OK, - content=gen(), - headers={ - "Content-Type": "application/json", - }, - request=httpx.Request( - method=options.method, - url=options.url, - params=options.params, - headers=options.headers, - json=options.json_data, - ), - ) + mock_response = httpx.Response( + status_code=httpx.codes.OK, + content=gen(), + headers={ + "Content-Type": "application/json", + }, + request=httpx.Request( + method=options.method, + url=options.url, + params=options.params, + headers=options.headers, + json=options.json_data, + ), + ) - # we use asynchronous impl always internally and channel all requests to AsyncLlamaStackClient - # however, the top-level caller may be a SyncAPIClient -- so its stream_cls might be a Stream (SyncStream) - # so we need to convert it to AsyncStream - args = get_args(stream_cls) - stream_cls = AsyncStream[args[0]] - response = AsyncAPIResponse( - raw=mock_response, - client=self, - cast_to=cast_to, - options=options, - stream=True, - stream_cls=stream_cls, - ) - return await response.parse() - finally: - await end_trace() + # we use asynchronous impl always internally and channel all requests to AsyncLlamaStackClient + # however, the top-level caller may be a SyncAPIClient -- so its stream_cls might be a Stream (SyncStream) + # so we need to convert it to AsyncStream + args = get_args(stream_cls) + stream_cls = AsyncStream[args[0]] + response = AsyncAPIResponse( + raw=mock_response, + client=self, + cast_to=cast_to, + options=options, + stream=True, + stream_cls=stream_cls, + ) + return await response.parse() def _convert_body(self, path: str, body: Optional[dict] = None) -> dict: if not body: