mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
feat(pass_through_endpoints.py): initial working CRUD endpoints for /pass_through_endoints
This commit is contained in:
parent
28faafadb1
commit
589da45c24
3 changed files with 264 additions and 38 deletions
|
@ -1082,10 +1082,18 @@ class DynamoDBArgs(LiteLLMBase):
|
|||
assume_role_aws_session_name: Optional[str] = None
|
||||
|
||||
|
||||
class PassThroughEndpointTypedDict(TypedDict):
|
||||
path: str
|
||||
target: str
|
||||
headers: dict
|
||||
class PassThroughGenericEndpoint(LiteLLMBase):
|
||||
path: str = Field(description="The route to be added to the LiteLLM Proxy Server.")
|
||||
target: str = Field(
|
||||
description="The URL to which requests for this path should be forwarded."
|
||||
)
|
||||
headers: dict = Field(
|
||||
description="Key-value pairs of headers to be forwarded with the request. You can set any key value pair here and it will be forwarded to your target endpoint"
|
||||
)
|
||||
|
||||
|
||||
class PassThroughEndpointResponse(LiteLLMBase):
|
||||
endpoints: List[PassThroughGenericEndpoint]
|
||||
|
||||
|
||||
class ConfigFieldUpdate(LiteLLMBase):
|
||||
|
@ -1104,6 +1112,7 @@ class FieldDetail(BaseModel):
|
|||
field_type: str
|
||||
field_description: str
|
||||
field_default_value: Any = None
|
||||
stored_in_db: Optional[bool]
|
||||
|
||||
|
||||
class ConfigList(LiteLLMBase):
|
||||
|
@ -1219,7 +1228,7 @@ class ConfigGeneralSettings(LiteLLMBase):
|
|||
default=False,
|
||||
description="Public model hub for users to see what models they have access to, supported openai params, etc.",
|
||||
)
|
||||
pass_through_endpoints: Optional[PassThroughEndpointTypedDict] = Field(
|
||||
pass_through_endpoints: Optional[List[PassThroughGenericEndpoint]] = Field(
|
||||
default=None,
|
||||
description="Set-up pass-through endpoints for provider-specific endpoints. Docs - https://docs.litellm.ai/docs/proxy/pass_through",
|
||||
)
|
||||
|
@ -1781,3 +1790,9 @@ class VirtualKeyEvent(LiteLLMBase):
|
|||
created_by_user_role: str
|
||||
created_by_key_alias: Optional[str]
|
||||
request_kwargs: dict
|
||||
|
||||
|
||||
class CreatePassThroughEndpoint(LiteLLMBase):
|
||||
path: str
|
||||
target: str
|
||||
headers: dict
|
||||
|
|
|
@ -20,7 +20,14 @@ from fastapi.responses import StreamingResponse
|
|||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.proxy._types import ProxyException, UserAPIKeyAuth
|
||||
from litellm.proxy._types import (
|
||||
ConfigFieldInfo,
|
||||
ConfigFieldUpdate,
|
||||
PassThroughEndpointResponse,
|
||||
PassThroughGenericEndpoint,
|
||||
ProxyException,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
|
||||
router = APIRouter()
|
||||
|
@ -481,16 +488,64 @@ async def initialize_pass_through_endpoints(pass_through_endpoints: list):
|
|||
|
||||
|
||||
@router.get(
|
||||
"/config/pass_through_endpoint/{endpoint_id}",
|
||||
"/config/pass_through_endpoint",
|
||||
tags=["Internal User management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=PassThroughEndpointResponse,
|
||||
)
|
||||
async def get_pass_through_endpoints(request: Request, endpoint_id: str):
|
||||
async def get_pass_through_endpoints(
|
||||
endpoint_id: Optional[str] = None,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
GET configured pass through endpoint.
|
||||
|
||||
If no endpoint_id given, return all configured endpoints.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import get_config_general_settings
|
||||
|
||||
## Get existing pass-through endpoint field value
|
||||
try:
|
||||
response: ConfigFieldInfo = await get_config_general_settings(
|
||||
field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict
|
||||
)
|
||||
except Exception:
|
||||
return PassThroughEndpointResponse(endpoints=[])
|
||||
|
||||
pass_through_endpoint_data: Optional[List] = response.field_value
|
||||
if pass_through_endpoint_data is None:
|
||||
return PassThroughEndpointResponse(endpoints=[])
|
||||
|
||||
returned_endpoints = []
|
||||
if endpoint_id is None:
|
||||
for endpoint in pass_through_endpoint_data:
|
||||
if isinstance(endpoint, dict):
|
||||
returned_endpoints.append(PassThroughGenericEndpoint(**endpoint))
|
||||
elif isinstance(endpoint, PassThroughGenericEndpoint):
|
||||
returned_endpoints.append(endpoint)
|
||||
elif endpoint_id is not None:
|
||||
for endpoint in pass_through_endpoint_data:
|
||||
_endpoint: Optional[PassThroughGenericEndpoint] = None
|
||||
if isinstance(endpoint, dict):
|
||||
_endpoint = PassThroughGenericEndpoint(**endpoint)
|
||||
elif isinstance(endpoint, PassThroughGenericEndpoint):
|
||||
_endpoint = endpoint
|
||||
|
||||
if _endpoint is not None and _endpoint.path == endpoint_id:
|
||||
returned_endpoints.append(_endpoint)
|
||||
|
||||
return PassThroughEndpointResponse(endpoints=returned_endpoints)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/config/pass_through_endpoint/{endpoint_id}",
|
||||
tags=["Internal User management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def update_pass_through_endpoints(request: Request, endpoint_id: str):
|
||||
"""
|
||||
Update a pass-through endpoint
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
|
@ -499,20 +554,119 @@ async def get_pass_through_endpoints(request: Request, endpoint_id: str):
|
|||
tags=["Internal User management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def create_pass_through_endpoints(request: Request):
|
||||
async def create_pass_through_endpoints(
|
||||
data: PassThroughGenericEndpoint,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Create new pass-through endpoint
|
||||
"""
|
||||
pass
|
||||
from litellm.proxy.proxy_server import (
|
||||
get_config_general_settings,
|
||||
update_config_general_settings,
|
||||
)
|
||||
|
||||
## Get existing pass-through endpoint field value
|
||||
|
||||
try:
|
||||
response: ConfigFieldInfo = await get_config_general_settings(
|
||||
field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict
|
||||
)
|
||||
except Exception:
|
||||
response = ConfigFieldInfo(
|
||||
field_name="pass_through_endpoints", field_value=None
|
||||
)
|
||||
|
||||
## Update field with new endpoint
|
||||
data_dict = data.model_dump()
|
||||
if response.field_value is None:
|
||||
response.field_value = [data_dict]
|
||||
elif isinstance(response.field_value, List):
|
||||
response.field_value.append(data_dict)
|
||||
|
||||
## Update db
|
||||
updated_data = ConfigFieldUpdate(
|
||||
field_name="pass_through_endpoints",
|
||||
field_value=response.field_value,
|
||||
config_type="general_settings",
|
||||
)
|
||||
await update_config_general_settings(
|
||||
data=updated_data, user_api_key_dict=user_api_key_dict
|
||||
)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/config/pass_through_endpoint",
|
||||
tags=["Internal User management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=PassThroughEndpointResponse,
|
||||
)
|
||||
async def delete_pass_through_endpoints(request: Request):
|
||||
async def delete_pass_through_endpoints(
|
||||
endpoint_id: str,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Create new pass-through endpoint
|
||||
Delete a pass-through endpoint
|
||||
|
||||
Returns - the deleted endpoint
|
||||
"""
|
||||
pass
|
||||
from litellm.proxy.proxy_server import (
|
||||
get_config_general_settings,
|
||||
update_config_general_settings,
|
||||
)
|
||||
|
||||
## Get existing pass-through endpoint field value
|
||||
|
||||
try:
|
||||
response: ConfigFieldInfo = await get_config_general_settings(
|
||||
field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict
|
||||
)
|
||||
except Exception:
|
||||
response = ConfigFieldInfo(
|
||||
field_name="pass_through_endpoints", field_value=None
|
||||
)
|
||||
|
||||
## Update field by removing endpoint
|
||||
pass_through_endpoint_data: Optional[List] = response.field_value
|
||||
response_obj: Optional[PassThroughGenericEndpoint] = None
|
||||
if response.field_value is None or pass_through_endpoint_data is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": "There are no pass-through endpoints setup."},
|
||||
)
|
||||
elif isinstance(response.field_value, List):
|
||||
invalid_idx: Optional[int] = None
|
||||
for idx, endpoint in enumerate(pass_through_endpoint_data):
|
||||
_endpoint: Optional[PassThroughGenericEndpoint] = None
|
||||
if isinstance(endpoint, dict):
|
||||
_endpoint = PassThroughGenericEndpoint(**endpoint)
|
||||
elif isinstance(endpoint, PassThroughGenericEndpoint):
|
||||
_endpoint = endpoint
|
||||
|
||||
if _endpoint is not None and _endpoint.path == endpoint_id:
|
||||
invalid_idx = idx
|
||||
response_obj = _endpoint
|
||||
|
||||
if invalid_idx is not None:
|
||||
pass_through_endpoint_data.pop(invalid_idx)
|
||||
|
||||
## Update db
|
||||
updated_data = ConfigFieldUpdate(
|
||||
field_name="pass_through_endpoints",
|
||||
field_value=pass_through_endpoint_data,
|
||||
config_type="general_settings",
|
||||
)
|
||||
await update_config_general_settings(
|
||||
data=updated_data, user_api_key_dict=user_api_key_dict
|
||||
)
|
||||
|
||||
if response_obj is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Endpoint={} was not found in pass-through endpoint list.".format(
|
||||
endpoint_id
|
||||
)
|
||||
},
|
||||
)
|
||||
return PassThroughEndpointResponse(endpoints=[response_obj])
|
||||
|
|
|
@ -573,6 +573,23 @@ def _resolve_typed_dict_type(typ):
|
|||
return None
|
||||
|
||||
|
||||
def _resolve_pydantic_type(typ) -> List:
|
||||
"""Resolve the actual TypedDict class from a potentially wrapped type."""
|
||||
origin = get_origin(typ)
|
||||
typs = []
|
||||
if origin is Union: # Check if it's a Union (like Optional)
|
||||
for arg in get_args(typ):
|
||||
if (
|
||||
arg is not None
|
||||
and not isinstance(arg, type(None))
|
||||
and "NoneType" not in str(arg)
|
||||
):
|
||||
typs.append(arg)
|
||||
elif isinstance(typ, type) and isinstance(typ, BaseModel):
|
||||
return [typ]
|
||||
return typs
|
||||
|
||||
|
||||
def prisma_setup(database_url: Optional[str]):
|
||||
global prisma_client, proxy_logging_obj, user_api_key_cache
|
||||
|
||||
|
@ -2204,6 +2221,15 @@ class ProxyConfig:
|
|||
alerting_args=general_settings["alerting_args"],
|
||||
)
|
||||
|
||||
## PASS-THROUGH ENDPOINTS ##
|
||||
if "pass_through_endpoints" in _general_settings:
|
||||
general_settings["pass_through_endpoints"] = _general_settings[
|
||||
"pass_through_endpoints"
|
||||
]
|
||||
await initialize_pass_through_endpoints(
|
||||
pass_through_endpoints=general_settings["pass_through_endpoints"]
|
||||
)
|
||||
|
||||
async def add_deployment(
|
||||
self,
|
||||
prisma_client: PrismaClient,
|
||||
|
@ -9434,7 +9460,7 @@ async def get_config_list(
|
|||
"global_max_parallel_requests": {"type": "Integer"},
|
||||
"max_request_size_mb": {"type": "Integer"},
|
||||
"max_response_size_mb": {"type": "Integer"},
|
||||
"pass_through_endpoints": {"type": "TypedDictionary"},
|
||||
"pass_through_endpoints": {"type": "PydanticModel"},
|
||||
}
|
||||
|
||||
return_val = []
|
||||
|
@ -9446,45 +9472,76 @@ async def get_config_list(
|
|||
|
||||
typed_dict_type = allowed_args[field_name]["type"]
|
||||
|
||||
if typed_dict_type == "TypedDictionary":
|
||||
typed_dict_class: Optional[Any] = _resolve_typed_dict_type(
|
||||
if typed_dict_type == "PydanticModel":
|
||||
pydantic_class_list: Optional[Any] = _resolve_pydantic_type(
|
||||
field_info.annotation
|
||||
)
|
||||
if pydantic_class_list is None:
|
||||
continue
|
||||
|
||||
if typed_dict_class is None:
|
||||
nested_fields = None
|
||||
else:
|
||||
for pydantic_class in pydantic_class_list:
|
||||
# Get type hints from the TypedDict to create FieldDetail objects
|
||||
nested_fields = [
|
||||
FieldDetail(
|
||||
field_name=sub_field,
|
||||
field_type=type_hint.__name__,
|
||||
field_type=sub_field_type.__name__,
|
||||
field_description="", # Add custom logic if descriptions are available
|
||||
field_default_value=general_settings.get(sub_field, None),
|
||||
stored_in_db=None,
|
||||
)
|
||||
for sub_field, type_hint in get_type_hints(
|
||||
typed_dict_class
|
||||
).items()
|
||||
for sub_field, sub_field_type in pydantic_class.__annotations__.items()
|
||||
]
|
||||
|
||||
idx = 0
|
||||
for (
|
||||
sub_field,
|
||||
sub_field_info,
|
||||
) in pydantic_class.model_fields.items():
|
||||
if (
|
||||
hasattr(sub_field_info, "description")
|
||||
and sub_field_info.description is not None
|
||||
):
|
||||
nested_fields[idx].field_description = (
|
||||
sub_field_info.description
|
||||
)
|
||||
idx += 1
|
||||
|
||||
_stored_in_db = None
|
||||
if field_name in db_general_settings_dict:
|
||||
_stored_in_db = True
|
||||
elif field_name in general_settings:
|
||||
_stored_in_db = False
|
||||
|
||||
_response_obj = ConfigList(
|
||||
field_name=field_name,
|
||||
field_type=allowed_args[field_name]["type"],
|
||||
field_description=field_info.description or "",
|
||||
field_value=general_settings.get(field_name, None),
|
||||
stored_in_db=_stored_in_db,
|
||||
field_default_value=field_info.default,
|
||||
nested_fields=nested_fields,
|
||||
)
|
||||
return_val.append(_response_obj)
|
||||
|
||||
else:
|
||||
nested_fields = None
|
||||
|
||||
_stored_in_db = None
|
||||
if field_name in db_general_settings_dict:
|
||||
_stored_in_db = True
|
||||
elif field_name in general_settings:
|
||||
_stored_in_db = False
|
||||
_stored_in_db = None
|
||||
if field_name in db_general_settings_dict:
|
||||
_stored_in_db = True
|
||||
elif field_name in general_settings:
|
||||
_stored_in_db = False
|
||||
|
||||
_response_obj = ConfigList(
|
||||
field_name=field_name,
|
||||
field_type=allowed_args[field_name]["type"],
|
||||
field_description=field_info.description or "",
|
||||
field_value=general_settings.get(field_name, None),
|
||||
stored_in_db=_stored_in_db,
|
||||
field_default_value=field_info.default,
|
||||
nested_fields=nested_fields,
|
||||
)
|
||||
return_val.append(_response_obj)
|
||||
_response_obj = ConfigList(
|
||||
field_name=field_name,
|
||||
field_type=allowed_args[field_name]["type"],
|
||||
field_description=field_info.description or "",
|
||||
field_value=general_settings.get(field_name, None),
|
||||
stored_in_db=_stored_in_db,
|
||||
field_default_value=field_info.default,
|
||||
nested_fields=nested_fields,
|
||||
)
|
||||
return_val.append(_response_obj)
|
||||
|
||||
return return_val
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue