feat(pass_through_endpoints.py): initial working CRUD endpoints for /pass_through_endoints

This commit is contained in:
Krrish Dholakia 2024-08-15 21:23:26 -07:00
parent 28faafadb1
commit 589da45c24
3 changed files with 264 additions and 38 deletions

View file

@ -573,6 +573,23 @@ def _resolve_typed_dict_type(typ):
return None
def _resolve_pydantic_type(typ) -> List:
"""Resolve the actual TypedDict class from a potentially wrapped type."""
origin = get_origin(typ)
typs = []
if origin is Union: # Check if it's a Union (like Optional)
for arg in get_args(typ):
if (
arg is not None
and not isinstance(arg, type(None))
and "NoneType" not in str(arg)
):
typs.append(arg)
elif isinstance(typ, type) and isinstance(typ, BaseModel):
return [typ]
return typs
def prisma_setup(database_url: Optional[str]):
global prisma_client, proxy_logging_obj, user_api_key_cache
@ -2204,6 +2221,15 @@ class ProxyConfig:
alerting_args=general_settings["alerting_args"],
)
## PASS-THROUGH ENDPOINTS ##
if "pass_through_endpoints" in _general_settings:
general_settings["pass_through_endpoints"] = _general_settings[
"pass_through_endpoints"
]
await initialize_pass_through_endpoints(
pass_through_endpoints=general_settings["pass_through_endpoints"]
)
async def add_deployment(
self,
prisma_client: PrismaClient,
@ -9434,7 +9460,7 @@ async def get_config_list(
"global_max_parallel_requests": {"type": "Integer"},
"max_request_size_mb": {"type": "Integer"},
"max_response_size_mb": {"type": "Integer"},
"pass_through_endpoints": {"type": "TypedDictionary"},
"pass_through_endpoints": {"type": "PydanticModel"},
}
return_val = []
@ -9446,45 +9472,76 @@ async def get_config_list(
typed_dict_type = allowed_args[field_name]["type"]
if typed_dict_type == "TypedDictionary":
typed_dict_class: Optional[Any] = _resolve_typed_dict_type(
if typed_dict_type == "PydanticModel":
pydantic_class_list: Optional[Any] = _resolve_pydantic_type(
field_info.annotation
)
if pydantic_class_list is None:
continue
if typed_dict_class is None:
nested_fields = None
else:
for pydantic_class in pydantic_class_list:
# Get type hints from the TypedDict to create FieldDetail objects
nested_fields = [
FieldDetail(
field_name=sub_field,
field_type=type_hint.__name__,
field_type=sub_field_type.__name__,
field_description="", # Add custom logic if descriptions are available
field_default_value=general_settings.get(sub_field, None),
stored_in_db=None,
)
for sub_field, type_hint in get_type_hints(
typed_dict_class
).items()
for sub_field, sub_field_type in pydantic_class.__annotations__.items()
]
idx = 0
for (
sub_field,
sub_field_info,
) in pydantic_class.model_fields.items():
if (
hasattr(sub_field_info, "description")
and sub_field_info.description is not None
):
nested_fields[idx].field_description = (
sub_field_info.description
)
idx += 1
_stored_in_db = None
if field_name in db_general_settings_dict:
_stored_in_db = True
elif field_name in general_settings:
_stored_in_db = False
_response_obj = ConfigList(
field_name=field_name,
field_type=allowed_args[field_name]["type"],
field_description=field_info.description or "",
field_value=general_settings.get(field_name, None),
stored_in_db=_stored_in_db,
field_default_value=field_info.default,
nested_fields=nested_fields,
)
return_val.append(_response_obj)
else:
nested_fields = None
_stored_in_db = None
if field_name in db_general_settings_dict:
_stored_in_db = True
elif field_name in general_settings:
_stored_in_db = False
_stored_in_db = None
if field_name in db_general_settings_dict:
_stored_in_db = True
elif field_name in general_settings:
_stored_in_db = False
_response_obj = ConfigList(
field_name=field_name,
field_type=allowed_args[field_name]["type"],
field_description=field_info.description or "",
field_value=general_settings.get(field_name, None),
stored_in_db=_stored_in_db,
field_default_value=field_info.default,
nested_fields=nested_fields,
)
return_val.append(_response_obj)
_response_obj = ConfigList(
field_name=field_name,
field_type=allowed_args[field_name]["type"],
field_description=field_info.description or "",
field_value=general_settings.get(field_name, None),
stored_in_db=_stored_in_db,
field_default_value=field_info.default,
nested_fields=nested_fields,
)
return_val.append(_response_obj)
return return_val