mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
feat(pass_through_endpoints.py): initial working CRUD endpoints for /pass_through_endoints
This commit is contained in:
parent
28faafadb1
commit
589da45c24
3 changed files with 264 additions and 38 deletions
|
@ -1082,10 +1082,18 @@ class DynamoDBArgs(LiteLLMBase):
|
||||||
assume_role_aws_session_name: Optional[str] = None
|
assume_role_aws_session_name: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class PassThroughEndpointTypedDict(TypedDict):
|
class PassThroughGenericEndpoint(LiteLLMBase):
|
||||||
path: str
|
path: str = Field(description="The route to be added to the LiteLLM Proxy Server.")
|
||||||
target: str
|
target: str = Field(
|
||||||
headers: dict
|
description="The URL to which requests for this path should be forwarded."
|
||||||
|
)
|
||||||
|
headers: dict = Field(
|
||||||
|
description="Key-value pairs of headers to be forwarded with the request. You can set any key value pair here and it will be forwarded to your target endpoint"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PassThroughEndpointResponse(LiteLLMBase):
|
||||||
|
endpoints: List[PassThroughGenericEndpoint]
|
||||||
|
|
||||||
|
|
||||||
class ConfigFieldUpdate(LiteLLMBase):
|
class ConfigFieldUpdate(LiteLLMBase):
|
||||||
|
@ -1104,6 +1112,7 @@ class FieldDetail(BaseModel):
|
||||||
field_type: str
|
field_type: str
|
||||||
field_description: str
|
field_description: str
|
||||||
field_default_value: Any = None
|
field_default_value: Any = None
|
||||||
|
stored_in_db: Optional[bool]
|
||||||
|
|
||||||
|
|
||||||
class ConfigList(LiteLLMBase):
|
class ConfigList(LiteLLMBase):
|
||||||
|
@ -1219,7 +1228,7 @@ class ConfigGeneralSettings(LiteLLMBase):
|
||||||
default=False,
|
default=False,
|
||||||
description="Public model hub for users to see what models they have access to, supported openai params, etc.",
|
description="Public model hub for users to see what models they have access to, supported openai params, etc.",
|
||||||
)
|
)
|
||||||
pass_through_endpoints: Optional[PassThroughEndpointTypedDict] = Field(
|
pass_through_endpoints: Optional[List[PassThroughGenericEndpoint]] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Set-up pass-through endpoints for provider-specific endpoints. Docs - https://docs.litellm.ai/docs/proxy/pass_through",
|
description="Set-up pass-through endpoints for provider-specific endpoints. Docs - https://docs.litellm.ai/docs/proxy/pass_through",
|
||||||
)
|
)
|
||||||
|
@ -1781,3 +1790,9 @@ class VirtualKeyEvent(LiteLLMBase):
|
||||||
created_by_user_role: str
|
created_by_user_role: str
|
||||||
created_by_key_alias: Optional[str]
|
created_by_key_alias: Optional[str]
|
||||||
request_kwargs: dict
|
request_kwargs: dict
|
||||||
|
|
||||||
|
|
||||||
|
class CreatePassThroughEndpoint(LiteLLMBase):
|
||||||
|
path: str
|
||||||
|
target: str
|
||||||
|
headers: dict
|
||||||
|
|
|
@ -20,7 +20,14 @@ from fastapi.responses import StreamingResponse
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.proxy._types import ProxyException, UserAPIKeyAuth
|
from litellm.proxy._types import (
|
||||||
|
ConfigFieldInfo,
|
||||||
|
ConfigFieldUpdate,
|
||||||
|
PassThroughEndpointResponse,
|
||||||
|
PassThroughGenericEndpoint,
|
||||||
|
ProxyException,
|
||||||
|
UserAPIKeyAuth,
|
||||||
|
)
|
||||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
@ -481,16 +488,64 @@ async def initialize_pass_through_endpoints(pass_through_endpoints: list):
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/config/pass_through_endpoint/{endpoint_id}",
|
"/config/pass_through_endpoint",
|
||||||
tags=["Internal User management"],
|
tags=["Internal User management"],
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
response_model=PassThroughEndpointResponse,
|
||||||
)
|
)
|
||||||
async def get_pass_through_endpoints(request: Request, endpoint_id: str):
|
async def get_pass_through_endpoints(
|
||||||
|
endpoint_id: Optional[str] = None,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
GET configured pass through endpoint.
|
GET configured pass through endpoint.
|
||||||
|
|
||||||
If no endpoint_id given, return all configured endpoints.
|
If no endpoint_id given, return all configured endpoints.
|
||||||
"""
|
"""
|
||||||
|
from litellm.proxy.proxy_server import get_config_general_settings
|
||||||
|
|
||||||
|
## Get existing pass-through endpoint field value
|
||||||
|
try:
|
||||||
|
response: ConfigFieldInfo = await get_config_general_settings(
|
||||||
|
field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
return PassThroughEndpointResponse(endpoints=[])
|
||||||
|
|
||||||
|
pass_through_endpoint_data: Optional[List] = response.field_value
|
||||||
|
if pass_through_endpoint_data is None:
|
||||||
|
return PassThroughEndpointResponse(endpoints=[])
|
||||||
|
|
||||||
|
returned_endpoints = []
|
||||||
|
if endpoint_id is None:
|
||||||
|
for endpoint in pass_through_endpoint_data:
|
||||||
|
if isinstance(endpoint, dict):
|
||||||
|
returned_endpoints.append(PassThroughGenericEndpoint(**endpoint))
|
||||||
|
elif isinstance(endpoint, PassThroughGenericEndpoint):
|
||||||
|
returned_endpoints.append(endpoint)
|
||||||
|
elif endpoint_id is not None:
|
||||||
|
for endpoint in pass_through_endpoint_data:
|
||||||
|
_endpoint: Optional[PassThroughGenericEndpoint] = None
|
||||||
|
if isinstance(endpoint, dict):
|
||||||
|
_endpoint = PassThroughGenericEndpoint(**endpoint)
|
||||||
|
elif isinstance(endpoint, PassThroughGenericEndpoint):
|
||||||
|
_endpoint = endpoint
|
||||||
|
|
||||||
|
if _endpoint is not None and _endpoint.path == endpoint_id:
|
||||||
|
returned_endpoints.append(_endpoint)
|
||||||
|
|
||||||
|
return PassThroughEndpointResponse(endpoints=returned_endpoints)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/config/pass_through_endpoint/{endpoint_id}",
|
||||||
|
tags=["Internal User management"],
|
||||||
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
)
|
||||||
|
async def update_pass_through_endpoints(request: Request, endpoint_id: str):
|
||||||
|
"""
|
||||||
|
Update a pass-through endpoint
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@ -499,20 +554,119 @@ async def get_pass_through_endpoints(request: Request, endpoint_id: str):
|
||||||
tags=["Internal User management"],
|
tags=["Internal User management"],
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
)
|
)
|
||||||
async def create_pass_through_endpoints(request: Request):
|
async def create_pass_through_endpoints(
|
||||||
|
data: PassThroughGenericEndpoint,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Create new pass-through endpoint
|
Create new pass-through endpoint
|
||||||
"""
|
"""
|
||||||
pass
|
from litellm.proxy.proxy_server import (
|
||||||
|
get_config_general_settings,
|
||||||
|
update_config_general_settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
## Get existing pass-through endpoint field value
|
||||||
|
|
||||||
|
try:
|
||||||
|
response: ConfigFieldInfo = await get_config_general_settings(
|
||||||
|
field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
response = ConfigFieldInfo(
|
||||||
|
field_name="pass_through_endpoints", field_value=None
|
||||||
|
)
|
||||||
|
|
||||||
|
## Update field with new endpoint
|
||||||
|
data_dict = data.model_dump()
|
||||||
|
if response.field_value is None:
|
||||||
|
response.field_value = [data_dict]
|
||||||
|
elif isinstance(response.field_value, List):
|
||||||
|
response.field_value.append(data_dict)
|
||||||
|
|
||||||
|
## Update db
|
||||||
|
updated_data = ConfigFieldUpdate(
|
||||||
|
field_name="pass_through_endpoints",
|
||||||
|
field_value=response.field_value,
|
||||||
|
config_type="general_settings",
|
||||||
|
)
|
||||||
|
await update_config_general_settings(
|
||||||
|
data=updated_data, user_api_key_dict=user_api_key_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.delete(
|
@router.delete(
|
||||||
"/config/pass_through_endpoint",
|
"/config/pass_through_endpoint",
|
||||||
tags=["Internal User management"],
|
tags=["Internal User management"],
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
response_model=PassThroughEndpointResponse,
|
||||||
)
|
)
|
||||||
async def delete_pass_through_endpoints(request: Request):
|
async def delete_pass_through_endpoints(
|
||||||
|
endpoint_id: str,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Create new pass-through endpoint
|
Delete a pass-through endpoint
|
||||||
|
|
||||||
|
Returns - the deleted endpoint
|
||||||
"""
|
"""
|
||||||
pass
|
from litellm.proxy.proxy_server import (
|
||||||
|
get_config_general_settings,
|
||||||
|
update_config_general_settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
## Get existing pass-through endpoint field value
|
||||||
|
|
||||||
|
try:
|
||||||
|
response: ConfigFieldInfo = await get_config_general_settings(
|
||||||
|
field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
response = ConfigFieldInfo(
|
||||||
|
field_name="pass_through_endpoints", field_value=None
|
||||||
|
)
|
||||||
|
|
||||||
|
## Update field by removing endpoint
|
||||||
|
pass_through_endpoint_data: Optional[List] = response.field_value
|
||||||
|
response_obj: Optional[PassThroughGenericEndpoint] = None
|
||||||
|
if response.field_value is None or pass_through_endpoint_data is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={"error": "There are no pass-through endpoints setup."},
|
||||||
|
)
|
||||||
|
elif isinstance(response.field_value, List):
|
||||||
|
invalid_idx: Optional[int] = None
|
||||||
|
for idx, endpoint in enumerate(pass_through_endpoint_data):
|
||||||
|
_endpoint: Optional[PassThroughGenericEndpoint] = None
|
||||||
|
if isinstance(endpoint, dict):
|
||||||
|
_endpoint = PassThroughGenericEndpoint(**endpoint)
|
||||||
|
elif isinstance(endpoint, PassThroughGenericEndpoint):
|
||||||
|
_endpoint = endpoint
|
||||||
|
|
||||||
|
if _endpoint is not None and _endpoint.path == endpoint_id:
|
||||||
|
invalid_idx = idx
|
||||||
|
response_obj = _endpoint
|
||||||
|
|
||||||
|
if invalid_idx is not None:
|
||||||
|
pass_through_endpoint_data.pop(invalid_idx)
|
||||||
|
|
||||||
|
## Update db
|
||||||
|
updated_data = ConfigFieldUpdate(
|
||||||
|
field_name="pass_through_endpoints",
|
||||||
|
field_value=pass_through_endpoint_data,
|
||||||
|
config_type="general_settings",
|
||||||
|
)
|
||||||
|
await update_config_general_settings(
|
||||||
|
data=updated_data, user_api_key_dict=user_api_key_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
if response_obj is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={
|
||||||
|
"error": "Endpoint={} was not found in pass-through endpoint list.".format(
|
||||||
|
endpoint_id
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return PassThroughEndpointResponse(endpoints=[response_obj])
|
||||||
|
|
|
@ -573,6 +573,23 @@ def _resolve_typed_dict_type(typ):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_pydantic_type(typ) -> List:
|
||||||
|
"""Resolve the actual TypedDict class from a potentially wrapped type."""
|
||||||
|
origin = get_origin(typ)
|
||||||
|
typs = []
|
||||||
|
if origin is Union: # Check if it's a Union (like Optional)
|
||||||
|
for arg in get_args(typ):
|
||||||
|
if (
|
||||||
|
arg is not None
|
||||||
|
and not isinstance(arg, type(None))
|
||||||
|
and "NoneType" not in str(arg)
|
||||||
|
):
|
||||||
|
typs.append(arg)
|
||||||
|
elif isinstance(typ, type) and isinstance(typ, BaseModel):
|
||||||
|
return [typ]
|
||||||
|
return typs
|
||||||
|
|
||||||
|
|
||||||
def prisma_setup(database_url: Optional[str]):
|
def prisma_setup(database_url: Optional[str]):
|
||||||
global prisma_client, proxy_logging_obj, user_api_key_cache
|
global prisma_client, proxy_logging_obj, user_api_key_cache
|
||||||
|
|
||||||
|
@ -2204,6 +2221,15 @@ class ProxyConfig:
|
||||||
alerting_args=general_settings["alerting_args"],
|
alerting_args=general_settings["alerting_args"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
## PASS-THROUGH ENDPOINTS ##
|
||||||
|
if "pass_through_endpoints" in _general_settings:
|
||||||
|
general_settings["pass_through_endpoints"] = _general_settings[
|
||||||
|
"pass_through_endpoints"
|
||||||
|
]
|
||||||
|
await initialize_pass_through_endpoints(
|
||||||
|
pass_through_endpoints=general_settings["pass_through_endpoints"]
|
||||||
|
)
|
||||||
|
|
||||||
async def add_deployment(
|
async def add_deployment(
|
||||||
self,
|
self,
|
||||||
prisma_client: PrismaClient,
|
prisma_client: PrismaClient,
|
||||||
|
@ -9434,7 +9460,7 @@ async def get_config_list(
|
||||||
"global_max_parallel_requests": {"type": "Integer"},
|
"global_max_parallel_requests": {"type": "Integer"},
|
||||||
"max_request_size_mb": {"type": "Integer"},
|
"max_request_size_mb": {"type": "Integer"},
|
||||||
"max_response_size_mb": {"type": "Integer"},
|
"max_response_size_mb": {"type": "Integer"},
|
||||||
"pass_through_endpoints": {"type": "TypedDictionary"},
|
"pass_through_endpoints": {"type": "PydanticModel"},
|
||||||
}
|
}
|
||||||
|
|
||||||
return_val = []
|
return_val = []
|
||||||
|
@ -9446,45 +9472,76 @@ async def get_config_list(
|
||||||
|
|
||||||
typed_dict_type = allowed_args[field_name]["type"]
|
typed_dict_type = allowed_args[field_name]["type"]
|
||||||
|
|
||||||
if typed_dict_type == "TypedDictionary":
|
if typed_dict_type == "PydanticModel":
|
||||||
typed_dict_class: Optional[Any] = _resolve_typed_dict_type(
|
pydantic_class_list: Optional[Any] = _resolve_pydantic_type(
|
||||||
field_info.annotation
|
field_info.annotation
|
||||||
)
|
)
|
||||||
|
if pydantic_class_list is None:
|
||||||
|
continue
|
||||||
|
|
||||||
if typed_dict_class is None:
|
for pydantic_class in pydantic_class_list:
|
||||||
nested_fields = None
|
|
||||||
else:
|
|
||||||
# Get type hints from the TypedDict to create FieldDetail objects
|
# Get type hints from the TypedDict to create FieldDetail objects
|
||||||
nested_fields = [
|
nested_fields = [
|
||||||
FieldDetail(
|
FieldDetail(
|
||||||
field_name=sub_field,
|
field_name=sub_field,
|
||||||
field_type=type_hint.__name__,
|
field_type=sub_field_type.__name__,
|
||||||
field_description="", # Add custom logic if descriptions are available
|
field_description="", # Add custom logic if descriptions are available
|
||||||
field_default_value=general_settings.get(sub_field, None),
|
field_default_value=general_settings.get(sub_field, None),
|
||||||
|
stored_in_db=None,
|
||||||
)
|
)
|
||||||
for sub_field, type_hint in get_type_hints(
|
for sub_field, sub_field_type in pydantic_class.__annotations__.items()
|
||||||
typed_dict_class
|
|
||||||
).items()
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
idx = 0
|
||||||
|
for (
|
||||||
|
sub_field,
|
||||||
|
sub_field_info,
|
||||||
|
) in pydantic_class.model_fields.items():
|
||||||
|
if (
|
||||||
|
hasattr(sub_field_info, "description")
|
||||||
|
and sub_field_info.description is not None
|
||||||
|
):
|
||||||
|
nested_fields[idx].field_description = (
|
||||||
|
sub_field_info.description
|
||||||
|
)
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
_stored_in_db = None
|
||||||
|
if field_name in db_general_settings_dict:
|
||||||
|
_stored_in_db = True
|
||||||
|
elif field_name in general_settings:
|
||||||
|
_stored_in_db = False
|
||||||
|
|
||||||
|
_response_obj = ConfigList(
|
||||||
|
field_name=field_name,
|
||||||
|
field_type=allowed_args[field_name]["type"],
|
||||||
|
field_description=field_info.description or "",
|
||||||
|
field_value=general_settings.get(field_name, None),
|
||||||
|
stored_in_db=_stored_in_db,
|
||||||
|
field_default_value=field_info.default,
|
||||||
|
nested_fields=nested_fields,
|
||||||
|
)
|
||||||
|
return_val.append(_response_obj)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
nested_fields = None
|
nested_fields = None
|
||||||
|
|
||||||
_stored_in_db = None
|
_stored_in_db = None
|
||||||
if field_name in db_general_settings_dict:
|
if field_name in db_general_settings_dict:
|
||||||
_stored_in_db = True
|
_stored_in_db = True
|
||||||
elif field_name in general_settings:
|
elif field_name in general_settings:
|
||||||
_stored_in_db = False
|
_stored_in_db = False
|
||||||
|
|
||||||
_response_obj = ConfigList(
|
_response_obj = ConfigList(
|
||||||
field_name=field_name,
|
field_name=field_name,
|
||||||
field_type=allowed_args[field_name]["type"],
|
field_type=allowed_args[field_name]["type"],
|
||||||
field_description=field_info.description or "",
|
field_description=field_info.description or "",
|
||||||
field_value=general_settings.get(field_name, None),
|
field_value=general_settings.get(field_name, None),
|
||||||
stored_in_db=_stored_in_db,
|
stored_in_db=_stored_in_db,
|
||||||
field_default_value=field_info.default,
|
field_default_value=field_info.default,
|
||||||
nested_fields=nested_fields,
|
nested_fields=nested_fields,
|
||||||
)
|
)
|
||||||
return_val.append(_response_obj)
|
return_val.append(_response_obj)
|
||||||
|
|
||||||
return return_val
|
return return_val
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue