mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
fix: Use re-entrancy and concurrency safe context managers for provider data (#1498)
Concurrent requests should not trample (or reuse) each others' provider data. Provider data should be scoped to each request. ## Test Plan Set the uvicorn server to have a single worker process + thread by updating the config: ```python uvicorn_config = { ... "workers": 1, "loop": "asyncio", } ``` Then perform the following steps on `origin/main` (without this change). (1) Run the server using `llama stack run dev` without having `FIREWORKS_API_KEY` in the environment. (2) Run a test by specifying the FIREWORKS_API_KEY env var so it gets stored in the thread local ``` pytest -s -v tests/integration/inference/test_text_inference.py \ --stack-config http://localhost:8321 \ --text-model accounts/fireworks/models/llama-v3p1-8b-instruct \ -k test_text_chat_completion_with_tool_calling_and_streaming \ --env FIREWORKS_API_KEY=<...> ``` Ensure you don't have any other API keys in the environment (otherwise the bug will not reproduce due to other specifics in our testing code.) Verify this works. (3) Run the same command again without specifying FIREWORKS_API_KEY. See that the request actually succeeds when it *should have failed*. ---- Now do the same tests on this branch, verify step (3) results in failure. Finally, run the full `test_text_inference.py` test suite with this change, verify it succeeds.
This commit is contained in:
parent
6033e6893e
commit
205661bc78
6 changed files with 114 additions and 46 deletions
|
@ -29,7 +29,10 @@ from typing_extensions import Annotated
|
|||
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.request_headers import (
|
||||
preserve_headers_context_async_generator,
|
||||
request_provider_data_context,
|
||||
)
|
||||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
from llama_stack.distribution.stack import (
|
||||
construct_stack,
|
||||
|
@ -202,16 +205,14 @@ async def maybe_await(value):
|
|||
|
||||
async def sse_generator(event_gen):
|
||||
try:
|
||||
event_gen = await event_gen
|
||||
async for item in event_gen:
|
||||
async for item in await event_gen:
|
||||
yield create_sse_event(item)
|
||||
await asyncio.sleep(0.01)
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Generator cancelled")
|
||||
await event_gen.aclose()
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in sse_generator: {e}")
|
||||
logger.exception(f"Traceback: {''.join(traceback.format_exception(type(e), e, e.__traceback__))}")
|
||||
logger.exception("Error in sse_generator")
|
||||
yield create_sse_event(
|
||||
{
|
||||
"error": {
|
||||
|
@ -223,18 +224,20 @@ async def sse_generator(event_gen):
|
|||
|
||||
def create_dynamic_typed_route(func: Any, method: str, route: str):
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
set_request_provider_data(request.headers)
|
||||
# Use context manager for request provider data
|
||||
with request_provider_data_context(request.headers):
|
||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||
|
||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||
try:
|
||||
if is_streaming:
|
||||
return StreamingResponse(sse_generator(func(**kwargs)), media_type="text/event-stream")
|
||||
else:
|
||||
value = func(**kwargs)
|
||||
return await maybe_await(value)
|
||||
except Exception as e:
|
||||
traceback.print_exception(e)
|
||||
raise translate_exception(e) from e
|
||||
try:
|
||||
if is_streaming:
|
||||
gen = preserve_headers_context_async_generator(sse_generator(func(**kwargs)))
|
||||
return StreamingResponse(gen, media_type="text/event-stream")
|
||||
else:
|
||||
value = func(**kwargs)
|
||||
return await maybe_await(value)
|
||||
except Exception as e:
|
||||
logger.exception("Error executing endpoint %s", method, route)
|
||||
raise translate_exception(e) from e
|
||||
|
||||
sig = inspect.signature(func)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue