import ast import asyncio import json import traceback from base64 import b64encode from typing import AsyncIterable, List, Optional import httpx from fastapi import ( APIRouter, Depends, FastAPI, HTTPException, Request, Response, status, ) from fastapi.responses import StreamingResponse import litellm from litellm._logging import verbose_proxy_logger from litellm.integrations.custom_logger import CustomLogger from litellm.proxy._types import ( ConfigFieldInfo, ConfigFieldUpdate, PassThroughEndpointResponse, PassThroughGenericEndpoint, ProxyException, UserAPIKeyAuth, ) from litellm.proxy.auth.user_api_key_auth import user_api_key_auth router = APIRouter() async def set_env_variables_in_header(custom_headers: dict): """ checks if any headers on config.yaml are defined as os.environ/COHERE_API_KEY etc only runs for headers defined on config.yaml example header can be {"Authorization": "bearer os.environ/COHERE_API_KEY"} """ headers = {} for key, value in custom_headers.items(): # langfuse Api requires base64 encoded headers - it's simpleer to just ask litellm users to set their langfuse public and secret keys # we can then get the b64 encoded keys here if key == "LANGFUSE_PUBLIC_KEY" or key == "LANGFUSE_SECRET_KEY": # langfuse requires b64 encoded headers - we construct that here _langfuse_public_key = custom_headers["LANGFUSE_PUBLIC_KEY"] _langfuse_secret_key = custom_headers["LANGFUSE_SECRET_KEY"] if isinstance( _langfuse_public_key, str ) and _langfuse_public_key.startswith("os.environ/"): _langfuse_public_key = litellm.get_secret(_langfuse_public_key) if isinstance( _langfuse_secret_key, str ) and _langfuse_secret_key.startswith("os.environ/"): _langfuse_secret_key = litellm.get_secret(_langfuse_secret_key) headers["Authorization"] = "Basic " + b64encode( f"{_langfuse_public_key}:{_langfuse_secret_key}".encode("utf-8") ).decode("ascii") else: # for all other headers headers[key] = value if isinstance(value, str) and "os.environ/" in value: verbose_proxy_logger.debug( "pass through endpoint - looking up 'os.environ/' variable" ) # get string section that is os.environ/ start_index = value.find("os.environ/") _variable_name = value[start_index:] verbose_proxy_logger.debug( "pass through endpoint - getting secret for variable name: %s", _variable_name, ) _secret_value = litellm.get_secret(_variable_name) new_value = value.replace(_variable_name, _secret_value) headers[key] = new_value return headers async def chat_completion_pass_through_endpoint( fastapi_response: Response, request: Request, adapter_id: str, user_api_key_dict: UserAPIKeyAuth, ): from litellm.proxy.proxy_server import ( add_litellm_data_to_request, general_settings, get_custom_headers, llm_router, proxy_config, proxy_logging_obj, user_api_base, user_max_tokens, user_model, user_request_timeout, user_temperature, version, ) data = {} try: body = await request.body() body_str = body.decode() try: data = ast.literal_eval(body_str) except Exception: data = json.loads(body_str) data["adapter_id"] = adapter_id verbose_proxy_logger.debug( "Request received by LiteLLM:\n{}".format(json.dumps(data, indent=4)), ) data["model"] = ( general_settings.get("completion_model", None) # server default or user_model # model name passed via cli args or data["model"] # default passed in http request ) if user_model: data["model"] = user_model data = await add_litellm_data_to_request( data=data, # type: ignore request=request, general_settings=general_settings, user_api_key_dict=user_api_key_dict, version=version, proxy_config=proxy_config, ) # override with user settings, these are params passed via cli if user_temperature: data["temperature"] = user_temperature if user_request_timeout: data["request_timeout"] = user_request_timeout if user_max_tokens: data["max_tokens"] = user_max_tokens if user_api_base: data["api_base"] = user_api_base ### MODEL ALIAS MAPPING ### # check if model name in model alias map # get the actual model name if data["model"] in litellm.model_alias_map: data["model"] = litellm.model_alias_map[data["model"]] ### CALL HOOKS ### - modify incoming data before calling the model data = await proxy_logging_obj.pre_call_hook( # type: ignore user_api_key_dict=user_api_key_dict, data=data, call_type="text_completion" ) ### ROUTE THE REQUESTs ### router_model_names = llm_router.model_names if llm_router is not None else [] # skip router if user passed their key if "api_key" in data: llm_response = asyncio.create_task(litellm.aadapter_completion(**data)) elif ( llm_router is not None and data["model"] in router_model_names ): # model in router model list llm_response = asyncio.create_task(llm_router.aadapter_completion(**data)) elif ( llm_router is not None and llm_router.model_group_alias is not None and data["model"] in llm_router.model_group_alias ): # model set in model_group_alias llm_response = asyncio.create_task(llm_router.aadapter_completion(**data)) elif ( llm_router is not None and data["model"] in llm_router.deployment_names ): # model in router deployments, calling a specific deployment on the router llm_response = asyncio.create_task( llm_router.aadapter_completion(**data, specific_deployment=True) ) elif ( llm_router is not None and data["model"] in llm_router.get_model_ids() ): # model in router model list llm_response = asyncio.create_task(llm_router.aadapter_completion(**data)) elif ( llm_router is not None and data["model"] not in router_model_names and llm_router.default_deployment is not None ): # model in router deployments, calling a specific deployment on the router llm_response = asyncio.create_task(llm_router.aadapter_completion(**data)) elif user_model is not None: # `litellm --model ` llm_response = asyncio.create_task(litellm.aadapter_completion(**data)) else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail={ "error": "completion: Invalid model name passed in model=" + data.get("model", "") }, ) # Await the llm_response task response = await llm_response hidden_params = getattr(response, "_hidden_params", {}) or {} model_id = hidden_params.get("model_id", None) or "" cache_key = hidden_params.get("cache_key", None) or "" api_base = hidden_params.get("api_base", None) or "" response_cost = hidden_params.get("response_cost", None) or "" ### ALERTING ### asyncio.create_task( proxy_logging_obj.update_request_status( litellm_call_id=data.get("litellm_call_id", ""), status="success" ) ) verbose_proxy_logger.debug("final response: %s", response) fastapi_response.headers.update( get_custom_headers( user_api_key_dict=user_api_key_dict, model_id=model_id, cache_key=cache_key, api_base=api_base, version=version, response_cost=response_cost, ) ) verbose_proxy_logger.info("\nResponse from Litellm:\n{}".format(response)) return response except Exception as e: await proxy_logging_obj.post_call_failure_hook( user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data ) verbose_proxy_logger.exception( "litellm.proxy.proxy_server.completion(): Exception occured - {}".format( str(e) ) ) error_msg = f"{str(e)}" raise ProxyException( message=getattr(e, "message", error_msg), type=getattr(e, "type", "None"), param=getattr(e, "param", "None"), code=getattr(e, "status_code", 500), ) def forward_headers_from_request( request: Request, headers: dict, forward_headers: Optional[bool] = False, ): """ Helper to forward headers from original request """ if forward_headers is True: request_headers = dict(request.headers) # Header We Should NOT forward request_headers.pop("content-length", None) request_headers.pop("host", None) # Combine request headers with custom headers headers = {**request_headers, **headers} return headers def get_response_headers(headers: httpx.Headers) -> dict: excluded_headers = {"transfer-encoding", "content-encoding"} return_headers = { key: value for key, value in headers.items() if key.lower() not in excluded_headers } return return_headers async def pass_through_request( request: Request, target: str, custom_headers: dict, user_api_key_dict: UserAPIKeyAuth, forward_headers: Optional[bool] = False, query_params: Optional[dict] = None, stream: Optional[bool] = None, ): try: import time import uuid from litellm.litellm_core_utils.litellm_logging import Logging from litellm.proxy.proxy_server import proxy_logging_obj url = httpx.URL(target) headers = custom_headers headers = forward_headers_from_request( request=request, headers=headers, forward_headers=forward_headers ) request_body = await request.body() body_str = request_body.decode() try: _parsed_body = ast.literal_eval(body_str) except Exception: _parsed_body = json.loads(body_str) verbose_proxy_logger.debug( "Pass through endpoint sending request to \nURL {}\nheaders: {}\nbody: {}\n".format( url, headers, _parsed_body ) ) ### CALL HOOKS ### - modify incoming data / reject request before calling the model _parsed_body = await proxy_logging_obj.pre_call_hook( user_api_key_dict=user_api_key_dict, data=_parsed_body, call_type="pass_through_endpoint", ) async_client = httpx.AsyncClient() # create logging object start_time = time.time() logging_obj = Logging( model="unknown", messages=[{"role": "user", "content": "no-message-pass-through-endpoint"}], stream=False, call_type="pass_through_endpoint", start_time=start_time, litellm_call_id=str(uuid.uuid4()), function_id="1245", ) # done for supporting 'parallel_request_limiter.py' with pass-through endpoints kwargs = { "litellm_params": { "metadata": { "user_api_key": user_api_key_dict.api_key, "user_api_key_user_id": user_api_key_dict.user_id, "user_api_key_team_id": user_api_key_dict.team_id, "user_api_key_end_user_id": user_api_key_dict.user_id, } }, "call_type": "pass_through_endpoint", } logging_obj.update_environment_variables( model="unknown", user="unknown", optional_params={}, litellm_params=kwargs["litellm_params"], call_type="pass_through_endpoint", ) # combine url with query params for logging requested_query_params = query_params or request.query_params.__dict__ requested_query_params_str = "&".join( f"{k}={v}" for k, v in requested_query_params.items() ) if "?" in str(url): logging_url = str(url) + "&" + requested_query_params_str else: logging_url = str(url) + "?" + requested_query_params_str logging_obj.pre_call( input=[{"role": "user", "content": "no-message-pass-through-endpoint"}], api_key="", additional_args={ "complete_input_dict": _parsed_body, "api_base": logging_url, "headers": headers, }, ) if stream: req = async_client.build_request( "POST", url, json=_parsed_body, params=requested_query_params, headers=headers, ) response = await async_client.send(req, stream=stream) try: response.raise_for_status() except httpx.HTTPStatusError as e: raise HTTPException( status_code=e.response.status_code, detail=await e.response.aread() ) # Create an async generator to yield the response content async def stream_response() -> AsyncIterable[bytes]: async for chunk in response.aiter_bytes(): yield chunk return StreamingResponse( stream_response(), headers=get_response_headers(response.headers), status_code=response.status_code, ) response = await async_client.request( method=request.method, url=url, headers=headers, params=requested_query_params, json=_parsed_body, ) if ( response.headers.get("content-type") is not None and response.headers["content-type"] == "text/event-stream" ): try: response.raise_for_status() except httpx.HTTPStatusError as e: raise HTTPException( status_code=e.response.status_code, detail=await e.response.aread() ) # streaming response # Create an async generator to yield the response content async def stream_response() -> AsyncIterable[bytes]: async for chunk in response.aiter_bytes(): yield chunk return StreamingResponse( stream_response(), headers=get_response_headers(response.headers), status_code=response.status_code, ) try: response.raise_for_status() except httpx.HTTPStatusError as e: raise HTTPException( status_code=e.response.status_code, detail=e.response.text ) if response.status_code >= 300: raise HTTPException(status_code=response.status_code, detail=response.text) content = await response.aread() ## LOG SUCCESS end_time = time.time() await logging_obj.async_success_handler( result="", start_time=start_time, end_time=end_time, cache_hit=False, ) return Response( content=content, status_code=response.status_code, headers=get_response_headers(response.headers), ) except Exception as e: verbose_proxy_logger.exception( "litellm.proxy.proxy_server.pass_through_endpoint(): Exception occured - {}".format( str(e) ) ) if isinstance(e, HTTPException): raise ProxyException( message=getattr(e, "message", str(e.detail)), type=getattr(e, "type", "None"), param=getattr(e, "param", "None"), code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), ) else: error_msg = f"{str(e)}" raise ProxyException( message=getattr(e, "message", error_msg), type=getattr(e, "type", "None"), param=getattr(e, "param", "None"), code=getattr(e, "status_code", 500), ) def create_pass_through_route( endpoint, target: str, custom_headers: Optional[dict] = None, _forward_headers: Optional[bool] = False, dependencies: Optional[List] = None, ): # check if target is an adapter.py or a url import uuid from litellm.proxy.utils import get_instance_fn try: if isinstance(target, CustomLogger): adapter = target else: adapter = get_instance_fn(value=target) adapter_id = str(uuid.uuid4()) litellm.adapters = [{"id": adapter_id, "adapter": adapter}] async def endpoint_func( request: Request, fastapi_response: Response, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): return await chat_completion_pass_through_endpoint( fastapi_response=fastapi_response, request=request, adapter_id=adapter_id, user_api_key_dict=user_api_key_dict, ) except Exception: verbose_proxy_logger.debug("Defaulting to target being a url.") async def endpoint_func( # type: ignore request: Request, fastapi_response: Response, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), query_params: Optional[dict] = None, stream: Optional[ bool ] = None, # if pass-through endpoint is a streaming request ): return await pass_through_request( # type: ignore request=request, target=target, custom_headers=custom_headers or {}, user_api_key_dict=user_api_key_dict, forward_headers=_forward_headers, query_params=query_params, stream=stream, ) return endpoint_func async def initialize_pass_through_endpoints(pass_through_endpoints: list): verbose_proxy_logger.debug("initializing pass through endpoints") from litellm.proxy._types import CommonProxyErrors, LiteLLMRoutes from litellm.proxy.proxy_server import app, premium_user for endpoint in pass_through_endpoints: _target = endpoint.get("target", None) _path = endpoint.get("path", None) _custom_headers = endpoint.get("headers", None) _custom_headers = await set_env_variables_in_header( custom_headers=_custom_headers ) _forward_headers = endpoint.get("forward_headers", None) _auth = endpoint.get("auth", None) _dependencies = None if _auth is not None and str(_auth).lower() == "true": if premium_user is not True: raise ValueError( "Error Setting Authentication on Pass Through Endpoint: {}".format( CommonProxyErrors.not_premium_user.value ) ) _dependencies = [Depends(user_api_key_auth)] LiteLLMRoutes.openai_routes.value.append(_path) if _target is None: continue verbose_proxy_logger.debug( "adding pass through endpoint: %s, dependencies: %s", _path, _dependencies ) app.add_api_route( path=_path, endpoint=create_pass_through_route( _path, _target, _custom_headers, _forward_headers, _dependencies ), methods=["GET", "POST", "PUT", "DELETE", "PATCH"], dependencies=_dependencies, ) verbose_proxy_logger.debug("Added new pass through endpoint: %s", _path) @router.get( "/config/pass_through_endpoint", tags=["Internal User management"], dependencies=[Depends(user_api_key_auth)], response_model=PassThroughEndpointResponse, ) async def get_pass_through_endpoints( endpoint_id: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): """ GET configured pass through endpoint. If no endpoint_id given, return all configured endpoints. """ from litellm.proxy.proxy_server import get_config_general_settings ## Get existing pass-through endpoint field value try: response: ConfigFieldInfo = await get_config_general_settings( field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict ) except Exception: return PassThroughEndpointResponse(endpoints=[]) pass_through_endpoint_data: Optional[List] = response.field_value if pass_through_endpoint_data is None: return PassThroughEndpointResponse(endpoints=[]) returned_endpoints = [] if endpoint_id is None: for endpoint in pass_through_endpoint_data: if isinstance(endpoint, dict): returned_endpoints.append(PassThroughGenericEndpoint(**endpoint)) elif isinstance(endpoint, PassThroughGenericEndpoint): returned_endpoints.append(endpoint) elif endpoint_id is not None: for endpoint in pass_through_endpoint_data: _endpoint: Optional[PassThroughGenericEndpoint] = None if isinstance(endpoint, dict): _endpoint = PassThroughGenericEndpoint(**endpoint) elif isinstance(endpoint, PassThroughGenericEndpoint): _endpoint = endpoint if _endpoint is not None and _endpoint.path == endpoint_id: returned_endpoints.append(_endpoint) return PassThroughEndpointResponse(endpoints=returned_endpoints) @router.post( "/config/pass_through_endpoint/{endpoint_id}", tags=["Internal User management"], dependencies=[Depends(user_api_key_auth)], ) async def update_pass_through_endpoints(request: Request, endpoint_id: str): """ Update a pass-through endpoint """ pass @router.post( "/config/pass_through_endpoint", tags=["Internal User management"], dependencies=[Depends(user_api_key_auth)], ) async def create_pass_through_endpoints( data: PassThroughGenericEndpoint, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): """ Create new pass-through endpoint """ from litellm.proxy.proxy_server import ( get_config_general_settings, update_config_general_settings, ) ## Get existing pass-through endpoint field value try: response: ConfigFieldInfo = await get_config_general_settings( field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict ) except Exception: response = ConfigFieldInfo( field_name="pass_through_endpoints", field_value=None ) ## Update field with new endpoint data_dict = data.model_dump() if response.field_value is None: response.field_value = [data_dict] elif isinstance(response.field_value, List): response.field_value.append(data_dict) ## Update db updated_data = ConfigFieldUpdate( field_name="pass_through_endpoints", field_value=response.field_value, config_type="general_settings", ) await update_config_general_settings( data=updated_data, user_api_key_dict=user_api_key_dict ) @router.delete( "/config/pass_through_endpoint", tags=["Internal User management"], dependencies=[Depends(user_api_key_auth)], response_model=PassThroughEndpointResponse, ) async def delete_pass_through_endpoints( endpoint_id: str, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): """ Delete a pass-through endpoint Returns - the deleted endpoint """ from litellm.proxy.proxy_server import ( get_config_general_settings, update_config_general_settings, ) ## Get existing pass-through endpoint field value try: response: ConfigFieldInfo = await get_config_general_settings( field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict ) except Exception: response = ConfigFieldInfo( field_name="pass_through_endpoints", field_value=None ) ## Update field by removing endpoint pass_through_endpoint_data: Optional[List] = response.field_value response_obj: Optional[PassThroughGenericEndpoint] = None if response.field_value is None or pass_through_endpoint_data is None: raise HTTPException( status_code=400, detail={"error": "There are no pass-through endpoints setup."}, ) elif isinstance(response.field_value, List): invalid_idx: Optional[int] = None for idx, endpoint in enumerate(pass_through_endpoint_data): _endpoint: Optional[PassThroughGenericEndpoint] = None if isinstance(endpoint, dict): _endpoint = PassThroughGenericEndpoint(**endpoint) elif isinstance(endpoint, PassThroughGenericEndpoint): _endpoint = endpoint if _endpoint is not None and _endpoint.path == endpoint_id: invalid_idx = idx response_obj = _endpoint if invalid_idx is not None: pass_through_endpoint_data.pop(invalid_idx) ## Update db updated_data = ConfigFieldUpdate( field_name="pass_through_endpoints", field_value=pass_through_endpoint_data, config_type="general_settings", ) await update_config_general_settings( data=updated_data, user_api_key_dict=user_api_key_dict ) if response_obj is None: raise HTTPException( status_code=400, detail={ "error": "Endpoint={} was not found in pass-through endpoint list.".format( endpoint_id ) }, ) return PassThroughEndpointResponse(endpoints=[response_obj])