Adding safety adapter for Together

This commit is contained in:
Yogish Baliga 2024-09-20 09:35:01 -07:00 committed by Ashwin Bharambe
parent 0d2eb3bd25
commit b85d675c6f
8 changed files with 188 additions and 23 deletions

View file

@ -6,7 +6,7 @@
import json
import threading
from typing import Any, Dict, Optional
from typing import Any, Dict, List
from .utils.dynamic import instantiate_class_type
@ -17,8 +17,8 @@ def get_request_provider_data() -> Any:
return getattr(_THREAD_LOCAL, "provider_data", None)
def set_request_provider_data(headers: Dict[str, str], validator_class: Optional[str]):
if not validator_class:
def set_request_provider_data(headers: Dict[str, str], validator_classes: List[str]):
if not validator_classes:
return
keys = [
@ -39,11 +39,12 @@ def set_request_provider_data(headers: Dict[str, str], validator_class: Optional
print("Provider data not encoded as a JSON object!", val)
return
validator = instantiate_class_type(validator_class)
try:
provider_data = validator(**val)
except Exception as e:
print("Error parsing provider data", e)
return
_THREAD_LOCAL.provider_data = provider_data
for validator_class in validator_classes:
validator = instantiate_class_type(validator_class)
try:
provider_data = validator(**val)
if provider_data:
_THREAD_LOCAL.provider_data = provider_data
return
except Exception as e:
print("Error parsing provider data", e)