mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
use ProxyBaseLLMRequestProcessing
This commit is contained in:
parent
bfc928192d
commit
055a4fa2d5
5 changed files with 326 additions and 237 deletions
|
@ -239,6 +239,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
||||||
"image_generation",
|
"image_generation",
|
||||||
"moderation",
|
"moderation",
|
||||||
"audio_transcription",
|
"audio_transcription",
|
||||||
|
"responses",
|
||||||
],
|
],
|
||||||
) -> Any:
|
) -> Any:
|
||||||
pass
|
pass
|
||||||
|
|
291
litellm/proxy/common_request_processing.py
Normal file
291
litellm/proxy/common_request_processing.py
Normal file
|
@ -0,0 +1,291 @@
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import TYPE_CHECKING, Any, Callable, Dict, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from fastapi import Request
|
||||||
|
from fastapi.responses import Response, StreamingResponse
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
from litellm.proxy.auth.auth_utils import check_response_size_is_safe
|
||||||
|
from litellm.proxy.common_utils.callback_utils import (
|
||||||
|
get_logging_caching_headers,
|
||||||
|
get_remaining_tokens_and_requests_from_request_data,
|
||||||
|
)
|
||||||
|
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
||||||
|
from litellm.proxy.route_llm_request import route_request
|
||||||
|
from litellm.proxy.utils import ProxyLogging
|
||||||
|
from litellm.router import Router
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from litellm.proxy.proxy_server import ProxyConfig as _ProxyConfig
|
||||||
|
|
||||||
|
ProxyConfig = _ProxyConfig
|
||||||
|
else:
|
||||||
|
ProxyConfig = Any
|
||||||
|
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
||||||
|
|
||||||
|
|
||||||
|
class ProxyBaseLLMRequestProcessing:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_custom_headers(
|
||||||
|
*,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
call_id: Optional[str] = None,
|
||||||
|
model_id: Optional[str] = None,
|
||||||
|
cache_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
version: Optional[str] = None,
|
||||||
|
model_region: Optional[str] = None,
|
||||||
|
response_cost: Optional[Union[float, str]] = None,
|
||||||
|
hidden_params: Optional[dict] = None,
|
||||||
|
fastest_response_batch_completion: Optional[bool] = None,
|
||||||
|
request_data: Optional[dict] = {},
|
||||||
|
timeout: Optional[Union[float, int, httpx.Timeout]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> dict:
|
||||||
|
exclude_values = {"", None, "None"}
|
||||||
|
hidden_params = hidden_params or {}
|
||||||
|
headers = {
|
||||||
|
"x-litellm-call-id": call_id,
|
||||||
|
"x-litellm-model-id": model_id,
|
||||||
|
"x-litellm-cache-key": cache_key,
|
||||||
|
"x-litellm-model-api-base": api_base,
|
||||||
|
"x-litellm-version": version,
|
||||||
|
"x-litellm-model-region": model_region,
|
||||||
|
"x-litellm-response-cost": str(response_cost),
|
||||||
|
"x-litellm-key-tpm-limit": str(user_api_key_dict.tpm_limit),
|
||||||
|
"x-litellm-key-rpm-limit": str(user_api_key_dict.rpm_limit),
|
||||||
|
"x-litellm-key-max-budget": str(user_api_key_dict.max_budget),
|
||||||
|
"x-litellm-key-spend": str(user_api_key_dict.spend),
|
||||||
|
"x-litellm-response-duration-ms": str(
|
||||||
|
hidden_params.get("_response_ms", None)
|
||||||
|
),
|
||||||
|
"x-litellm-overhead-duration-ms": str(
|
||||||
|
hidden_params.get("litellm_overhead_time_ms", None)
|
||||||
|
),
|
||||||
|
"x-litellm-fastest_response_batch_completion": (
|
||||||
|
str(fastest_response_batch_completion)
|
||||||
|
if fastest_response_batch_completion is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
"x-litellm-timeout": str(timeout) if timeout is not None else None,
|
||||||
|
**{k: str(v) for k, v in kwargs.items()},
|
||||||
|
}
|
||||||
|
if request_data:
|
||||||
|
remaining_tokens_header = (
|
||||||
|
get_remaining_tokens_and_requests_from_request_data(request_data)
|
||||||
|
)
|
||||||
|
headers.update(remaining_tokens_header)
|
||||||
|
|
||||||
|
logging_caching_headers = get_logging_caching_headers(request_data)
|
||||||
|
if logging_caching_headers:
|
||||||
|
headers.update(logging_caching_headers)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return {
|
||||||
|
key: str(value)
|
||||||
|
for key, value in headers.items()
|
||||||
|
if value not in exclude_values
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.error(f"Error setting custom headers: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def base_process_llm_request(
|
||||||
|
data: dict,
|
||||||
|
request: Request,
|
||||||
|
fastapi_response: Response,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
route_type: Literal["acompletion", "aresponses"],
|
||||||
|
proxy_logging_obj: ProxyLogging,
|
||||||
|
general_settings: dict,
|
||||||
|
proxy_config: ProxyConfig,
|
||||||
|
select_data_generator: Callable,
|
||||||
|
llm_router: Optional[Router] = None,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
user_model: Optional[str] = None,
|
||||||
|
user_temperature: Optional[float] = None,
|
||||||
|
user_request_timeout: Optional[float] = None,
|
||||||
|
user_max_tokens: Optional[int] = None,
|
||||||
|
user_api_base: Optional[str] = None,
|
||||||
|
version: Optional[str] = None,
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Common request processing logic for both chat completions and responses API endpoints
|
||||||
|
"""
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
"Request received by LiteLLM:\n{}".format(json.dumps(data, indent=4)),
|
||||||
|
)
|
||||||
|
|
||||||
|
data = await add_litellm_data_to_request(
|
||||||
|
data=data,
|
||||||
|
request=request,
|
||||||
|
general_settings=general_settings,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
version=version,
|
||||||
|
proxy_config=proxy_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
data["model"] = (
|
||||||
|
general_settings.get("completion_model", None) # server default
|
||||||
|
or user_model # model name passed via cli args
|
||||||
|
or model # for azure deployments
|
||||||
|
or data.get("model", None) # default passed in http request
|
||||||
|
)
|
||||||
|
|
||||||
|
# override with user settings, these are params passed via cli
|
||||||
|
if user_temperature:
|
||||||
|
data["temperature"] = user_temperature
|
||||||
|
if user_request_timeout:
|
||||||
|
data["request_timeout"] = user_request_timeout
|
||||||
|
if user_max_tokens:
|
||||||
|
data["max_tokens"] = user_max_tokens
|
||||||
|
if user_api_base:
|
||||||
|
data["api_base"] = user_api_base
|
||||||
|
|
||||||
|
### MODEL ALIAS MAPPING ###
|
||||||
|
# check if model name in model alias map
|
||||||
|
# get the actual model name
|
||||||
|
if isinstance(data["model"], str) and data["model"] in litellm.model_alias_map:
|
||||||
|
data["model"] = litellm.model_alias_map[data["model"]]
|
||||||
|
|
||||||
|
### CALL HOOKS ### - modify/reject incoming data before calling the model
|
||||||
|
data = await proxy_logging_obj.pre_call_hook( # type: ignore
|
||||||
|
user_api_key_dict=user_api_key_dict, data=data, call_type="completion"
|
||||||
|
)
|
||||||
|
|
||||||
|
## LOGGING OBJECT ## - initialize logging object for logging success/failure events for call
|
||||||
|
## IMPORTANT Note: - initialize this before running pre-call checks. Ensures we log rejected requests to langfuse.
|
||||||
|
data["litellm_call_id"] = request.headers.get(
|
||||||
|
"x-litellm-call-id", str(uuid.uuid4())
|
||||||
|
)
|
||||||
|
logging_obj, data = litellm.utils.function_setup(
|
||||||
|
original_function=route_type,
|
||||||
|
rules_obj=litellm.utils.Rules(),
|
||||||
|
start_time=datetime.now(),
|
||||||
|
**data,
|
||||||
|
)
|
||||||
|
|
||||||
|
data["litellm_logging_obj"] = logging_obj
|
||||||
|
|
||||||
|
tasks = []
|
||||||
|
tasks.append(
|
||||||
|
proxy_logging_obj.during_call_hook(
|
||||||
|
data=data,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
call_type=ProxyBaseLLMRequestProcessing._get_pre_call_type(
|
||||||
|
route_type=route_type
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
### ROUTE THE REQUEST ###
|
||||||
|
# Do not change this - it should be a constant time fetch - ALWAYS
|
||||||
|
llm_call = await route_request(
|
||||||
|
data=data,
|
||||||
|
route_type=route_type,
|
||||||
|
llm_router=llm_router,
|
||||||
|
user_model=user_model,
|
||||||
|
)
|
||||||
|
tasks.append(llm_call)
|
||||||
|
|
||||||
|
# wait for call to end
|
||||||
|
llm_responses = asyncio.gather(
|
||||||
|
*tasks
|
||||||
|
) # run the moderation check in parallel to the actual llm api call
|
||||||
|
|
||||||
|
responses = await llm_responses
|
||||||
|
|
||||||
|
response = responses[1]
|
||||||
|
|
||||||
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||||
|
model_id = hidden_params.get("model_id", None) or ""
|
||||||
|
cache_key = hidden_params.get("cache_key", None) or ""
|
||||||
|
api_base = hidden_params.get("api_base", None) or ""
|
||||||
|
response_cost = hidden_params.get("response_cost", None) or ""
|
||||||
|
fastest_response_batch_completion = hidden_params.get(
|
||||||
|
"fastest_response_batch_completion", None
|
||||||
|
)
|
||||||
|
additional_headers: dict = hidden_params.get("additional_headers", {}) or {}
|
||||||
|
|
||||||
|
# Post Call Processing
|
||||||
|
if llm_router is not None:
|
||||||
|
data["deployment"] = llm_router.get_deployment(model_id=model_id)
|
||||||
|
asyncio.create_task(
|
||||||
|
proxy_logging_obj.update_request_status(
|
||||||
|
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
"stream" in data and data["stream"] is True
|
||||||
|
): # use generate_responses to stream responses
|
||||||
|
custom_headers = ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
call_id=logging_obj.litellm_call_id,
|
||||||
|
model_id=model_id,
|
||||||
|
cache_key=cache_key,
|
||||||
|
api_base=api_base,
|
||||||
|
version=version,
|
||||||
|
response_cost=response_cost,
|
||||||
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||||
|
fastest_response_batch_completion=fastest_response_batch_completion,
|
||||||
|
request_data=data,
|
||||||
|
hidden_params=hidden_params,
|
||||||
|
**additional_headers,
|
||||||
|
)
|
||||||
|
selected_data_generator = select_data_generator(
|
||||||
|
response=response,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
request_data=data,
|
||||||
|
)
|
||||||
|
return StreamingResponse(
|
||||||
|
selected_data_generator,
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers=custom_headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
### CALL HOOKS ### - modify outgoing data
|
||||||
|
response = await proxy_logging_obj.post_call_success_hook(
|
||||||
|
data=data, user_api_key_dict=user_api_key_dict, response=response
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_params = (
|
||||||
|
getattr(response, "_hidden_params", {}) or {}
|
||||||
|
) # get any updated response headers
|
||||||
|
additional_headers = hidden_params.get("additional_headers", {}) or {}
|
||||||
|
|
||||||
|
fastapi_response.headers.update(
|
||||||
|
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
call_id=logging_obj.litellm_call_id,
|
||||||
|
model_id=model_id,
|
||||||
|
cache_key=cache_key,
|
||||||
|
api_base=api_base,
|
||||||
|
version=version,
|
||||||
|
response_cost=response_cost,
|
||||||
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||||
|
fastest_response_batch_completion=fastest_response_batch_completion,
|
||||||
|
request_data=data,
|
||||||
|
hidden_params=hidden_params,
|
||||||
|
**additional_headers,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await check_response_size_is_safe(response=response)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_pre_call_type(
|
||||||
|
route_type: Literal["acompletion", "aresponses"]
|
||||||
|
) -> Literal["completion", "responses"]:
|
||||||
|
if route_type == "acompletion":
|
||||||
|
return "completion"
|
||||||
|
elif route_type == "aresponses":
|
||||||
|
return "responses"
|
|
@ -139,12 +139,9 @@ from litellm.proxy.batches_endpoints.endpoints import router as batches_router
|
||||||
|
|
||||||
## Import All Misc routes here ##
|
## Import All Misc routes here ##
|
||||||
from litellm.proxy.caching_routes import router as caching_router
|
from litellm.proxy.caching_routes import router as caching_router
|
||||||
|
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
|
||||||
from litellm.proxy.common_utils.admin_ui_utils import html_form
|
from litellm.proxy.common_utils.admin_ui_utils import html_form
|
||||||
from litellm.proxy.common_utils.callback_utils import (
|
from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy
|
||||||
get_logging_caching_headers,
|
|
||||||
get_remaining_tokens_and_requests_from_request_data,
|
|
||||||
initialize_callbacks_on_proxy,
|
|
||||||
)
|
|
||||||
from litellm.proxy.common_utils.debug_utils import init_verbose_loggers
|
from litellm.proxy.common_utils.debug_utils import init_verbose_loggers
|
||||||
from litellm.proxy.common_utils.debug_utils import router as debugging_endpoints_router
|
from litellm.proxy.common_utils.debug_utils import router as debugging_endpoints_router
|
||||||
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
|
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
|
||||||
|
@ -783,69 +780,6 @@ db_writer_client: Optional[AsyncHTTPHandler] = None
|
||||||
### logger ###
|
### logger ###
|
||||||
|
|
||||||
|
|
||||||
def get_custom_headers(
|
|
||||||
*,
|
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
|
||||||
call_id: Optional[str] = None,
|
|
||||||
model_id: Optional[str] = None,
|
|
||||||
cache_key: Optional[str] = None,
|
|
||||||
api_base: Optional[str] = None,
|
|
||||||
version: Optional[str] = None,
|
|
||||||
model_region: Optional[str] = None,
|
|
||||||
response_cost: Optional[Union[float, str]] = None,
|
|
||||||
hidden_params: Optional[dict] = None,
|
|
||||||
fastest_response_batch_completion: Optional[bool] = None,
|
|
||||||
request_data: Optional[dict] = {},
|
|
||||||
timeout: Optional[Union[float, int, httpx.Timeout]] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> dict:
|
|
||||||
exclude_values = {"", None, "None"}
|
|
||||||
hidden_params = hidden_params or {}
|
|
||||||
headers = {
|
|
||||||
"x-litellm-call-id": call_id,
|
|
||||||
"x-litellm-model-id": model_id,
|
|
||||||
"x-litellm-cache-key": cache_key,
|
|
||||||
"x-litellm-model-api-base": api_base,
|
|
||||||
"x-litellm-version": version,
|
|
||||||
"x-litellm-model-region": model_region,
|
|
||||||
"x-litellm-response-cost": str(response_cost),
|
|
||||||
"x-litellm-key-tpm-limit": str(user_api_key_dict.tpm_limit),
|
|
||||||
"x-litellm-key-rpm-limit": str(user_api_key_dict.rpm_limit),
|
|
||||||
"x-litellm-key-max-budget": str(user_api_key_dict.max_budget),
|
|
||||||
"x-litellm-key-spend": str(user_api_key_dict.spend),
|
|
||||||
"x-litellm-response-duration-ms": str(hidden_params.get("_response_ms", None)),
|
|
||||||
"x-litellm-overhead-duration-ms": str(
|
|
||||||
hidden_params.get("litellm_overhead_time_ms", None)
|
|
||||||
),
|
|
||||||
"x-litellm-fastest_response_batch_completion": (
|
|
||||||
str(fastest_response_batch_completion)
|
|
||||||
if fastest_response_batch_completion is not None
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
"x-litellm-timeout": str(timeout) if timeout is not None else None,
|
|
||||||
**{k: str(v) for k, v in kwargs.items()},
|
|
||||||
}
|
|
||||||
if request_data:
|
|
||||||
remaining_tokens_header = get_remaining_tokens_and_requests_from_request_data(
|
|
||||||
request_data
|
|
||||||
)
|
|
||||||
headers.update(remaining_tokens_header)
|
|
||||||
|
|
||||||
logging_caching_headers = get_logging_caching_headers(request_data)
|
|
||||||
if logging_caching_headers:
|
|
||||||
headers.update(logging_caching_headers)
|
|
||||||
|
|
||||||
try:
|
|
||||||
return {
|
|
||||||
key: str(value)
|
|
||||||
for key, value in headers.items()
|
|
||||||
if value not in exclude_values
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
verbose_proxy_logger.error(f"Error setting custom headers: {e}")
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
async def check_request_disconnection(request: Request, llm_api_call_task):
|
async def check_request_disconnection(request: Request, llm_api_call_task):
|
||||||
"""
|
"""
|
||||||
Asynchronously checks if the request is disconnected at regular intervals.
|
Asynchronously checks if the request is disconnected at regular intervals.
|
||||||
|
@ -3518,169 +3452,29 @@ async def chat_completion( # noqa: PLR0915
|
||||||
|
|
||||||
"""
|
"""
|
||||||
global general_settings, user_debug, proxy_logging_obj, llm_model_list
|
global general_settings, user_debug, proxy_logging_obj, llm_model_list
|
||||||
|
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
||||||
data = {}
|
data = {}
|
||||||
try:
|
try:
|
||||||
data = await _read_request_body(request=request)
|
data = await _read_request_body(request=request)
|
||||||
verbose_proxy_logger.debug(
|
return await ProxyBaseLLMRequestProcessing.base_process_llm_request(
|
||||||
"Request received by LiteLLM:\n{}".format(json.dumps(data, indent=4)),
|
|
||||||
)
|
|
||||||
|
|
||||||
data = await add_litellm_data_to_request(
|
|
||||||
data=data,
|
data=data,
|
||||||
request=request,
|
request=request,
|
||||||
general_settings=general_settings,
|
fastapi_response=fastapi_response,
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
version=version,
|
|
||||||
proxy_config=proxy_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
data["model"] = (
|
|
||||||
general_settings.get("completion_model", None) # server default
|
|
||||||
or user_model # model name passed via cli args
|
|
||||||
or model # for azure deployments
|
|
||||||
or data.get("model", None) # default passed in http request
|
|
||||||
)
|
|
||||||
|
|
||||||
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
|
||||||
# override with user settings, these are params passed via cli
|
|
||||||
if user_temperature:
|
|
||||||
data["temperature"] = user_temperature
|
|
||||||
if user_request_timeout:
|
|
||||||
data["request_timeout"] = user_request_timeout
|
|
||||||
if user_max_tokens:
|
|
||||||
data["max_tokens"] = user_max_tokens
|
|
||||||
if user_api_base:
|
|
||||||
data["api_base"] = user_api_base
|
|
||||||
|
|
||||||
### MODEL ALIAS MAPPING ###
|
|
||||||
# check if model name in model alias map
|
|
||||||
# get the actual model name
|
|
||||||
if isinstance(data["model"], str) and data["model"] in litellm.model_alias_map:
|
|
||||||
data["model"] = litellm.model_alias_map[data["model"]]
|
|
||||||
|
|
||||||
### CALL HOOKS ### - modify/reject incoming data before calling the model
|
|
||||||
data = await proxy_logging_obj.pre_call_hook( # type: ignore
|
|
||||||
user_api_key_dict=user_api_key_dict, data=data, call_type="completion"
|
|
||||||
)
|
|
||||||
|
|
||||||
## LOGGING OBJECT ## - initialize logging object for logging success/failure events for call
|
|
||||||
## IMPORTANT Note: - initialize this before running pre-call checks. Ensures we log rejected requests to langfuse.
|
|
||||||
data["litellm_call_id"] = request.headers.get(
|
|
||||||
"x-litellm-call-id", str(uuid.uuid4())
|
|
||||||
)
|
|
||||||
logging_obj, data = litellm.utils.function_setup(
|
|
||||||
original_function="acompletion",
|
|
||||||
rules_obj=litellm.utils.Rules(),
|
|
||||||
start_time=datetime.now(),
|
|
||||||
**data,
|
|
||||||
)
|
|
||||||
|
|
||||||
data["litellm_logging_obj"] = logging_obj
|
|
||||||
|
|
||||||
tasks = []
|
|
||||||
tasks.append(
|
|
||||||
proxy_logging_obj.during_call_hook(
|
|
||||||
data=data,
|
|
||||||
user_api_key_dict=user_api_key_dict,
|
|
||||||
call_type="completion",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
### ROUTE THE REQUEST ###
|
|
||||||
# Do not change this - it should be a constant time fetch - ALWAYS
|
|
||||||
llm_call = await route_request(
|
|
||||||
data=data,
|
|
||||||
route_type="acompletion",
|
route_type="acompletion",
|
||||||
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
llm_router=llm_router,
|
llm_router=llm_router,
|
||||||
|
general_settings=general_settings,
|
||||||
|
proxy_config=proxy_config,
|
||||||
|
select_data_generator=select_data_generator,
|
||||||
|
model=model,
|
||||||
user_model=user_model,
|
user_model=user_model,
|
||||||
|
user_temperature=user_temperature,
|
||||||
|
user_request_timeout=user_request_timeout,
|
||||||
|
user_max_tokens=user_max_tokens,
|
||||||
|
user_api_base=user_api_base,
|
||||||
|
version=version,
|
||||||
)
|
)
|
||||||
tasks.append(llm_call)
|
|
||||||
|
|
||||||
# wait for call to end
|
|
||||||
llm_responses = asyncio.gather(
|
|
||||||
*tasks
|
|
||||||
) # run the moderation check in parallel to the actual llm api call
|
|
||||||
|
|
||||||
responses = await llm_responses
|
|
||||||
|
|
||||||
response = responses[1]
|
|
||||||
|
|
||||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
|
||||||
model_id = hidden_params.get("model_id", None) or ""
|
|
||||||
cache_key = hidden_params.get("cache_key", None) or ""
|
|
||||||
api_base = hidden_params.get("api_base", None) or ""
|
|
||||||
response_cost = hidden_params.get("response_cost", None) or ""
|
|
||||||
fastest_response_batch_completion = hidden_params.get(
|
|
||||||
"fastest_response_batch_completion", None
|
|
||||||
)
|
|
||||||
additional_headers: dict = hidden_params.get("additional_headers", {}) or {}
|
|
||||||
|
|
||||||
# Post Call Processing
|
|
||||||
if llm_router is not None:
|
|
||||||
data["deployment"] = llm_router.get_deployment(model_id=model_id)
|
|
||||||
asyncio.create_task(
|
|
||||||
proxy_logging_obj.update_request_status(
|
|
||||||
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
"stream" in data and data["stream"] is True
|
|
||||||
): # use generate_responses to stream responses
|
|
||||||
custom_headers = get_custom_headers(
|
|
||||||
user_api_key_dict=user_api_key_dict,
|
|
||||||
call_id=logging_obj.litellm_call_id,
|
|
||||||
model_id=model_id,
|
|
||||||
cache_key=cache_key,
|
|
||||||
api_base=api_base,
|
|
||||||
version=version,
|
|
||||||
response_cost=response_cost,
|
|
||||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
|
||||||
fastest_response_batch_completion=fastest_response_batch_completion,
|
|
||||||
request_data=data,
|
|
||||||
hidden_params=hidden_params,
|
|
||||||
**additional_headers,
|
|
||||||
)
|
|
||||||
selected_data_generator = select_data_generator(
|
|
||||||
response=response,
|
|
||||||
user_api_key_dict=user_api_key_dict,
|
|
||||||
request_data=data,
|
|
||||||
)
|
|
||||||
return StreamingResponse(
|
|
||||||
selected_data_generator,
|
|
||||||
media_type="text/event-stream",
|
|
||||||
headers=custom_headers,
|
|
||||||
)
|
|
||||||
|
|
||||||
### CALL HOOKS ### - modify outgoing data
|
|
||||||
response = await proxy_logging_obj.post_call_success_hook(
|
|
||||||
data=data, user_api_key_dict=user_api_key_dict, response=response
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_params = (
|
|
||||||
getattr(response, "_hidden_params", {}) or {}
|
|
||||||
) # get any updated response headers
|
|
||||||
additional_headers = hidden_params.get("additional_headers", {}) or {}
|
|
||||||
|
|
||||||
fastapi_response.headers.update(
|
|
||||||
get_custom_headers(
|
|
||||||
user_api_key_dict=user_api_key_dict,
|
|
||||||
call_id=logging_obj.litellm_call_id,
|
|
||||||
model_id=model_id,
|
|
||||||
cache_key=cache_key,
|
|
||||||
api_base=api_base,
|
|
||||||
version=version,
|
|
||||||
response_cost=response_cost,
|
|
||||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
|
||||||
fastest_response_batch_completion=fastest_response_batch_completion,
|
|
||||||
request_data=data,
|
|
||||||
hidden_params=hidden_params,
|
|
||||||
**additional_headers,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
await check_response_size_is_safe(response=response)
|
|
||||||
|
|
||||||
return response
|
|
||||||
except RejectedRequestError as e:
|
except RejectedRequestError as e:
|
||||||
_data = e.request_data
|
_data = e.request_data
|
||||||
await proxy_logging_obj.post_call_failure_hook(
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
|
@ -3734,7 +3528,7 @@ async def chat_completion( # noqa: PLR0915
|
||||||
_litellm_logging_obj: Optional[LiteLLMLoggingObj] = data.get(
|
_litellm_logging_obj: Optional[LiteLLMLoggingObj] = data.get(
|
||||||
"litellm_logging_obj", None
|
"litellm_logging_obj", None
|
||||||
)
|
)
|
||||||
custom_headers = get_custom_headers(
|
custom_headers = ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
call_id=(
|
call_id=(
|
||||||
_litellm_logging_obj.litellm_call_id if _litellm_logging_obj else None
|
_litellm_logging_obj.litellm_call_id if _litellm_logging_obj else None
|
||||||
|
@ -3880,7 +3674,7 @@ async def completion( # noqa: PLR0915
|
||||||
if (
|
if (
|
||||||
"stream" in data and data["stream"] is True
|
"stream" in data and data["stream"] is True
|
||||||
): # use generate_responses to stream responses
|
): # use generate_responses to stream responses
|
||||||
custom_headers = get_custom_headers(
|
custom_headers = ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
call_id=litellm_call_id,
|
call_id=litellm_call_id,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
|
@ -3908,7 +3702,7 @@ async def completion( # noqa: PLR0915
|
||||||
)
|
)
|
||||||
|
|
||||||
fastapi_response.headers.update(
|
fastapi_response.headers.update(
|
||||||
get_custom_headers(
|
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
call_id=litellm_call_id,
|
call_id=litellm_call_id,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
|
@ -4139,7 +3933,7 @@ async def embeddings( # noqa: PLR0915
|
||||||
additional_headers: dict = hidden_params.get("additional_headers", {}) or {}
|
additional_headers: dict = hidden_params.get("additional_headers", {}) or {}
|
||||||
|
|
||||||
fastapi_response.headers.update(
|
fastapi_response.headers.update(
|
||||||
get_custom_headers(
|
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
cache_key=cache_key,
|
cache_key=cache_key,
|
||||||
|
@ -4267,7 +4061,7 @@ async def image_generation(
|
||||||
litellm_call_id = hidden_params.get("litellm_call_id", None) or ""
|
litellm_call_id = hidden_params.get("litellm_call_id", None) or ""
|
||||||
|
|
||||||
fastapi_response.headers.update(
|
fastapi_response.headers.update(
|
||||||
get_custom_headers(
|
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
cache_key=cache_key,
|
cache_key=cache_key,
|
||||||
|
@ -4388,7 +4182,7 @@ async def audio_speech(
|
||||||
async for chunk in _generator:
|
async for chunk in _generator:
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
custom_headers = get_custom_headers(
|
custom_headers = ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
cache_key=cache_key,
|
cache_key=cache_key,
|
||||||
|
@ -4529,7 +4323,7 @@ async def audio_transcriptions(
|
||||||
additional_headers: dict = hidden_params.get("additional_headers", {}) or {}
|
additional_headers: dict = hidden_params.get("additional_headers", {}) or {}
|
||||||
|
|
||||||
fastapi_response.headers.update(
|
fastapi_response.headers.update(
|
||||||
get_custom_headers(
|
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
cache_key=cache_key,
|
cache_key=cache_key,
|
||||||
|
@ -4681,7 +4475,7 @@ async def get_assistants(
|
||||||
api_base = hidden_params.get("api_base", None) or ""
|
api_base = hidden_params.get("api_base", None) or ""
|
||||||
|
|
||||||
fastapi_response.headers.update(
|
fastapi_response.headers.update(
|
||||||
get_custom_headers(
|
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
cache_key=cache_key,
|
cache_key=cache_key,
|
||||||
|
@ -4780,7 +4574,7 @@ async def create_assistant(
|
||||||
api_base = hidden_params.get("api_base", None) or ""
|
api_base = hidden_params.get("api_base", None) or ""
|
||||||
|
|
||||||
fastapi_response.headers.update(
|
fastapi_response.headers.update(
|
||||||
get_custom_headers(
|
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
cache_key=cache_key,
|
cache_key=cache_key,
|
||||||
|
@ -4877,7 +4671,7 @@ async def delete_assistant(
|
||||||
api_base = hidden_params.get("api_base", None) or ""
|
api_base = hidden_params.get("api_base", None) or ""
|
||||||
|
|
||||||
fastapi_response.headers.update(
|
fastapi_response.headers.update(
|
||||||
get_custom_headers(
|
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
cache_key=cache_key,
|
cache_key=cache_key,
|
||||||
|
@ -4974,7 +4768,7 @@ async def create_threads(
|
||||||
api_base = hidden_params.get("api_base", None) or ""
|
api_base = hidden_params.get("api_base", None) or ""
|
||||||
|
|
||||||
fastapi_response.headers.update(
|
fastapi_response.headers.update(
|
||||||
get_custom_headers(
|
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
cache_key=cache_key,
|
cache_key=cache_key,
|
||||||
|
@ -5070,7 +4864,7 @@ async def get_thread(
|
||||||
api_base = hidden_params.get("api_base", None) or ""
|
api_base = hidden_params.get("api_base", None) or ""
|
||||||
|
|
||||||
fastapi_response.headers.update(
|
fastapi_response.headers.update(
|
||||||
get_custom_headers(
|
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
cache_key=cache_key,
|
cache_key=cache_key,
|
||||||
|
@ -5169,7 +4963,7 @@ async def add_messages(
|
||||||
api_base = hidden_params.get("api_base", None) or ""
|
api_base = hidden_params.get("api_base", None) or ""
|
||||||
|
|
||||||
fastapi_response.headers.update(
|
fastapi_response.headers.update(
|
||||||
get_custom_headers(
|
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
cache_key=cache_key,
|
cache_key=cache_key,
|
||||||
|
@ -5264,7 +5058,7 @@ async def get_messages(
|
||||||
api_base = hidden_params.get("api_base", None) or ""
|
api_base = hidden_params.get("api_base", None) or ""
|
||||||
|
|
||||||
fastapi_response.headers.update(
|
fastapi_response.headers.update(
|
||||||
get_custom_headers(
|
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
cache_key=cache_key,
|
cache_key=cache_key,
|
||||||
|
@ -5373,7 +5167,7 @@ async def run_thread(
|
||||||
api_base = hidden_params.get("api_base", None) or ""
|
api_base = hidden_params.get("api_base", None) or ""
|
||||||
|
|
||||||
fastapi_response.headers.update(
|
fastapi_response.headers.update(
|
||||||
get_custom_headers(
|
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
cache_key=cache_key,
|
cache_key=cache_key,
|
||||||
|
@ -5496,7 +5290,7 @@ async def moderations(
|
||||||
api_base = hidden_params.get("api_base", None) or ""
|
api_base = hidden_params.get("api_base", None) or ""
|
||||||
|
|
||||||
fastapi_response.headers.update(
|
fastapi_response.headers.update(
|
||||||
get_custom_headers(
|
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
cache_key=cache_key,
|
cache_key=cache_key,
|
||||||
|
|
|
@ -21,6 +21,7 @@ ROUTE_ENDPOINT_MAPPING = {
|
||||||
"atranscription": "/audio/transcriptions",
|
"atranscription": "/audio/transcriptions",
|
||||||
"amoderation": "/moderations",
|
"amoderation": "/moderations",
|
||||||
"arerank": "/rerank",
|
"arerank": "/rerank",
|
||||||
|
"aresponses": "/responses",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -45,6 +46,7 @@ async def route_request(
|
||||||
"atranscription",
|
"atranscription",
|
||||||
"amoderation",
|
"amoderation",
|
||||||
"arerank",
|
"arerank",
|
||||||
|
"aresponses",
|
||||||
"_arealtime", # private function for realtime API
|
"_arealtime", # private function for realtime API
|
||||||
],
|
],
|
||||||
):
|
):
|
||||||
|
|
|
@ -537,6 +537,7 @@ class ProxyLogging:
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
call_type: Literal[
|
call_type: Literal[
|
||||||
"completion",
|
"completion",
|
||||||
|
"responses",
|
||||||
"embeddings",
|
"embeddings",
|
||||||
"image_generation",
|
"image_generation",
|
||||||
"moderation",
|
"moderation",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue