diff --git a/litellm/proxy/common_request_processing.py b/litellm/proxy/common_request_processing.py index 9ffc2666fc..7f131efb04 100644 --- a/litellm/proxy/common_request_processing.py +++ b/litellm/proxy/common_request_processing.py @@ -31,6 +31,8 @@ from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request class ProxyBaseLLMRequestProcessing: + def __init__(self, data: dict): + self.data = data @staticmethod def get_custom_headers( @@ -97,9 +99,8 @@ class ProxyBaseLLMRequestProcessing: verbose_proxy_logger.error(f"Error setting custom headers: {e}") return {} - @staticmethod async def base_process_llm_request( - data: dict, + self, request: Request, fastapi_response: Response, user_api_key_dict: UserAPIKeyAuth, @@ -121,11 +122,11 @@ class ProxyBaseLLMRequestProcessing: Common request processing logic for both chat completions and responses API endpoints """ verbose_proxy_logger.debug( - "Request received by LiteLLM:\n{}".format(json.dumps(data, indent=4)), + "Request received by LiteLLM:\n{}".format(json.dumps(self.data, indent=4)), ) - data = await add_litellm_data_to_request( - data=data, + self.data = await add_litellm_data_to_request( + data=self.data, request=request, general_settings=general_settings, user_api_key_dict=user_api_key_dict, @@ -133,52 +134,55 @@ class ProxyBaseLLMRequestProcessing: proxy_config=proxy_config, ) - data["model"] = ( + self.data["model"] = ( general_settings.get("completion_model", None) # server default or user_model # model name passed via cli args or model # for azure deployments - or data.get("model", None) # default passed in http request + or self.data.get("model", None) # default passed in http request ) # override with user settings, these are params passed via cli if user_temperature: - data["temperature"] = user_temperature + self.data["temperature"] = user_temperature if user_request_timeout: - data["request_timeout"] = user_request_timeout + self.data["request_timeout"] = user_request_timeout if user_max_tokens: - data["max_tokens"] = user_max_tokens + self.data["max_tokens"] = user_max_tokens if user_api_base: - data["api_base"] = user_api_base + self.data["api_base"] = user_api_base ### MODEL ALIAS MAPPING ### # check if model name in model alias map # get the actual model name - if isinstance(data["model"], str) and data["model"] in litellm.model_alias_map: - data["model"] = litellm.model_alias_map[data["model"]] + if ( + isinstance(self.data["model"], str) + and self.data["model"] in litellm.model_alias_map + ): + self.data["model"] = litellm.model_alias_map[self.data["model"]] ### CALL HOOKS ### - modify/reject incoming data before calling the model - data = await proxy_logging_obj.pre_call_hook( # type: ignore - user_api_key_dict=user_api_key_dict, data=data, call_type="completion" + self.data = await proxy_logging_obj.pre_call_hook( # type: ignore + user_api_key_dict=user_api_key_dict, data=self.data, call_type="completion" ) ## LOGGING OBJECT ## - initialize logging object for logging success/failure events for call ## IMPORTANT Note: - initialize this before running pre-call checks. Ensures we log rejected requests to langfuse. - data["litellm_call_id"] = request.headers.get( + self.data["litellm_call_id"] = request.headers.get( "x-litellm-call-id", str(uuid.uuid4()) ) - logging_obj, data = litellm.utils.function_setup( + logging_obj, self.data = litellm.utils.function_setup( original_function=route_type, rules_obj=litellm.utils.Rules(), start_time=datetime.now(), - **data, + **self.data, ) - data["litellm_logging_obj"] = logging_obj + self.data["litellm_logging_obj"] = logging_obj tasks = [] tasks.append( proxy_logging_obj.during_call_hook( - data=data, + data=self.data, user_api_key_dict=user_api_key_dict, call_type=ProxyBaseLLMRequestProcessing._get_pre_call_type( route_type=route_type @@ -189,7 +193,7 @@ class ProxyBaseLLMRequestProcessing: ### ROUTE THE REQUEST ### # Do not change this - it should be a constant time fetch - ALWAYS llm_call = await route_request( - data=data, + data=self.data, route_type=route_type, llm_router=llm_router, user_model=user_model, @@ -217,14 +221,14 @@ class ProxyBaseLLMRequestProcessing: # Post Call Processing if llm_router is not None: - data["deployment"] = llm_router.get_deployment(model_id=model_id) + self.data["deployment"] = llm_router.get_deployment(model_id=model_id) asyncio.create_task( proxy_logging_obj.update_request_status( - litellm_call_id=data.get("litellm_call_id", ""), status="success" + litellm_call_id=self.data.get("litellm_call_id", ""), status="success" ) ) if ( - "stream" in data and data["stream"] is True + "stream" in self.data and self.data["stream"] is True ): # use generate_responses to stream responses custom_headers = ProxyBaseLLMRequestProcessing.get_custom_headers( user_api_key_dict=user_api_key_dict, @@ -236,14 +240,14 @@ class ProxyBaseLLMRequestProcessing: response_cost=response_cost, model_region=getattr(user_api_key_dict, "allowed_model_region", ""), fastest_response_batch_completion=fastest_response_batch_completion, - request_data=data, + request_data=self.data, hidden_params=hidden_params, **additional_headers, ) selected_data_generator = select_data_generator( response=response, user_api_key_dict=user_api_key_dict, - request_data=data, + request_data=self.data, ) return StreamingResponse( selected_data_generator, @@ -253,7 +257,7 @@ class ProxyBaseLLMRequestProcessing: ### CALL HOOKS ### - modify outgoing data response = await proxy_logging_obj.post_call_success_hook( - data=data, user_api_key_dict=user_api_key_dict, response=response + data=self.data, user_api_key_dict=user_api_key_dict, response=response ) hidden_params = ( @@ -272,7 +276,7 @@ class ProxyBaseLLMRequestProcessing: response_cost=response_cost, model_region=getattr(user_api_key_dict, "allowed_model_region", ""), fastest_response_batch_completion=fastest_response_batch_completion, - request_data=data, + request_data=self.data, hidden_params=hidden_params, **additional_headers, ) @@ -281,10 +285,9 @@ class ProxyBaseLLMRequestProcessing: return response - @staticmethod async def _handle_llm_api_exception( + self, e: Exception, - data: dict, user_api_key_dict: UserAPIKeyAuth, proxy_logging_obj: ProxyLogging, version: Optional[str] = None, @@ -294,7 +297,9 @@ class ProxyBaseLLMRequestProcessing: f"litellm.proxy.proxy_server._handle_llm_api_exception(): Exception occured - {str(e)}" ) await proxy_logging_obj.post_call_failure_hook( - user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data + user_api_key_dict=user_api_key_dict, + original_exception=e, + request_data=self.data, ) litellm_debug_info = getattr(e, "litellm_debug_info", "") verbose_proxy_logger.debug( @@ -306,7 +311,7 @@ class ProxyBaseLLMRequestProcessing: timeout = getattr( e, "timeout", None ) # returns the timeout set by the wrapper. Used for testing if model-specific timeout are set correctly - _litellm_logging_obj: Optional[LiteLLMLoggingObj] = data.get( + _litellm_logging_obj: Optional[LiteLLMLoggingObj] = self.data.get( "litellm_logging_obj", None ) custom_headers = ProxyBaseLLMRequestProcessing.get_custom_headers( @@ -317,7 +322,7 @@ class ProxyBaseLLMRequestProcessing: version=version, response_cost=0, model_region=getattr(user_api_key_dict, "allowed_model_region", ""), - request_data=data, + request_data=self.data, timeout=timeout, ) headers = getattr(e, "headers", {}) or {} diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index d866fe1e75..6b5d62ca49 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -3454,11 +3454,10 @@ async def chat_completion( # noqa: PLR0915 """ global general_settings, user_debug, proxy_logging_obj, llm_model_list global user_temperature, user_request_timeout, user_max_tokens, user_api_base - data = {} + data = await _read_request_body(request=request) + base_llm_response_processor = ProxyBaseLLMRequestProcessing(data=data) try: - data = await _read_request_body(request=request) - return await ProxyBaseLLMRequestProcessing.base_process_llm_request( - data=data, + return await base_llm_response_processor.base_process_llm_request( request=request, fastapi_response=fastapi_response, user_api_key_dict=user_api_key_dict, @@ -3510,9 +3509,8 @@ async def chat_completion( # noqa: PLR0915 _chat_response.usage = _usage # type: ignore return _chat_response except Exception as e: - raise await ProxyBaseLLMRequestProcessing._handle_llm_api_exception( + raise await base_llm_response_processor._handle_llm_api_exception( e=e, - data=data, user_api_key_dict=user_api_key_dict, proxy_logging_obj=proxy_logging_obj, ) diff --git a/litellm/proxy/response_api_endpoints/endpoints.py b/litellm/proxy/response_api_endpoints/endpoints.py index 31d3c2ca90..8649276b0e 100644 --- a/litellm/proxy/response_api_endpoints/endpoints.py +++ b/litellm/proxy/response_api_endpoints/endpoints.py @@ -50,11 +50,10 @@ async def responses_api( version, ) - data = {} + data = await _read_request_body(request=request) + processor = ProxyBaseLLMRequestProcessing(data=data) try: - data = await _read_request_body(request=request) - return await ProxyBaseLLMRequestProcessing.base_process_llm_request( - data=data, + return await processor.base_process_llm_request( request=request, fastapi_response=fastapi_response, user_api_key_dict=user_api_key_dict, @@ -73,9 +72,8 @@ async def responses_api( version=version, ) except Exception as e: - raise await ProxyBaseLLMRequestProcessing._handle_llm_api_exception( + raise await processor._handle_llm_api_exception( e=e, - data=data, user_api_key_dict=user_api_key_dict, proxy_logging_obj=proxy_logging_obj, version=version,