From d89ef35151b4863acf8f10bbf2e16fb2716c95e1 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Sat, 8 Mar 2025 19:49:07 -0800 Subject: [PATCH] fixes --- llama_stack/distribution/request_headers.py | 27 +++++++++++-------- llama_stack/distribution/server/server.py | 12 ++++----- .../remote/inference/fireworks/fireworks.py | 5 ++-- .../remote/inference/together/together.py | 5 ++-- tests/integration/fixtures/common.py | 2 +- 5 files changed, 28 insertions(+), 23 deletions(-) diff --git a/llama_stack/distribution/request_headers.py b/llama_stack/distribution/request_headers.py index f617ce945..552ac15c0 100644 --- a/llama_stack/distribution/request_headers.py +++ b/llama_stack/distribution/request_headers.py @@ -38,7 +38,7 @@ class RequestProviderDataContext(ContextManager): T = TypeVar("T") -async def preserve_headers_context_async_generator(gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]: +def preserve_headers_context_async_generator(gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]: """ Wraps an async generator to preserve request headers context variables across iterations. @@ -46,18 +46,23 @@ async def preserve_headers_context_async_generator(gen: AsyncGenerator[T, None]) available during each iteration of the generator, even if the original context manager has exited. """ - # Capture the current context value + # Capture the current context value right now context_value = _provider_data_var.get() - # Create a wrapper that restores context for each iteration - async for item in gen: - # Save the current token to restore later - token = _provider_data_var.set(context_value) - try: - yield item - finally: - # Restore the previous value - _provider_data_var.reset(token) + async def wrapper(): + while True: + # Set context before each anext() call + token = _provider_data_var.set(context_value) + try: + item = await gen.__anext__() + yield item + except StopAsyncIteration: + break + finally: + # Restore the previous value + _provider_data_var.reset(token) + + return wrapper() class NeedsRequestProviderData: diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 842798614..5ca759f1c 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -205,18 +205,14 @@ async def maybe_await(value): async def sse_generator(event_gen): try: - event_gen = await event_gen - # Wrap the generator to preserve context across iterations - wrapped_gen = preserve_headers_context_async_generator(event_gen) - async for item in wrapped_gen: + async for item in await event_gen: yield create_sse_event(item) await asyncio.sleep(0.01) except asyncio.CancelledError: logger.info("Generator cancelled") await event_gen.aclose() except Exception as e: - logger.exception(f"Error in sse_generator: {e}") - logger.exception(f"Traceback: {''.join(traceback.format_exception(type(e), e, e.__traceback__))}") + logger.exception("Error in sse_generator") yield create_sse_event( { "error": { @@ -231,9 +227,11 @@ def create_dynamic_typed_route(func: Any, method: str, route: str): # Use context manager for request provider data with request_provider_data_context(request.headers): is_streaming = is_streaming_request(func.__name__, request, **kwargs) + try: if is_streaming: - return StreamingResponse(sse_generator(func(**kwargs)), media_type="text/event-stream") + gen = preserve_headers_context_async_generator(sse_generator(func(**kwargs))) + return StreamingResponse(gen, media_type="text/event-stream") else: value = func(**kwargs) return await maybe_await(value) diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index ec68fb556..4acbe43f8 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -70,8 +70,9 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv pass def _get_api_key(self) -> str: - if self.config.api_key is not None: - return self.config.api_key.get_secret_value() + config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None + if config_api_key: + return config_api_key else: provider_data = self.get_request_provider_data() if provider_data is None or not provider_data.fireworks_api_key: diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 2046d4aae..dfc9ae6d3 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -93,8 +93,9 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi def _get_client(self) -> Together: together_api_key = None - if self.config.api_key is not None: - together_api_key = self.config.api_key.get_secret_value() + config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None + if config_api_key: + together_api_key = config_api_key else: provider_data = self.get_request_provider_data() if provider_data is None or not provider_data.together_api_key: diff --git a/tests/integration/fixtures/common.py b/tests/integration/fixtures/common.py index 6a75b3adf..e410039e7 100644 --- a/tests/integration/fixtures/common.py +++ b/tests/integration/fixtures/common.py @@ -42,7 +42,7 @@ def provider_data(): for key, value in keymap.items(): if os.environ.get(key): provider_data[value] = os.environ[key] - return provider_data if len(provider_data) > 0 else None + return provider_data @pytest.fixture(scope="session")