From 589da45c24548fcd9f1c280f55fa17ca640a1eb7 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 15 Aug 2024 21:23:26 -0700 Subject: [PATCH] feat(pass_through_endpoints.py): initial working CRUD endpoints for /pass_through_endoints --- litellm/proxy/_types.py | 25 ++- .../pass_through_endpoints.py | 170 +++++++++++++++++- litellm/proxy/proxy_server.py | 107 ++++++++--- 3 files changed, 264 insertions(+), 38 deletions(-) diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index cb60fa5f04..94428c7feb 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1082,10 +1082,18 @@ class DynamoDBArgs(LiteLLMBase): assume_role_aws_session_name: Optional[str] = None -class PassThroughEndpointTypedDict(TypedDict): - path: str - target: str - headers: dict +class PassThroughGenericEndpoint(LiteLLMBase): + path: str = Field(description="The route to be added to the LiteLLM Proxy Server.") + target: str = Field( + description="The URL to which requests for this path should be forwarded." + ) + headers: dict = Field( + description="Key-value pairs of headers to be forwarded with the request. You can set any key value pair here and it will be forwarded to your target endpoint" + ) + + +class PassThroughEndpointResponse(LiteLLMBase): + endpoints: List[PassThroughGenericEndpoint] class ConfigFieldUpdate(LiteLLMBase): @@ -1104,6 +1112,7 @@ class FieldDetail(BaseModel): field_type: str field_description: str field_default_value: Any = None + stored_in_db: Optional[bool] class ConfigList(LiteLLMBase): @@ -1219,7 +1228,7 @@ class ConfigGeneralSettings(LiteLLMBase): default=False, description="Public model hub for users to see what models they have access to, supported openai params, etc.", ) - pass_through_endpoints: Optional[PassThroughEndpointTypedDict] = Field( + pass_through_endpoints: Optional[List[PassThroughGenericEndpoint]] = Field( default=None, description="Set-up pass-through endpoints for provider-specific endpoints. Docs - https://docs.litellm.ai/docs/proxy/pass_through", ) @@ -1781,3 +1790,9 @@ class VirtualKeyEvent(LiteLLMBase): created_by_user_role: str created_by_key_alias: Optional[str] request_kwargs: dict + + +class CreatePassThroughEndpoint(LiteLLMBase): + path: str + target: str + headers: dict diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index 1d9f691177..61893a3dc4 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -20,7 +20,14 @@ from fastapi.responses import StreamingResponse import litellm from litellm._logging import verbose_proxy_logger from litellm.integrations.custom_logger import CustomLogger -from litellm.proxy._types import ProxyException, UserAPIKeyAuth +from litellm.proxy._types import ( + ConfigFieldInfo, + ConfigFieldUpdate, + PassThroughEndpointResponse, + PassThroughGenericEndpoint, + ProxyException, + UserAPIKeyAuth, +) from litellm.proxy.auth.user_api_key_auth import user_api_key_auth router = APIRouter() @@ -481,16 +488,64 @@ async def initialize_pass_through_endpoints(pass_through_endpoints: list): @router.get( - "/config/pass_through_endpoint/{endpoint_id}", + "/config/pass_through_endpoint", tags=["Internal User management"], dependencies=[Depends(user_api_key_auth)], + response_model=PassThroughEndpointResponse, ) -async def get_pass_through_endpoints(request: Request, endpoint_id: str): +async def get_pass_through_endpoints( + endpoint_id: Optional[str] = None, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): """ GET configured pass through endpoint. If no endpoint_id given, return all configured endpoints. """ + from litellm.proxy.proxy_server import get_config_general_settings + + ## Get existing pass-through endpoint field value + try: + response: ConfigFieldInfo = await get_config_general_settings( + field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict + ) + except Exception: + return PassThroughEndpointResponse(endpoints=[]) + + pass_through_endpoint_data: Optional[List] = response.field_value + if pass_through_endpoint_data is None: + return PassThroughEndpointResponse(endpoints=[]) + + returned_endpoints = [] + if endpoint_id is None: + for endpoint in pass_through_endpoint_data: + if isinstance(endpoint, dict): + returned_endpoints.append(PassThroughGenericEndpoint(**endpoint)) + elif isinstance(endpoint, PassThroughGenericEndpoint): + returned_endpoints.append(endpoint) + elif endpoint_id is not None: + for endpoint in pass_through_endpoint_data: + _endpoint: Optional[PassThroughGenericEndpoint] = None + if isinstance(endpoint, dict): + _endpoint = PassThroughGenericEndpoint(**endpoint) + elif isinstance(endpoint, PassThroughGenericEndpoint): + _endpoint = endpoint + + if _endpoint is not None and _endpoint.path == endpoint_id: + returned_endpoints.append(_endpoint) + + return PassThroughEndpointResponse(endpoints=returned_endpoints) + + +@router.post( + "/config/pass_through_endpoint/{endpoint_id}", + tags=["Internal User management"], + dependencies=[Depends(user_api_key_auth)], +) +async def update_pass_through_endpoints(request: Request, endpoint_id: str): + """ + Update a pass-through endpoint + """ pass @@ -499,20 +554,119 @@ async def get_pass_through_endpoints(request: Request, endpoint_id: str): tags=["Internal User management"], dependencies=[Depends(user_api_key_auth)], ) -async def create_pass_through_endpoints(request: Request): +async def create_pass_through_endpoints( + data: PassThroughGenericEndpoint, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): """ Create new pass-through endpoint """ - pass + from litellm.proxy.proxy_server import ( + get_config_general_settings, + update_config_general_settings, + ) + + ## Get existing pass-through endpoint field value + + try: + response: ConfigFieldInfo = await get_config_general_settings( + field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict + ) + except Exception: + response = ConfigFieldInfo( + field_name="pass_through_endpoints", field_value=None + ) + + ## Update field with new endpoint + data_dict = data.model_dump() + if response.field_value is None: + response.field_value = [data_dict] + elif isinstance(response.field_value, List): + response.field_value.append(data_dict) + + ## Update db + updated_data = ConfigFieldUpdate( + field_name="pass_through_endpoints", + field_value=response.field_value, + config_type="general_settings", + ) + await update_config_general_settings( + data=updated_data, user_api_key_dict=user_api_key_dict + ) @router.delete( "/config/pass_through_endpoint", tags=["Internal User management"], dependencies=[Depends(user_api_key_auth)], + response_model=PassThroughEndpointResponse, ) -async def delete_pass_through_endpoints(request: Request): +async def delete_pass_through_endpoints( + endpoint_id: str, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): """ - Create new pass-through endpoint + Delete a pass-through endpoint + + Returns - the deleted endpoint """ - pass + from litellm.proxy.proxy_server import ( + get_config_general_settings, + update_config_general_settings, + ) + + ## Get existing pass-through endpoint field value + + try: + response: ConfigFieldInfo = await get_config_general_settings( + field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict + ) + except Exception: + response = ConfigFieldInfo( + field_name="pass_through_endpoints", field_value=None + ) + + ## Update field by removing endpoint + pass_through_endpoint_data: Optional[List] = response.field_value + response_obj: Optional[PassThroughGenericEndpoint] = None + if response.field_value is None or pass_through_endpoint_data is None: + raise HTTPException( + status_code=400, + detail={"error": "There are no pass-through endpoints setup."}, + ) + elif isinstance(response.field_value, List): + invalid_idx: Optional[int] = None + for idx, endpoint in enumerate(pass_through_endpoint_data): + _endpoint: Optional[PassThroughGenericEndpoint] = None + if isinstance(endpoint, dict): + _endpoint = PassThroughGenericEndpoint(**endpoint) + elif isinstance(endpoint, PassThroughGenericEndpoint): + _endpoint = endpoint + + if _endpoint is not None and _endpoint.path == endpoint_id: + invalid_idx = idx + response_obj = _endpoint + + if invalid_idx is not None: + pass_through_endpoint_data.pop(invalid_idx) + + ## Update db + updated_data = ConfigFieldUpdate( + field_name="pass_through_endpoints", + field_value=pass_through_endpoint_data, + config_type="general_settings", + ) + await update_config_general_settings( + data=updated_data, user_api_key_dict=user_api_key_dict + ) + + if response_obj is None: + raise HTTPException( + status_code=400, + detail={ + "error": "Endpoint={} was not found in pass-through endpoint list.".format( + endpoint_id + ) + }, + ) + return PassThroughEndpointResponse(endpoints=[response_obj]) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 6483a61b53..a331e150ef 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -573,6 +573,23 @@ def _resolve_typed_dict_type(typ): return None +def _resolve_pydantic_type(typ) -> List: + """Resolve the actual TypedDict class from a potentially wrapped type.""" + origin = get_origin(typ) + typs = [] + if origin is Union: # Check if it's a Union (like Optional) + for arg in get_args(typ): + if ( + arg is not None + and not isinstance(arg, type(None)) + and "NoneType" not in str(arg) + ): + typs.append(arg) + elif isinstance(typ, type) and isinstance(typ, BaseModel): + return [typ] + return typs + + def prisma_setup(database_url: Optional[str]): global prisma_client, proxy_logging_obj, user_api_key_cache @@ -2204,6 +2221,15 @@ class ProxyConfig: alerting_args=general_settings["alerting_args"], ) + ## PASS-THROUGH ENDPOINTS ## + if "pass_through_endpoints" in _general_settings: + general_settings["pass_through_endpoints"] = _general_settings[ + "pass_through_endpoints" + ] + await initialize_pass_through_endpoints( + pass_through_endpoints=general_settings["pass_through_endpoints"] + ) + async def add_deployment( self, prisma_client: PrismaClient, @@ -9434,7 +9460,7 @@ async def get_config_list( "global_max_parallel_requests": {"type": "Integer"}, "max_request_size_mb": {"type": "Integer"}, "max_response_size_mb": {"type": "Integer"}, - "pass_through_endpoints": {"type": "TypedDictionary"}, + "pass_through_endpoints": {"type": "PydanticModel"}, } return_val = [] @@ -9446,45 +9472,76 @@ async def get_config_list( typed_dict_type = allowed_args[field_name]["type"] - if typed_dict_type == "TypedDictionary": - typed_dict_class: Optional[Any] = _resolve_typed_dict_type( + if typed_dict_type == "PydanticModel": + pydantic_class_list: Optional[Any] = _resolve_pydantic_type( field_info.annotation ) + if pydantic_class_list is None: + continue - if typed_dict_class is None: - nested_fields = None - else: + for pydantic_class in pydantic_class_list: # Get type hints from the TypedDict to create FieldDetail objects nested_fields = [ FieldDetail( field_name=sub_field, - field_type=type_hint.__name__, + field_type=sub_field_type.__name__, field_description="", # Add custom logic if descriptions are available field_default_value=general_settings.get(sub_field, None), + stored_in_db=None, ) - for sub_field, type_hint in get_type_hints( - typed_dict_class - ).items() + for sub_field, sub_field_type in pydantic_class.__annotations__.items() ] + + idx = 0 + for ( + sub_field, + sub_field_info, + ) in pydantic_class.model_fields.items(): + if ( + hasattr(sub_field_info, "description") + and sub_field_info.description is not None + ): + nested_fields[idx].field_description = ( + sub_field_info.description + ) + idx += 1 + + _stored_in_db = None + if field_name in db_general_settings_dict: + _stored_in_db = True + elif field_name in general_settings: + _stored_in_db = False + + _response_obj = ConfigList( + field_name=field_name, + field_type=allowed_args[field_name]["type"], + field_description=field_info.description or "", + field_value=general_settings.get(field_name, None), + stored_in_db=_stored_in_db, + field_default_value=field_info.default, + nested_fields=nested_fields, + ) + return_val.append(_response_obj) + else: nested_fields = None - _stored_in_db = None - if field_name in db_general_settings_dict: - _stored_in_db = True - elif field_name in general_settings: - _stored_in_db = False + _stored_in_db = None + if field_name in db_general_settings_dict: + _stored_in_db = True + elif field_name in general_settings: + _stored_in_db = False - _response_obj = ConfigList( - field_name=field_name, - field_type=allowed_args[field_name]["type"], - field_description=field_info.description or "", - field_value=general_settings.get(field_name, None), - stored_in_db=_stored_in_db, - field_default_value=field_info.default, - nested_fields=nested_fields, - ) - return_val.append(_response_obj) + _response_obj = ConfigList( + field_name=field_name, + field_type=allowed_args[field_name]["type"], + field_description=field_info.description or "", + field_value=general_settings.get(field_name, None), + stored_in_db=_stored_in_db, + field_default_value=field_info.default, + nested_fields=nested_fields, + ) + return_val.append(_response_obj) return return_val