forked from phoenix-oss/llama-stack-mirror
Adding safety adapter for Together
This commit is contained in:
parent
0d2eb3bd25
commit
b85d675c6f
8 changed files with 188 additions and 23 deletions
|
@ -6,7 +6,7 @@
|
|||
|
||||
import json
|
||||
import threading
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from .utils.dynamic import instantiate_class_type
|
||||
|
||||
|
@ -17,8 +17,8 @@ def get_request_provider_data() -> Any:
|
|||
return getattr(_THREAD_LOCAL, "provider_data", None)
|
||||
|
||||
|
||||
def set_request_provider_data(headers: Dict[str, str], validator_class: Optional[str]):
|
||||
if not validator_class:
|
||||
def set_request_provider_data(headers: Dict[str, str], validator_classes: List[str]):
|
||||
if not validator_classes:
|
||||
return
|
||||
|
||||
keys = [
|
||||
|
@ -39,11 +39,12 @@ def set_request_provider_data(headers: Dict[str, str], validator_class: Optional
|
|||
print("Provider data not encoded as a JSON object!", val)
|
||||
return
|
||||
|
||||
validator = instantiate_class_type(validator_class)
|
||||
try:
|
||||
provider_data = validator(**val)
|
||||
except Exception as e:
|
||||
print("Error parsing provider data", e)
|
||||
return
|
||||
|
||||
_THREAD_LOCAL.provider_data = provider_data
|
||||
for validator_class in validator_classes:
|
||||
validator = instantiate_class_type(validator_class)
|
||||
try:
|
||||
provider_data = validator(**val)
|
||||
if provider_data:
|
||||
_THREAD_LOCAL.provider_data = provider_data
|
||||
return
|
||||
except Exception as e:
|
||||
print("Error parsing provider data", e)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue