feat(pass_through_endpoints.py): initial working CRUD endpoints for /pass_through_endoints

This commit is contained in:
Krrish Dholakia 2024-08-15 21:23:26 -07:00
parent 28faafadb1
commit 589da45c24
3 changed files with 264 additions and 38 deletions

View file

@ -1082,10 +1082,18 @@ class DynamoDBArgs(LiteLLMBase):
assume_role_aws_session_name: Optional[str] = None assume_role_aws_session_name: Optional[str] = None
class PassThroughEndpointTypedDict(TypedDict): class PassThroughGenericEndpoint(LiteLLMBase):
path: str path: str = Field(description="The route to be added to the LiteLLM Proxy Server.")
target: str target: str = Field(
headers: dict description="The URL to which requests for this path should be forwarded."
)
headers: dict = Field(
description="Key-value pairs of headers to be forwarded with the request. You can set any key value pair here and it will be forwarded to your target endpoint"
)
class PassThroughEndpointResponse(LiteLLMBase):
endpoints: List[PassThroughGenericEndpoint]
class ConfigFieldUpdate(LiteLLMBase): class ConfigFieldUpdate(LiteLLMBase):
@ -1104,6 +1112,7 @@ class FieldDetail(BaseModel):
field_type: str field_type: str
field_description: str field_description: str
field_default_value: Any = None field_default_value: Any = None
stored_in_db: Optional[bool]
class ConfigList(LiteLLMBase): class ConfigList(LiteLLMBase):
@ -1219,7 +1228,7 @@ class ConfigGeneralSettings(LiteLLMBase):
default=False, default=False,
description="Public model hub for users to see what models they have access to, supported openai params, etc.", description="Public model hub for users to see what models they have access to, supported openai params, etc.",
) )
pass_through_endpoints: Optional[PassThroughEndpointTypedDict] = Field( pass_through_endpoints: Optional[List[PassThroughGenericEndpoint]] = Field(
default=None, default=None,
description="Set-up pass-through endpoints for provider-specific endpoints. Docs - https://docs.litellm.ai/docs/proxy/pass_through", description="Set-up pass-through endpoints for provider-specific endpoints. Docs - https://docs.litellm.ai/docs/proxy/pass_through",
) )
@ -1781,3 +1790,9 @@ class VirtualKeyEvent(LiteLLMBase):
created_by_user_role: str created_by_user_role: str
created_by_key_alias: Optional[str] created_by_key_alias: Optional[str]
request_kwargs: dict request_kwargs: dict
class CreatePassThroughEndpoint(LiteLLMBase):
path: str
target: str
headers: dict

View file

@ -20,7 +20,14 @@ from fastapi.responses import StreamingResponse
import litellm import litellm
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import ProxyException, UserAPIKeyAuth from litellm.proxy._types import (
ConfigFieldInfo,
ConfigFieldUpdate,
PassThroughEndpointResponse,
PassThroughGenericEndpoint,
ProxyException,
UserAPIKeyAuth,
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
router = APIRouter() router = APIRouter()
@ -481,16 +488,64 @@ async def initialize_pass_through_endpoints(pass_through_endpoints: list):
@router.get( @router.get(
"/config/pass_through_endpoint/{endpoint_id}", "/config/pass_through_endpoint",
tags=["Internal User management"], tags=["Internal User management"],
dependencies=[Depends(user_api_key_auth)], dependencies=[Depends(user_api_key_auth)],
response_model=PassThroughEndpointResponse,
) )
async def get_pass_through_endpoints(request: Request, endpoint_id: str): async def get_pass_through_endpoints(
endpoint_id: Optional[str] = None,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
""" """
GET configured pass through endpoint. GET configured pass through endpoint.
If no endpoint_id given, return all configured endpoints. If no endpoint_id given, return all configured endpoints.
""" """
from litellm.proxy.proxy_server import get_config_general_settings
## Get existing pass-through endpoint field value
try:
response: ConfigFieldInfo = await get_config_general_settings(
field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict
)
except Exception:
return PassThroughEndpointResponse(endpoints=[])
pass_through_endpoint_data: Optional[List] = response.field_value
if pass_through_endpoint_data is None:
return PassThroughEndpointResponse(endpoints=[])
returned_endpoints = []
if endpoint_id is None:
for endpoint in pass_through_endpoint_data:
if isinstance(endpoint, dict):
returned_endpoints.append(PassThroughGenericEndpoint(**endpoint))
elif isinstance(endpoint, PassThroughGenericEndpoint):
returned_endpoints.append(endpoint)
elif endpoint_id is not None:
for endpoint in pass_through_endpoint_data:
_endpoint: Optional[PassThroughGenericEndpoint] = None
if isinstance(endpoint, dict):
_endpoint = PassThroughGenericEndpoint(**endpoint)
elif isinstance(endpoint, PassThroughGenericEndpoint):
_endpoint = endpoint
if _endpoint is not None and _endpoint.path == endpoint_id:
returned_endpoints.append(_endpoint)
return PassThroughEndpointResponse(endpoints=returned_endpoints)
@router.post(
"/config/pass_through_endpoint/{endpoint_id}",
tags=["Internal User management"],
dependencies=[Depends(user_api_key_auth)],
)
async def update_pass_through_endpoints(request: Request, endpoint_id: str):
"""
Update a pass-through endpoint
"""
pass pass
@ -499,20 +554,119 @@ async def get_pass_through_endpoints(request: Request, endpoint_id: str):
tags=["Internal User management"], tags=["Internal User management"],
dependencies=[Depends(user_api_key_auth)], dependencies=[Depends(user_api_key_auth)],
) )
async def create_pass_through_endpoints(request: Request): async def create_pass_through_endpoints(
data: PassThroughGenericEndpoint,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
""" """
Create new pass-through endpoint Create new pass-through endpoint
""" """
pass from litellm.proxy.proxy_server import (
get_config_general_settings,
update_config_general_settings,
)
## Get existing pass-through endpoint field value
try:
response: ConfigFieldInfo = await get_config_general_settings(
field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict
)
except Exception:
response = ConfigFieldInfo(
field_name="pass_through_endpoints", field_value=None
)
## Update field with new endpoint
data_dict = data.model_dump()
if response.field_value is None:
response.field_value = [data_dict]
elif isinstance(response.field_value, List):
response.field_value.append(data_dict)
## Update db
updated_data = ConfigFieldUpdate(
field_name="pass_through_endpoints",
field_value=response.field_value,
config_type="general_settings",
)
await update_config_general_settings(
data=updated_data, user_api_key_dict=user_api_key_dict
)
@router.delete( @router.delete(
"/config/pass_through_endpoint", "/config/pass_through_endpoint",
tags=["Internal User management"], tags=["Internal User management"],
dependencies=[Depends(user_api_key_auth)], dependencies=[Depends(user_api_key_auth)],
response_model=PassThroughEndpointResponse,
) )
async def delete_pass_through_endpoints(request: Request): async def delete_pass_through_endpoints(
endpoint_id: str,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
""" """
Create new pass-through endpoint Delete a pass-through endpoint
Returns - the deleted endpoint
""" """
pass from litellm.proxy.proxy_server import (
get_config_general_settings,
update_config_general_settings,
)
## Get existing pass-through endpoint field value
try:
response: ConfigFieldInfo = await get_config_general_settings(
field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict
)
except Exception:
response = ConfigFieldInfo(
field_name="pass_through_endpoints", field_value=None
)
## Update field by removing endpoint
pass_through_endpoint_data: Optional[List] = response.field_value
response_obj: Optional[PassThroughGenericEndpoint] = None
if response.field_value is None or pass_through_endpoint_data is None:
raise HTTPException(
status_code=400,
detail={"error": "There are no pass-through endpoints setup."},
)
elif isinstance(response.field_value, List):
invalid_idx: Optional[int] = None
for idx, endpoint in enumerate(pass_through_endpoint_data):
_endpoint: Optional[PassThroughGenericEndpoint] = None
if isinstance(endpoint, dict):
_endpoint = PassThroughGenericEndpoint(**endpoint)
elif isinstance(endpoint, PassThroughGenericEndpoint):
_endpoint = endpoint
if _endpoint is not None and _endpoint.path == endpoint_id:
invalid_idx = idx
response_obj = _endpoint
if invalid_idx is not None:
pass_through_endpoint_data.pop(invalid_idx)
## Update db
updated_data = ConfigFieldUpdate(
field_name="pass_through_endpoints",
field_value=pass_through_endpoint_data,
config_type="general_settings",
)
await update_config_general_settings(
data=updated_data, user_api_key_dict=user_api_key_dict
)
if response_obj is None:
raise HTTPException(
status_code=400,
detail={
"error": "Endpoint={} was not found in pass-through endpoint list.".format(
endpoint_id
)
},
)
return PassThroughEndpointResponse(endpoints=[response_obj])

View file

@ -573,6 +573,23 @@ def _resolve_typed_dict_type(typ):
return None return None
def _resolve_pydantic_type(typ) -> List:
"""Resolve the actual TypedDict class from a potentially wrapped type."""
origin = get_origin(typ)
typs = []
if origin is Union: # Check if it's a Union (like Optional)
for arg in get_args(typ):
if (
arg is not None
and not isinstance(arg, type(None))
and "NoneType" not in str(arg)
):
typs.append(arg)
elif isinstance(typ, type) and isinstance(typ, BaseModel):
return [typ]
return typs
def prisma_setup(database_url: Optional[str]): def prisma_setup(database_url: Optional[str]):
global prisma_client, proxy_logging_obj, user_api_key_cache global prisma_client, proxy_logging_obj, user_api_key_cache
@ -2204,6 +2221,15 @@ class ProxyConfig:
alerting_args=general_settings["alerting_args"], alerting_args=general_settings["alerting_args"],
) )
## PASS-THROUGH ENDPOINTS ##
if "pass_through_endpoints" in _general_settings:
general_settings["pass_through_endpoints"] = _general_settings[
"pass_through_endpoints"
]
await initialize_pass_through_endpoints(
pass_through_endpoints=general_settings["pass_through_endpoints"]
)
async def add_deployment( async def add_deployment(
self, self,
prisma_client: PrismaClient, prisma_client: PrismaClient,
@ -9434,7 +9460,7 @@ async def get_config_list(
"global_max_parallel_requests": {"type": "Integer"}, "global_max_parallel_requests": {"type": "Integer"},
"max_request_size_mb": {"type": "Integer"}, "max_request_size_mb": {"type": "Integer"},
"max_response_size_mb": {"type": "Integer"}, "max_response_size_mb": {"type": "Integer"},
"pass_through_endpoints": {"type": "TypedDictionary"}, "pass_through_endpoints": {"type": "PydanticModel"},
} }
return_val = [] return_val = []
@ -9446,45 +9472,76 @@ async def get_config_list(
typed_dict_type = allowed_args[field_name]["type"] typed_dict_type = allowed_args[field_name]["type"]
if typed_dict_type == "TypedDictionary": if typed_dict_type == "PydanticModel":
typed_dict_class: Optional[Any] = _resolve_typed_dict_type( pydantic_class_list: Optional[Any] = _resolve_pydantic_type(
field_info.annotation field_info.annotation
) )
if pydantic_class_list is None:
continue
if typed_dict_class is None: for pydantic_class in pydantic_class_list:
nested_fields = None
else:
# Get type hints from the TypedDict to create FieldDetail objects # Get type hints from the TypedDict to create FieldDetail objects
nested_fields = [ nested_fields = [
FieldDetail( FieldDetail(
field_name=sub_field, field_name=sub_field,
field_type=type_hint.__name__, field_type=sub_field_type.__name__,
field_description="", # Add custom logic if descriptions are available field_description="", # Add custom logic if descriptions are available
field_default_value=general_settings.get(sub_field, None), field_default_value=general_settings.get(sub_field, None),
stored_in_db=None,
) )
for sub_field, type_hint in get_type_hints( for sub_field, sub_field_type in pydantic_class.__annotations__.items()
typed_dict_class
).items()
] ]
idx = 0
for (
sub_field,
sub_field_info,
) in pydantic_class.model_fields.items():
if (
hasattr(sub_field_info, "description")
and sub_field_info.description is not None
):
nested_fields[idx].field_description = (
sub_field_info.description
)
idx += 1
_stored_in_db = None
if field_name in db_general_settings_dict:
_stored_in_db = True
elif field_name in general_settings:
_stored_in_db = False
_response_obj = ConfigList(
field_name=field_name,
field_type=allowed_args[field_name]["type"],
field_description=field_info.description or "",
field_value=general_settings.get(field_name, None),
stored_in_db=_stored_in_db,
field_default_value=field_info.default,
nested_fields=nested_fields,
)
return_val.append(_response_obj)
else: else:
nested_fields = None nested_fields = None
_stored_in_db = None _stored_in_db = None
if field_name in db_general_settings_dict: if field_name in db_general_settings_dict:
_stored_in_db = True _stored_in_db = True
elif field_name in general_settings: elif field_name in general_settings:
_stored_in_db = False _stored_in_db = False
_response_obj = ConfigList( _response_obj = ConfigList(
field_name=field_name, field_name=field_name,
field_type=allowed_args[field_name]["type"], field_type=allowed_args[field_name]["type"],
field_description=field_info.description or "", field_description=field_info.description or "",
field_value=general_settings.get(field_name, None), field_value=general_settings.get(field_name, None),
stored_in_db=_stored_in_db, stored_in_db=_stored_in_db,
field_default_value=field_info.default, field_default_value=field_info.default,
nested_fields=nested_fields, nested_fields=nested_fields,
) )
return_val.append(_response_obj) return_val.append(_response_obj)
return return_val return return_val