mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
responses_api
This commit is contained in:
parent
e0252a9b49
commit
db1f48bbfb
3 changed files with 46 additions and 45 deletions
|
@ -31,6 +31,8 @@ from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
||||||
|
|
||||||
|
|
||||||
class ProxyBaseLLMRequestProcessing:
|
class ProxyBaseLLMRequestProcessing:
|
||||||
|
def __init__(self, data: dict):
|
||||||
|
self.data = data
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_custom_headers(
|
def get_custom_headers(
|
||||||
|
@ -97,9 +99,8 @@ class ProxyBaseLLMRequestProcessing:
|
||||||
verbose_proxy_logger.error(f"Error setting custom headers: {e}")
|
verbose_proxy_logger.error(f"Error setting custom headers: {e}")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def base_process_llm_request(
|
async def base_process_llm_request(
|
||||||
data: dict,
|
self,
|
||||||
request: Request,
|
request: Request,
|
||||||
fastapi_response: Response,
|
fastapi_response: Response,
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
@ -121,11 +122,11 @@ class ProxyBaseLLMRequestProcessing:
|
||||||
Common request processing logic for both chat completions and responses API endpoints
|
Common request processing logic for both chat completions and responses API endpoints
|
||||||
"""
|
"""
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
"Request received by LiteLLM:\n{}".format(json.dumps(data, indent=4)),
|
"Request received by LiteLLM:\n{}".format(json.dumps(self.data, indent=4)),
|
||||||
)
|
)
|
||||||
|
|
||||||
data = await add_litellm_data_to_request(
|
self.data = await add_litellm_data_to_request(
|
||||||
data=data,
|
data=self.data,
|
||||||
request=request,
|
request=request,
|
||||||
general_settings=general_settings,
|
general_settings=general_settings,
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
@ -133,52 +134,55 @@ class ProxyBaseLLMRequestProcessing:
|
||||||
proxy_config=proxy_config,
|
proxy_config=proxy_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
data["model"] = (
|
self.data["model"] = (
|
||||||
general_settings.get("completion_model", None) # server default
|
general_settings.get("completion_model", None) # server default
|
||||||
or user_model # model name passed via cli args
|
or user_model # model name passed via cli args
|
||||||
or model # for azure deployments
|
or model # for azure deployments
|
||||||
or data.get("model", None) # default passed in http request
|
or self.data.get("model", None) # default passed in http request
|
||||||
)
|
)
|
||||||
|
|
||||||
# override with user settings, these are params passed via cli
|
# override with user settings, these are params passed via cli
|
||||||
if user_temperature:
|
if user_temperature:
|
||||||
data["temperature"] = user_temperature
|
self.data["temperature"] = user_temperature
|
||||||
if user_request_timeout:
|
if user_request_timeout:
|
||||||
data["request_timeout"] = user_request_timeout
|
self.data["request_timeout"] = user_request_timeout
|
||||||
if user_max_tokens:
|
if user_max_tokens:
|
||||||
data["max_tokens"] = user_max_tokens
|
self.data["max_tokens"] = user_max_tokens
|
||||||
if user_api_base:
|
if user_api_base:
|
||||||
data["api_base"] = user_api_base
|
self.data["api_base"] = user_api_base
|
||||||
|
|
||||||
### MODEL ALIAS MAPPING ###
|
### MODEL ALIAS MAPPING ###
|
||||||
# check if model name in model alias map
|
# check if model name in model alias map
|
||||||
# get the actual model name
|
# get the actual model name
|
||||||
if isinstance(data["model"], str) and data["model"] in litellm.model_alias_map:
|
if (
|
||||||
data["model"] = litellm.model_alias_map[data["model"]]
|
isinstance(self.data["model"], str)
|
||||||
|
and self.data["model"] in litellm.model_alias_map
|
||||||
|
):
|
||||||
|
self.data["model"] = litellm.model_alias_map[self.data["model"]]
|
||||||
|
|
||||||
### CALL HOOKS ### - modify/reject incoming data before calling the model
|
### CALL HOOKS ### - modify/reject incoming data before calling the model
|
||||||
data = await proxy_logging_obj.pre_call_hook( # type: ignore
|
self.data = await proxy_logging_obj.pre_call_hook( # type: ignore
|
||||||
user_api_key_dict=user_api_key_dict, data=data, call_type="completion"
|
user_api_key_dict=user_api_key_dict, data=self.data, call_type="completion"
|
||||||
)
|
)
|
||||||
|
|
||||||
## LOGGING OBJECT ## - initialize logging object for logging success/failure events for call
|
## LOGGING OBJECT ## - initialize logging object for logging success/failure events for call
|
||||||
## IMPORTANT Note: - initialize this before running pre-call checks. Ensures we log rejected requests to langfuse.
|
## IMPORTANT Note: - initialize this before running pre-call checks. Ensures we log rejected requests to langfuse.
|
||||||
data["litellm_call_id"] = request.headers.get(
|
self.data["litellm_call_id"] = request.headers.get(
|
||||||
"x-litellm-call-id", str(uuid.uuid4())
|
"x-litellm-call-id", str(uuid.uuid4())
|
||||||
)
|
)
|
||||||
logging_obj, data = litellm.utils.function_setup(
|
logging_obj, self.data = litellm.utils.function_setup(
|
||||||
original_function=route_type,
|
original_function=route_type,
|
||||||
rules_obj=litellm.utils.Rules(),
|
rules_obj=litellm.utils.Rules(),
|
||||||
start_time=datetime.now(),
|
start_time=datetime.now(),
|
||||||
**data,
|
**self.data,
|
||||||
)
|
)
|
||||||
|
|
||||||
data["litellm_logging_obj"] = logging_obj
|
self.data["litellm_logging_obj"] = logging_obj
|
||||||
|
|
||||||
tasks = []
|
tasks = []
|
||||||
tasks.append(
|
tasks.append(
|
||||||
proxy_logging_obj.during_call_hook(
|
proxy_logging_obj.during_call_hook(
|
||||||
data=data,
|
data=self.data,
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
call_type=ProxyBaseLLMRequestProcessing._get_pre_call_type(
|
call_type=ProxyBaseLLMRequestProcessing._get_pre_call_type(
|
||||||
route_type=route_type
|
route_type=route_type
|
||||||
|
@ -189,7 +193,7 @@ class ProxyBaseLLMRequestProcessing:
|
||||||
### ROUTE THE REQUEST ###
|
### ROUTE THE REQUEST ###
|
||||||
# Do not change this - it should be a constant time fetch - ALWAYS
|
# Do not change this - it should be a constant time fetch - ALWAYS
|
||||||
llm_call = await route_request(
|
llm_call = await route_request(
|
||||||
data=data,
|
data=self.data,
|
||||||
route_type=route_type,
|
route_type=route_type,
|
||||||
llm_router=llm_router,
|
llm_router=llm_router,
|
||||||
user_model=user_model,
|
user_model=user_model,
|
||||||
|
@ -217,14 +221,14 @@ class ProxyBaseLLMRequestProcessing:
|
||||||
|
|
||||||
# Post Call Processing
|
# Post Call Processing
|
||||||
if llm_router is not None:
|
if llm_router is not None:
|
||||||
data["deployment"] = llm_router.get_deployment(model_id=model_id)
|
self.data["deployment"] = llm_router.get_deployment(model_id=model_id)
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
proxy_logging_obj.update_request_status(
|
proxy_logging_obj.update_request_status(
|
||||||
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
litellm_call_id=self.data.get("litellm_call_id", ""), status="success"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
"stream" in data and data["stream"] is True
|
"stream" in self.data and self.data["stream"] is True
|
||||||
): # use generate_responses to stream responses
|
): # use generate_responses to stream responses
|
||||||
custom_headers = ProxyBaseLLMRequestProcessing.get_custom_headers(
|
custom_headers = ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
@ -236,14 +240,14 @@ class ProxyBaseLLMRequestProcessing:
|
||||||
response_cost=response_cost,
|
response_cost=response_cost,
|
||||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||||
fastest_response_batch_completion=fastest_response_batch_completion,
|
fastest_response_batch_completion=fastest_response_batch_completion,
|
||||||
request_data=data,
|
request_data=self.data,
|
||||||
hidden_params=hidden_params,
|
hidden_params=hidden_params,
|
||||||
**additional_headers,
|
**additional_headers,
|
||||||
)
|
)
|
||||||
selected_data_generator = select_data_generator(
|
selected_data_generator = select_data_generator(
|
||||||
response=response,
|
response=response,
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
request_data=data,
|
request_data=self.data,
|
||||||
)
|
)
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
selected_data_generator,
|
selected_data_generator,
|
||||||
|
@ -253,7 +257,7 @@ class ProxyBaseLLMRequestProcessing:
|
||||||
|
|
||||||
### CALL HOOKS ### - modify outgoing data
|
### CALL HOOKS ### - modify outgoing data
|
||||||
response = await proxy_logging_obj.post_call_success_hook(
|
response = await proxy_logging_obj.post_call_success_hook(
|
||||||
data=data, user_api_key_dict=user_api_key_dict, response=response
|
data=self.data, user_api_key_dict=user_api_key_dict, response=response
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_params = (
|
hidden_params = (
|
||||||
|
@ -272,7 +276,7 @@ class ProxyBaseLLMRequestProcessing:
|
||||||
response_cost=response_cost,
|
response_cost=response_cost,
|
||||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||||
fastest_response_batch_completion=fastest_response_batch_completion,
|
fastest_response_batch_completion=fastest_response_batch_completion,
|
||||||
request_data=data,
|
request_data=self.data,
|
||||||
hidden_params=hidden_params,
|
hidden_params=hidden_params,
|
||||||
**additional_headers,
|
**additional_headers,
|
||||||
)
|
)
|
||||||
|
@ -281,10 +285,9 @@ class ProxyBaseLLMRequestProcessing:
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def _handle_llm_api_exception(
|
async def _handle_llm_api_exception(
|
||||||
|
self,
|
||||||
e: Exception,
|
e: Exception,
|
||||||
data: dict,
|
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
proxy_logging_obj: ProxyLogging,
|
proxy_logging_obj: ProxyLogging,
|
||||||
version: Optional[str] = None,
|
version: Optional[str] = None,
|
||||||
|
@ -294,7 +297,9 @@ class ProxyBaseLLMRequestProcessing:
|
||||||
f"litellm.proxy.proxy_server._handle_llm_api_exception(): Exception occured - {str(e)}"
|
f"litellm.proxy.proxy_server._handle_llm_api_exception(): Exception occured - {str(e)}"
|
||||||
)
|
)
|
||||||
await proxy_logging_obj.post_call_failure_hook(
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
original_exception=e,
|
||||||
|
request_data=self.data,
|
||||||
)
|
)
|
||||||
litellm_debug_info = getattr(e, "litellm_debug_info", "")
|
litellm_debug_info = getattr(e, "litellm_debug_info", "")
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
|
@ -306,7 +311,7 @@ class ProxyBaseLLMRequestProcessing:
|
||||||
timeout = getattr(
|
timeout = getattr(
|
||||||
e, "timeout", None
|
e, "timeout", None
|
||||||
) # returns the timeout set by the wrapper. Used for testing if model-specific timeout are set correctly
|
) # returns the timeout set by the wrapper. Used for testing if model-specific timeout are set correctly
|
||||||
_litellm_logging_obj: Optional[LiteLLMLoggingObj] = data.get(
|
_litellm_logging_obj: Optional[LiteLLMLoggingObj] = self.data.get(
|
||||||
"litellm_logging_obj", None
|
"litellm_logging_obj", None
|
||||||
)
|
)
|
||||||
custom_headers = ProxyBaseLLMRequestProcessing.get_custom_headers(
|
custom_headers = ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||||
|
@ -317,7 +322,7 @@ class ProxyBaseLLMRequestProcessing:
|
||||||
version=version,
|
version=version,
|
||||||
response_cost=0,
|
response_cost=0,
|
||||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||||
request_data=data,
|
request_data=self.data,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
headers = getattr(e, "headers", {}) or {}
|
headers = getattr(e, "headers", {}) or {}
|
||||||
|
|
|
@ -3454,11 +3454,10 @@ async def chat_completion( # noqa: PLR0915
|
||||||
"""
|
"""
|
||||||
global general_settings, user_debug, proxy_logging_obj, llm_model_list
|
global general_settings, user_debug, proxy_logging_obj, llm_model_list
|
||||||
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
||||||
data = {}
|
data = await _read_request_body(request=request)
|
||||||
|
base_llm_response_processor = ProxyBaseLLMRequestProcessing(data=data)
|
||||||
try:
|
try:
|
||||||
data = await _read_request_body(request=request)
|
return await base_llm_response_processor.base_process_llm_request(
|
||||||
return await ProxyBaseLLMRequestProcessing.base_process_llm_request(
|
|
||||||
data=data,
|
|
||||||
request=request,
|
request=request,
|
||||||
fastapi_response=fastapi_response,
|
fastapi_response=fastapi_response,
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
@ -3510,9 +3509,8 @@ async def chat_completion( # noqa: PLR0915
|
||||||
_chat_response.usage = _usage # type: ignore
|
_chat_response.usage = _usage # type: ignore
|
||||||
return _chat_response
|
return _chat_response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise await ProxyBaseLLMRequestProcessing._handle_llm_api_exception(
|
raise await base_llm_response_processor._handle_llm_api_exception(
|
||||||
e=e,
|
e=e,
|
||||||
data=data,
|
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
)
|
)
|
||||||
|
|
|
@ -50,11 +50,10 @@ async def responses_api(
|
||||||
version,
|
version,
|
||||||
)
|
)
|
||||||
|
|
||||||
data = {}
|
data = await _read_request_body(request=request)
|
||||||
|
processor = ProxyBaseLLMRequestProcessing(data=data)
|
||||||
try:
|
try:
|
||||||
data = await _read_request_body(request=request)
|
return await processor.base_process_llm_request(
|
||||||
return await ProxyBaseLLMRequestProcessing.base_process_llm_request(
|
|
||||||
data=data,
|
|
||||||
request=request,
|
request=request,
|
||||||
fastapi_response=fastapi_response,
|
fastapi_response=fastapi_response,
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
@ -73,9 +72,8 @@ async def responses_api(
|
||||||
version=version,
|
version=version,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise await ProxyBaseLLMRequestProcessing._handle_llm_api_exception(
|
raise await processor._handle_llm_api_exception(
|
||||||
e=e,
|
e=e,
|
||||||
data=data,
|
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
version=version,
|
version=version,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue