feat(pass_through_endpoints.py): initial working CRUD endpoints for /pass_through_endoints

This commit is contained in:
Krrish Dholakia 2024-08-15 21:23:26 -07:00
parent 28faafadb1
commit 589da45c24
3 changed files with 264 additions and 38 deletions

View file

@ -20,7 +20,14 @@ from fastapi.responses import StreamingResponse
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import ProxyException, UserAPIKeyAuth
from litellm.proxy._types import (
ConfigFieldInfo,
ConfigFieldUpdate,
PassThroughEndpointResponse,
PassThroughGenericEndpoint,
ProxyException,
UserAPIKeyAuth,
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
router = APIRouter()
@ -481,16 +488,64 @@ async def initialize_pass_through_endpoints(pass_through_endpoints: list):
@router.get(
"/config/pass_through_endpoint/{endpoint_id}",
"/config/pass_through_endpoint",
tags=["Internal User management"],
dependencies=[Depends(user_api_key_auth)],
response_model=PassThroughEndpointResponse,
)
async def get_pass_through_endpoints(request: Request, endpoint_id: str):
async def get_pass_through_endpoints(
endpoint_id: Optional[str] = None,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
GET configured pass through endpoint.
If no endpoint_id given, return all configured endpoints.
"""
from litellm.proxy.proxy_server import get_config_general_settings
## Get existing pass-through endpoint field value
try:
response: ConfigFieldInfo = await get_config_general_settings(
field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict
)
except Exception:
return PassThroughEndpointResponse(endpoints=[])
pass_through_endpoint_data: Optional[List] = response.field_value
if pass_through_endpoint_data is None:
return PassThroughEndpointResponse(endpoints=[])
returned_endpoints = []
if endpoint_id is None:
for endpoint in pass_through_endpoint_data:
if isinstance(endpoint, dict):
returned_endpoints.append(PassThroughGenericEndpoint(**endpoint))
elif isinstance(endpoint, PassThroughGenericEndpoint):
returned_endpoints.append(endpoint)
elif endpoint_id is not None:
for endpoint in pass_through_endpoint_data:
_endpoint: Optional[PassThroughGenericEndpoint] = None
if isinstance(endpoint, dict):
_endpoint = PassThroughGenericEndpoint(**endpoint)
elif isinstance(endpoint, PassThroughGenericEndpoint):
_endpoint = endpoint
if _endpoint is not None and _endpoint.path == endpoint_id:
returned_endpoints.append(_endpoint)
return PassThroughEndpointResponse(endpoints=returned_endpoints)
@router.post(
"/config/pass_through_endpoint/{endpoint_id}",
tags=["Internal User management"],
dependencies=[Depends(user_api_key_auth)],
)
async def update_pass_through_endpoints(request: Request, endpoint_id: str):
"""
Update a pass-through endpoint
"""
pass
@ -499,20 +554,119 @@ async def get_pass_through_endpoints(request: Request, endpoint_id: str):
tags=["Internal User management"],
dependencies=[Depends(user_api_key_auth)],
)
async def create_pass_through_endpoints(request: Request):
async def create_pass_through_endpoints(
data: PassThroughGenericEndpoint,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Create new pass-through endpoint
"""
pass
from litellm.proxy.proxy_server import (
get_config_general_settings,
update_config_general_settings,
)
## Get existing pass-through endpoint field value
try:
response: ConfigFieldInfo = await get_config_general_settings(
field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict
)
except Exception:
response = ConfigFieldInfo(
field_name="pass_through_endpoints", field_value=None
)
## Update field with new endpoint
data_dict = data.model_dump()
if response.field_value is None:
response.field_value = [data_dict]
elif isinstance(response.field_value, List):
response.field_value.append(data_dict)
## Update db
updated_data = ConfigFieldUpdate(
field_name="pass_through_endpoints",
field_value=response.field_value,
config_type="general_settings",
)
await update_config_general_settings(
data=updated_data, user_api_key_dict=user_api_key_dict
)
@router.delete(
"/config/pass_through_endpoint",
tags=["Internal User management"],
dependencies=[Depends(user_api_key_auth)],
response_model=PassThroughEndpointResponse,
)
async def delete_pass_through_endpoints(request: Request):
async def delete_pass_through_endpoints(
endpoint_id: str,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Create new pass-through endpoint
Delete a pass-through endpoint
Returns - the deleted endpoint
"""
pass
from litellm.proxy.proxy_server import (
get_config_general_settings,
update_config_general_settings,
)
## Get existing pass-through endpoint field value
try:
response: ConfigFieldInfo = await get_config_general_settings(
field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict
)
except Exception:
response = ConfigFieldInfo(
field_name="pass_through_endpoints", field_value=None
)
## Update field by removing endpoint
pass_through_endpoint_data: Optional[List] = response.field_value
response_obj: Optional[PassThroughGenericEndpoint] = None
if response.field_value is None or pass_through_endpoint_data is None:
raise HTTPException(
status_code=400,
detail={"error": "There are no pass-through endpoints setup."},
)
elif isinstance(response.field_value, List):
invalid_idx: Optional[int] = None
for idx, endpoint in enumerate(pass_through_endpoint_data):
_endpoint: Optional[PassThroughGenericEndpoint] = None
if isinstance(endpoint, dict):
_endpoint = PassThroughGenericEndpoint(**endpoint)
elif isinstance(endpoint, PassThroughGenericEndpoint):
_endpoint = endpoint
if _endpoint is not None and _endpoint.path == endpoint_id:
invalid_idx = idx
response_obj = _endpoint
if invalid_idx is not None:
pass_through_endpoint_data.pop(invalid_idx)
## Update db
updated_data = ConfigFieldUpdate(
field_name="pass_through_endpoints",
field_value=pass_through_endpoint_data,
config_type="general_settings",
)
await update_config_general_settings(
data=updated_data, user_api_key_dict=user_api_key_dict
)
if response_obj is None:
raise HTTPException(
status_code=400,
detail={
"error": "Endpoint={} was not found in pass-through endpoint list.".format(
endpoint_id
)
},
)
return PassThroughEndpointResponse(endpoints=[response_obj])