This commit is contained in:
Ashwin Bharambe 2025-03-08 19:49:07 -08:00
parent 21769648a6
commit d89ef35151
5 changed files with 28 additions and 23 deletions

View file

@ -38,7 +38,7 @@ class RequestProviderDataContext(ContextManager):
T = TypeVar("T")
async def preserve_headers_context_async_generator(gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
def preserve_headers_context_async_generator(gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
"""
Wraps an async generator to preserve request headers context variables across iterations.
@ -46,18 +46,23 @@ async def preserve_headers_context_async_generator(gen: AsyncGenerator[T, None])
available during each iteration of the generator, even if the original
context manager has exited.
"""
# Capture the current context value
# Capture the current context value right now
context_value = _provider_data_var.get()
# Create a wrapper that restores context for each iteration
async for item in gen:
# Save the current token to restore later
token = _provider_data_var.set(context_value)
try:
yield item
finally:
# Restore the previous value
_provider_data_var.reset(token)
async def wrapper():
while True:
# Set context before each anext() call
token = _provider_data_var.set(context_value)
try:
item = await gen.__anext__()
yield item
except StopAsyncIteration:
break
finally:
# Restore the previous value
_provider_data_var.reset(token)
return wrapper()
class NeedsRequestProviderData:

View file

@ -205,18 +205,14 @@ async def maybe_await(value):
async def sse_generator(event_gen):
try:
event_gen = await event_gen
# Wrap the generator to preserve context across iterations
wrapped_gen = preserve_headers_context_async_generator(event_gen)
async for item in wrapped_gen:
async for item in await event_gen:
yield create_sse_event(item)
await asyncio.sleep(0.01)
except asyncio.CancelledError:
logger.info("Generator cancelled")
await event_gen.aclose()
except Exception as e:
logger.exception(f"Error in sse_generator: {e}")
logger.exception(f"Traceback: {''.join(traceback.format_exception(type(e), e, e.__traceback__))}")
logger.exception("Error in sse_generator")
yield create_sse_event(
{
"error": {
@ -231,9 +227,11 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
# Use context manager for request provider data
with request_provider_data_context(request.headers):
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
try:
if is_streaming:
return StreamingResponse(sse_generator(func(**kwargs)), media_type="text/event-stream")
gen = preserve_headers_context_async_generator(sse_generator(func(**kwargs)))
return StreamingResponse(gen, media_type="text/event-stream")
else:
value = func(**kwargs)
return await maybe_await(value)

View file

@ -70,8 +70,9 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
pass
def _get_api_key(self) -> str:
if self.config.api_key is not None:
return self.config.api_key.get_secret_value()
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
if config_api_key:
return config_api_key
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.fireworks_api_key:

View file

@ -93,8 +93,9 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
def _get_client(self) -> Together:
together_api_key = None
if self.config.api_key is not None:
together_api_key = self.config.api_key.get_secret_value()
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
if config_api_key:
together_api_key = config_api_key
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.together_api_key:

View file

@ -42,7 +42,7 @@ def provider_data():
for key, value in keymap.items():
if os.environ.get(key):
provider_data[value] = os.environ[key]
return provider_data if len(provider_data) > 0 else None
return provider_data
@pytest.fixture(scope="session")