mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
Concurrent requests should not trample (or reuse) each others' provider data. Provider data should be scoped to each request. ## Test Plan Set the uvicorn server to have a single worker process + thread by updating the config: ```python uvicorn_config = { ... "workers": 1, "loop": "asyncio", } ``` Then perform the following steps on `origin/main` (without this change). (1) Run the server using `llama stack run dev` without having `FIREWORKS_API_KEY` in the environment. (2) Run a test by specifying the FIREWORKS_API_KEY env var so it gets stored in the thread local ``` pytest -s -v tests/integration/inference/test_text_inference.py \ --stack-config http://localhost:8321 \ --text-model accounts/fireworks/models/llama-v3p1-8b-instruct \ -k test_text_chat_completion_with_tool_calling_and_streaming \ --env FIREWORKS_API_KEY=<...> ``` Ensure you don't have any other API keys in the environment (otherwise the bug will not reproduce due to other specifics in our testing code.) Verify this works. (3) Run the same command again without specifying FIREWORKS_API_KEY. See that the request actually succeeds when it *should have failed*. ---- Now do the same tests on this branch, verify step (3) results in failure. Finally, run the full `test_text_inference.py` test suite with this change, verify it succeeds.
113 lines
3.4 KiB
Python
113 lines
3.4 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import contextvars
|
|
import json
|
|
import logging
|
|
from typing import Any, AsyncGenerator, ContextManager, Dict, Optional, TypeVar
|
|
|
|
from .utils.dynamic import instantiate_class_type
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
# Context variable for request provider data
|
|
_provider_data_var = contextvars.ContextVar("provider_data", default=None)
|
|
|
|
|
|
class RequestProviderDataContext(ContextManager):
|
|
"""Context manager for request provider data"""
|
|
|
|
def __init__(self, provider_data: Optional[Dict[str, Any]] = None):
|
|
self.provider_data = provider_data
|
|
self.token = None
|
|
|
|
def __enter__(self):
|
|
# Save the current value and set the new one
|
|
self.token = _provider_data_var.set(self.provider_data)
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
# Restore the previous value
|
|
if self.token is not None:
|
|
_provider_data_var.reset(self.token)
|
|
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
def preserve_headers_context_async_generator(gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
|
|
"""
|
|
Wraps an async generator to preserve request headers context variables across iterations.
|
|
|
|
This ensures that context variables set during generator creation are
|
|
available during each iteration of the generator, even if the original
|
|
context manager has exited.
|
|
"""
|
|
# Capture the current context value right now
|
|
context_value = _provider_data_var.get()
|
|
|
|
async def wrapper():
|
|
while True:
|
|
# Set context before each anext() call
|
|
_ = _provider_data_var.set(context_value)
|
|
try:
|
|
item = await gen.__anext__()
|
|
yield item
|
|
except StopAsyncIteration:
|
|
break
|
|
|
|
return wrapper()
|
|
|
|
|
|
class NeedsRequestProviderData:
|
|
def get_request_provider_data(self) -> Any:
|
|
spec = self.__provider_spec__
|
|
assert spec, f"Provider spec not set on {self.__class__}"
|
|
|
|
provider_type = spec.provider_type
|
|
validator_class = spec.provider_data_validator
|
|
if not validator_class:
|
|
raise ValueError(f"Provider {provider_type} does not have a validator")
|
|
|
|
val = _provider_data_var.get()
|
|
if not val:
|
|
return None
|
|
|
|
validator = instantiate_class_type(validator_class)
|
|
try:
|
|
provider_data = validator(**val)
|
|
return provider_data
|
|
except Exception as e:
|
|
log.error(f"Error parsing provider data: {e}")
|
|
return None
|
|
|
|
|
|
def parse_request_provider_data(headers: Dict[str, str]) -> Optional[Dict[str, Any]]:
|
|
"""Parse provider data from request headers"""
|
|
keys = [
|
|
"X-LlamaStack-Provider-Data",
|
|
"x-llamastack-provider-data",
|
|
]
|
|
val = None
|
|
for key in keys:
|
|
val = headers.get(key, None)
|
|
if val:
|
|
break
|
|
|
|
if not val:
|
|
return None
|
|
|
|
try:
|
|
return json.loads(val)
|
|
except json.JSONDecodeError:
|
|
log.error("Provider data not encoded as a JSON object!")
|
|
return None
|
|
|
|
|
|
def request_provider_data_context(headers: Dict[str, str]) -> ContextManager:
|
|
"""Context manager that sets request provider data from headers for the duration of the context"""
|
|
provider_data = parse_request_provider_data(headers)
|
|
return RequestProviderDataContext(provider_data)
|