mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
fix: Use re-entrancy and concurrency safe context managers for provider data
This commit is contained in:
parent
6033e6893e
commit
4a894b925d
3 changed files with 67 additions and 37 deletions
|
@ -32,7 +32,7 @@ from termcolor import cprint
|
||||||
from llama_stack.distribution.build import print_pip_install_help
|
from llama_stack.distribution.build import print_pip_install_help
|
||||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||||
from llama_stack.distribution.datatypes import Api
|
from llama_stack.distribution.datatypes import Api
|
||||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
from llama_stack.distribution.request_headers import request_provider_data_context
|
||||||
from llama_stack.distribution.resolver import ProviderRegistry
|
from llama_stack.distribution.resolver import ProviderRegistry
|
||||||
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
||||||
from llama_stack.distribution.stack import (
|
from llama_stack.distribution.stack import (
|
||||||
|
@ -262,21 +262,25 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
if not self.endpoint_impls:
|
if not self.endpoint_impls:
|
||||||
raise ValueError("Client not initialized")
|
raise ValueError("Client not initialized")
|
||||||
|
|
||||||
|
# Create headers with provider data if available
|
||||||
|
headers = {}
|
||||||
if self.provider_data:
|
if self.provider_data:
|
||||||
set_request_provider_data({"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)})
|
headers["X-LlamaStack-Provider-Data"] = json.dumps(self.provider_data)
|
||||||
|
|
||||||
if stream:
|
# Use context manager for provider data
|
||||||
response = await self._call_streaming(
|
with request_provider_data_context(headers):
|
||||||
cast_to=cast_to,
|
if stream:
|
||||||
options=options,
|
response = await self._call_streaming(
|
||||||
stream_cls=stream_cls,
|
cast_to=cast_to,
|
||||||
)
|
options=options,
|
||||||
else:
|
stream_cls=stream_cls,
|
||||||
response = await self._call_non_streaming(
|
)
|
||||||
cast_to=cast_to,
|
else:
|
||||||
options=options,
|
response = await self._call_non_streaming(
|
||||||
)
|
cast_to=cast_to,
|
||||||
return response
|
options=options,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
def _find_matching_endpoint(self, method: str, path: str) -> tuple[Any, dict]:
|
def _find_matching_endpoint(self, method: str, path: str) -> tuple[Any, dict]:
|
||||||
"""Find the matching endpoint implementation for a given method and path.
|
"""Find the matching endpoint implementation for a given method and path.
|
||||||
|
|
|
@ -4,16 +4,35 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import contextvars
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import threading
|
from typing import Any, ContextManager, Dict, Optional
|
||||||
from typing import Any, Dict
|
|
||||||
|
|
||||||
from .utils.dynamic import instantiate_class_type
|
from .utils.dynamic import instantiate_class_type
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
_THREAD_LOCAL = threading.local()
|
# Context variable for request provider data
|
||||||
|
_provider_data_var = contextvars.ContextVar("provider_data", default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class RequestProviderDataContext(ContextManager):
|
||||||
|
"""Context manager for request provider data"""
|
||||||
|
|
||||||
|
def __init__(self, provider_data: Optional[Dict[str, Any]] = None):
|
||||||
|
self.provider_data = provider_data
|
||||||
|
self.token = None
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
# Save the current value and set the new one
|
||||||
|
self.token = _provider_data_var.set(self.provider_data)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
# Restore the previous value
|
||||||
|
if self.token is not None:
|
||||||
|
_provider_data_var.reset(self.token)
|
||||||
|
|
||||||
|
|
||||||
class NeedsRequestProviderData:
|
class NeedsRequestProviderData:
|
||||||
|
@ -26,7 +45,7 @@ class NeedsRequestProviderData:
|
||||||
if not validator_class:
|
if not validator_class:
|
||||||
raise ValueError(f"Provider {provider_type} does not have a validator")
|
raise ValueError(f"Provider {provider_type} does not have a validator")
|
||||||
|
|
||||||
val = getattr(_THREAD_LOCAL, "provider_data_header_value", None)
|
val = _provider_data_var.get()
|
||||||
if not val:
|
if not val:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -36,25 +55,32 @@ class NeedsRequestProviderData:
|
||||||
return provider_data
|
return provider_data
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error parsing provider data: {e}")
|
log.error(f"Error parsing provider data: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def set_request_provider_data(headers: Dict[str, str]):
|
def parse_request_provider_data(headers: Dict[str, str]) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Parse provider data from request headers"""
|
||||||
keys = [
|
keys = [
|
||||||
"X-LlamaStack-Provider-Data",
|
"X-LlamaStack-Provider-Data",
|
||||||
"x-llamastack-provider-data",
|
"x-llamastack-provider-data",
|
||||||
]
|
]
|
||||||
|
val = None
|
||||||
for key in keys:
|
for key in keys:
|
||||||
val = headers.get(key, None)
|
val = headers.get(key, None)
|
||||||
if val:
|
if val:
|
||||||
break
|
break
|
||||||
|
|
||||||
if not val:
|
if not val:
|
||||||
return
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
val = json.loads(val)
|
return json.loads(val)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
log.error("Provider data not encoded as a JSON object!", val)
|
log.error("Provider data not encoded as a JSON object!")
|
||||||
return
|
return None
|
||||||
|
|
||||||
_THREAD_LOCAL.provider_data_header_value = val
|
|
||||||
|
def request_provider_data_context(headers: Dict[str, str]) -> ContextManager:
|
||||||
|
"""Context manager that sets request provider data from headers for the duration of the context"""
|
||||||
|
provider_data = parse_request_provider_data(headers)
|
||||||
|
return RequestProviderDataContext(provider_data)
|
||||||
|
|
|
@ -29,7 +29,7 @@ from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import StackRunConfig
|
from llama_stack.distribution.datatypes import StackRunConfig
|
||||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
from llama_stack.distribution.request_headers import request_provider_data_context
|
||||||
from llama_stack.distribution.resolver import InvalidProviderError
|
from llama_stack.distribution.resolver import InvalidProviderError
|
||||||
from llama_stack.distribution.stack import (
|
from llama_stack.distribution.stack import (
|
||||||
construct_stack,
|
construct_stack,
|
||||||
|
@ -223,18 +223,18 @@ async def sse_generator(event_gen):
|
||||||
|
|
||||||
def create_dynamic_typed_route(func: Any, method: str, route: str):
|
def create_dynamic_typed_route(func: Any, method: str, route: str):
|
||||||
async def endpoint(request: Request, **kwargs):
|
async def endpoint(request: Request, **kwargs):
|
||||||
set_request_provider_data(request.headers)
|
# Use context manager for request provider data
|
||||||
|
with request_provider_data_context(request.headers):
|
||||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||||
try:
|
try:
|
||||||
if is_streaming:
|
if is_streaming:
|
||||||
return StreamingResponse(sse_generator(func(**kwargs)), media_type="text/event-stream")
|
return StreamingResponse(sse_generator(func(**kwargs)), media_type="text/event-stream")
|
||||||
else:
|
else:
|
||||||
value = func(**kwargs)
|
value = func(**kwargs)
|
||||||
return await maybe_await(value)
|
return await maybe_await(value)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exception(e)
|
traceback.print_exception(e)
|
||||||
raise translate_exception(e) from e
|
raise translate_exception(e) from e
|
||||||
|
|
||||||
sig = inspect.signature(func)
|
sig = inspect.signature(func)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue