mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-11 20:40:40 +00:00
fixes
This commit is contained in:
parent
21769648a6
commit
d89ef35151
5 changed files with 28 additions and 23 deletions
|
@ -38,7 +38,7 @@ class RequestProviderDataContext(ContextManager):
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
async def preserve_headers_context_async_generator(gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
|
def preserve_headers_context_async_generator(gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
|
||||||
"""
|
"""
|
||||||
Wraps an async generator to preserve request headers context variables across iterations.
|
Wraps an async generator to preserve request headers context variables across iterations.
|
||||||
|
|
||||||
|
@ -46,18 +46,23 @@ async def preserve_headers_context_async_generator(gen: AsyncGenerator[T, None])
|
||||||
available during each iteration of the generator, even if the original
|
available during each iteration of the generator, even if the original
|
||||||
context manager has exited.
|
context manager has exited.
|
||||||
"""
|
"""
|
||||||
# Capture the current context value
|
# Capture the current context value right now
|
||||||
context_value = _provider_data_var.get()
|
context_value = _provider_data_var.get()
|
||||||
|
|
||||||
# Create a wrapper that restores context for each iteration
|
async def wrapper():
|
||||||
async for item in gen:
|
while True:
|
||||||
# Save the current token to restore later
|
# Set context before each anext() call
|
||||||
token = _provider_data_var.set(context_value)
|
token = _provider_data_var.set(context_value)
|
||||||
try:
|
try:
|
||||||
yield item
|
item = await gen.__anext__()
|
||||||
finally:
|
yield item
|
||||||
# Restore the previous value
|
except StopAsyncIteration:
|
||||||
_provider_data_var.reset(token)
|
break
|
||||||
|
finally:
|
||||||
|
# Restore the previous value
|
||||||
|
_provider_data_var.reset(token)
|
||||||
|
|
||||||
|
return wrapper()
|
||||||
|
|
||||||
|
|
||||||
class NeedsRequestProviderData:
|
class NeedsRequestProviderData:
|
||||||
|
|
|
@ -205,18 +205,14 @@ async def maybe_await(value):
|
||||||
|
|
||||||
async def sse_generator(event_gen):
|
async def sse_generator(event_gen):
|
||||||
try:
|
try:
|
||||||
event_gen = await event_gen
|
async for item in await event_gen:
|
||||||
# Wrap the generator to preserve context across iterations
|
|
||||||
wrapped_gen = preserve_headers_context_async_generator(event_gen)
|
|
||||||
async for item in wrapped_gen:
|
|
||||||
yield create_sse_event(item)
|
yield create_sse_event(item)
|
||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.info("Generator cancelled")
|
logger.info("Generator cancelled")
|
||||||
await event_gen.aclose()
|
await event_gen.aclose()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Error in sse_generator: {e}")
|
logger.exception("Error in sse_generator")
|
||||||
logger.exception(f"Traceback: {''.join(traceback.format_exception(type(e), e, e.__traceback__))}")
|
|
||||||
yield create_sse_event(
|
yield create_sse_event(
|
||||||
{
|
{
|
||||||
"error": {
|
"error": {
|
||||||
|
@ -231,9 +227,11 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
|
||||||
# Use context manager for request provider data
|
# Use context manager for request provider data
|
||||||
with request_provider_data_context(request.headers):
|
with request_provider_data_context(request.headers):
|
||||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if is_streaming:
|
if is_streaming:
|
||||||
return StreamingResponse(sse_generator(func(**kwargs)), media_type="text/event-stream")
|
gen = preserve_headers_context_async_generator(sse_generator(func(**kwargs)))
|
||||||
|
return StreamingResponse(gen, media_type="text/event-stream")
|
||||||
else:
|
else:
|
||||||
value = func(**kwargs)
|
value = func(**kwargs)
|
||||||
return await maybe_await(value)
|
return await maybe_await(value)
|
||||||
|
|
|
@ -70,8 +70,9 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _get_api_key(self) -> str:
|
def _get_api_key(self) -> str:
|
||||||
if self.config.api_key is not None:
|
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
|
||||||
return self.config.api_key.get_secret_value()
|
if config_api_key:
|
||||||
|
return config_api_key
|
||||||
else:
|
else:
|
||||||
provider_data = self.get_request_provider_data()
|
provider_data = self.get_request_provider_data()
|
||||||
if provider_data is None or not provider_data.fireworks_api_key:
|
if provider_data is None or not provider_data.fireworks_api_key:
|
||||||
|
|
|
@ -93,8 +93,9 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
||||||
|
|
||||||
def _get_client(self) -> Together:
|
def _get_client(self) -> Together:
|
||||||
together_api_key = None
|
together_api_key = None
|
||||||
if self.config.api_key is not None:
|
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
|
||||||
together_api_key = self.config.api_key.get_secret_value()
|
if config_api_key:
|
||||||
|
together_api_key = config_api_key
|
||||||
else:
|
else:
|
||||||
provider_data = self.get_request_provider_data()
|
provider_data = self.get_request_provider_data()
|
||||||
if provider_data is None or not provider_data.together_api_key:
|
if provider_data is None or not provider_data.together_api_key:
|
||||||
|
|
|
@ -42,7 +42,7 @@ def provider_data():
|
||||||
for key, value in keymap.items():
|
for key, value in keymap.items():
|
||||||
if os.environ.get(key):
|
if os.environ.get(key):
|
||||||
provider_data[value] = os.environ[key]
|
provider_data[value] = os.environ[key]
|
||||||
return provider_data if len(provider_data) > 0 else None
|
return provider_data
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue