feat(proxy_server.py): support returning available fields for pass_through_endpoints via `/config/field/list

This commit is contained in:
Krrish Dholakia 2024-08-14 19:07:10 -07:00
parent 10f27bb1b5
commit a020563149
3 changed files with 82 additions and 11 deletions

View file

@ -13,7 +13,15 @@ import traceback
import uuid
import warnings
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, List, Optional
from typing import (
TYPE_CHECKING,
Any,
List,
Optional,
get_args,
get_origin,
get_type_hints,
)
import requests
@ -548,6 +556,20 @@ async def check_request_disconnection(request: Request, llm_api_call_task):
)
def _resolve_typed_dict_type(typ):
"""Resolve the actual TypedDict class from a potentially wrapped type."""
from typing_extensions import _TypedDictMeta # type: ignore
origin = get_origin(typ)
if origin is Union: # Check if it's a Union (like Optional)
for arg in get_args(typ):
if isinstance(arg, _TypedDictMeta):
return arg
elif isinstance(typ, type) and isinstance(typ, dict):
return typ
return None
def prisma_setup(database_url: Optional[str]):
global prisma_client, proxy_logging_obj, user_api_key_cache
@ -9409,6 +9431,7 @@ async def get_config_list(
"global_max_parallel_requests": {"type": "Integer"},
"max_request_size_mb": {"type": "Integer"},
"max_response_size_mb": {"type": "Integer"},
"pass_through_endpoints": {"type": "TypedDictionary"},
}
return_val = []
@ -9416,6 +9439,33 @@ async def get_config_list(
for field_name, field_info in ConfigGeneralSettings.model_fields.items():
if field_name in allowed_args:
## HANDLE TYPED DICT
typed_dict_type = allowed_args[field_name]["type"]
if typed_dict_type == "TypedDictionary":
typed_dict_class: Optional[Any] = _resolve_typed_dict_type(
field_info.annotation
)
if typed_dict_class is None:
nested_fields = None
else:
# Get type hints from the TypedDict to create FieldDetail objects
nested_fields = [
FieldDetail(
field_name=sub_field,
field_type=type_hint.__name__,
field_description="", # Add custom logic if descriptions are available
field_default_value=general_settings.get(sub_field, None),
)
for sub_field, type_hint in get_type_hints(
typed_dict_class
).items()
]
else:
nested_fields = None
_stored_in_db = None
if field_name in db_general_settings_dict:
_stored_in_db = True
@ -9429,6 +9479,7 @@ async def get_config_list(
field_value=general_settings.get(field_name, None),
stored_in_db=_stored_in_db,
field_default_value=field_info.default,
nested_fields=nested_fields,
)
return_val.append(_response_obj)