This commit is contained in:
Ashwin Bharambe 2025-03-08 19:49:07 -08:00
parent 21769648a6
commit d89ef35151
5 changed files with 28 additions and 23 deletions

View file

@ -38,7 +38,7 @@ class RequestProviderDataContext(ContextManager):
T = TypeVar("T") T = TypeVar("T")
async def preserve_headers_context_async_generator(gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]: def preserve_headers_context_async_generator(gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
""" """
Wraps an async generator to preserve request headers context variables across iterations. Wraps an async generator to preserve request headers context variables across iterations.
@ -46,19 +46,24 @@ async def preserve_headers_context_async_generator(gen: AsyncGenerator[T, None])
available during each iteration of the generator, even if the original available during each iteration of the generator, even if the original
context manager has exited. context manager has exited.
""" """
# Capture the current context value # Capture the current context value right now
context_value = _provider_data_var.get() context_value = _provider_data_var.get()
# Create a wrapper that restores context for each iteration async def wrapper():
async for item in gen: while True:
# Save the current token to restore later # Set context before each anext() call
token = _provider_data_var.set(context_value) token = _provider_data_var.set(context_value)
try: try:
item = await gen.__anext__()
yield item yield item
except StopAsyncIteration:
break
finally: finally:
# Restore the previous value # Restore the previous value
_provider_data_var.reset(token) _provider_data_var.reset(token)
return wrapper()
class NeedsRequestProviderData: class NeedsRequestProviderData:
def get_request_provider_data(self) -> Any: def get_request_provider_data(self) -> Any:

View file

@ -205,18 +205,14 @@ async def maybe_await(value):
async def sse_generator(event_gen): async def sse_generator(event_gen):
try: try:
event_gen = await event_gen async for item in await event_gen:
# Wrap the generator to preserve context across iterations
wrapped_gen = preserve_headers_context_async_generator(event_gen)
async for item in wrapped_gen:
yield create_sse_event(item) yield create_sse_event(item)
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
except asyncio.CancelledError: except asyncio.CancelledError:
logger.info("Generator cancelled") logger.info("Generator cancelled")
await event_gen.aclose() await event_gen.aclose()
except Exception as e: except Exception as e:
logger.exception(f"Error in sse_generator: {e}") logger.exception("Error in sse_generator")
logger.exception(f"Traceback: {''.join(traceback.format_exception(type(e), e, e.__traceback__))}")
yield create_sse_event( yield create_sse_event(
{ {
"error": { "error": {
@ -231,9 +227,11 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
# Use context manager for request provider data # Use context manager for request provider data
with request_provider_data_context(request.headers): with request_provider_data_context(request.headers):
is_streaming = is_streaming_request(func.__name__, request, **kwargs) is_streaming = is_streaming_request(func.__name__, request, **kwargs)
try: try:
if is_streaming: if is_streaming:
return StreamingResponse(sse_generator(func(**kwargs)), media_type="text/event-stream") gen = preserve_headers_context_async_generator(sse_generator(func(**kwargs)))
return StreamingResponse(gen, media_type="text/event-stream")
else: else:
value = func(**kwargs) value = func(**kwargs)
return await maybe_await(value) return await maybe_await(value)

View file

@ -70,8 +70,9 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
pass pass
def _get_api_key(self) -> str: def _get_api_key(self) -> str:
if self.config.api_key is not None: config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
return self.config.api_key.get_secret_value() if config_api_key:
return config_api_key
else: else:
provider_data = self.get_request_provider_data() provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.fireworks_api_key: if provider_data is None or not provider_data.fireworks_api_key:

View file

@ -93,8 +93,9 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
def _get_client(self) -> Together: def _get_client(self) -> Together:
together_api_key = None together_api_key = None
if self.config.api_key is not None: config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
together_api_key = self.config.api_key.get_secret_value() if config_api_key:
together_api_key = config_api_key
else: else:
provider_data = self.get_request_provider_data() provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.together_api_key: if provider_data is None or not provider_data.together_api_key:

View file

@ -42,7 +42,7 @@ def provider_data():
for key, value in keymap.items(): for key, value in keymap.items():
if os.environ.get(key): if os.environ.get(key):
provider_data[value] = os.environ[key] provider_data[value] = os.environ[key]
return provider_data if len(provider_data) > 0 else None return provider_data
@pytest.fixture(scope="session") @pytest.fixture(scope="session")