mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-01 04:20:03 +00:00
address feedback
This commit is contained in:
parent
a900740e30
commit
714c09cd53
4 changed files with 47 additions and 30 deletions
|
|
@ -14,7 +14,7 @@ from .utils.dynamic import instantiate_class_type
|
|||
log = logging.getLogger(__name__)
|
||||
|
||||
# Context variable for request provider data
|
||||
_provider_data_var = contextvars.ContextVar("provider_data", default=None)
|
||||
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
|
||||
|
||||
|
||||
class RequestProviderDataContext(ContextManager):
|
||||
|
|
@ -26,13 +26,13 @@ class RequestProviderDataContext(ContextManager):
|
|||
|
||||
def __enter__(self):
|
||||
# Save the current value and set the new one
|
||||
self.token = _provider_data_var.set(self.provider_data)
|
||||
self.token = PROVIDER_DATA_VAR.set(self.provider_data)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# Restore the previous value
|
||||
if self.token is not None:
|
||||
_provider_data_var.reset(self.token)
|
||||
PROVIDER_DATA_VAR.reset(self.token)
|
||||
|
||||
|
||||
class NeedsRequestProviderData:
|
||||
|
|
@ -45,7 +45,7 @@ class NeedsRequestProviderData:
|
|||
if not validator_class:
|
||||
raise ValueError(f"Provider {provider_type} does not have a validator")
|
||||
|
||||
val = _provider_data_var.get()
|
||||
val = PROVIDER_DATA_VAR.get()
|
||||
if not val:
|
||||
return None
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue