mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-11 20:40:40 +00:00
fixes
This commit is contained in:
parent
21769648a6
commit
d89ef35151
5 changed files with 28 additions and 23 deletions
|
@ -38,7 +38,7 @@ class RequestProviderDataContext(ContextManager):
|
|||
T = TypeVar("T")
|
||||
|
||||
|
||||
async def preserve_headers_context_async_generator(gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
|
||||
def preserve_headers_context_async_generator(gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
|
||||
"""
|
||||
Wraps an async generator to preserve request headers context variables across iterations.
|
||||
|
||||
|
@ -46,18 +46,23 @@ async def preserve_headers_context_async_generator(gen: AsyncGenerator[T, None])
|
|||
available during each iteration of the generator, even if the original
|
||||
context manager has exited.
|
||||
"""
|
||||
# Capture the current context value
|
||||
# Capture the current context value right now
|
||||
context_value = _provider_data_var.get()
|
||||
|
||||
# Create a wrapper that restores context for each iteration
|
||||
async for item in gen:
|
||||
# Save the current token to restore later
|
||||
token = _provider_data_var.set(context_value)
|
||||
try:
|
||||
yield item
|
||||
finally:
|
||||
# Restore the previous value
|
||||
_provider_data_var.reset(token)
|
||||
async def wrapper():
|
||||
while True:
|
||||
# Set context before each anext() call
|
||||
token = _provider_data_var.set(context_value)
|
||||
try:
|
||||
item = await gen.__anext__()
|
||||
yield item
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
finally:
|
||||
# Restore the previous value
|
||||
_provider_data_var.reset(token)
|
||||
|
||||
return wrapper()
|
||||
|
||||
|
||||
class NeedsRequestProviderData:
|
||||
|
|
|
@ -205,18 +205,14 @@ async def maybe_await(value):
|
|||
|
||||
async def sse_generator(event_gen):
|
||||
try:
|
||||
event_gen = await event_gen
|
||||
# Wrap the generator to preserve context across iterations
|
||||
wrapped_gen = preserve_headers_context_async_generator(event_gen)
|
||||
async for item in wrapped_gen:
|
||||
async for item in await event_gen:
|
||||
yield create_sse_event(item)
|
||||
await asyncio.sleep(0.01)
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Generator cancelled")
|
||||
await event_gen.aclose()
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in sse_generator: {e}")
|
||||
logger.exception(f"Traceback: {''.join(traceback.format_exception(type(e), e, e.__traceback__))}")
|
||||
logger.exception("Error in sse_generator")
|
||||
yield create_sse_event(
|
||||
{
|
||||
"error": {
|
||||
|
@ -231,9 +227,11 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
|
|||
# Use context manager for request provider data
|
||||
with request_provider_data_context(request.headers):
|
||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||
|
||||
try:
|
||||
if is_streaming:
|
||||
return StreamingResponse(sse_generator(func(**kwargs)), media_type="text/event-stream")
|
||||
gen = preserve_headers_context_async_generator(sse_generator(func(**kwargs)))
|
||||
return StreamingResponse(gen, media_type="text/event-stream")
|
||||
else:
|
||||
value = func(**kwargs)
|
||||
return await maybe_await(value)
|
||||
|
|
|
@ -70,8 +70,9 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
pass
|
||||
|
||||
def _get_api_key(self) -> str:
|
||||
if self.config.api_key is not None:
|
||||
return self.config.api_key.get_secret_value()
|
||||
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
|
||||
if config_api_key:
|
||||
return config_api_key
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.fireworks_api_key:
|
||||
|
|
|
@ -93,8 +93,9 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
|
||||
def _get_client(self) -> Together:
|
||||
together_api_key = None
|
||||
if self.config.api_key is not None:
|
||||
together_api_key = self.config.api_key.get_secret_value()
|
||||
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
|
||||
if config_api_key:
|
||||
together_api_key = config_api_key
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.together_api_key:
|
||||
|
|
|
@ -42,7 +42,7 @@ def provider_data():
|
|||
for key, value in keymap.items():
|
||||
if os.environ.get(key):
|
||||
provider_data[value] = os.environ[key]
|
||||
return provider_data if len(provider_data) > 0 else None
|
||||
return provider_data
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue