diff --git a/litellm/llms/OpenAI/common_utils.py b/litellm/llms/OpenAI/common_utils.py new file mode 100644 index 000000000..01c3ae943 --- /dev/null +++ b/litellm/llms/OpenAI/common_utils.py @@ -0,0 +1,45 @@ +""" +Common helpers / utils across al OpenAI endpoints +""" + +import json +from typing import Any, Dict, List + +import openai + + +####### Error Handling Utils for OpenAI API ####################### +################################################################### +def drop_params_from_unprocessable_entity_error( + e: openai.UnprocessableEntityError, data: Dict[str, Any] +) -> Dict[str, Any]: + """ + Helper function to read OpenAI UnprocessableEntityError and drop the params that raised an error from the error message. + + Args: + e (UnprocessableEntityError): The UnprocessableEntityError exception + data (Dict[str, Any]): The original data dictionary containing all parameters + + Returns: + Dict[str, Any]: A new dictionary with invalid parameters removed + """ + invalid_params: List[str] = [] + if e.body is not None and isinstance(e.body, dict) and e.body.get("message"): + message = e.body.get("message", {}) + if isinstance(message, str): + try: + message = json.loads(message) + except json.JSONDecodeError: + message = {"detail": message} + detail = message.get("detail") + if isinstance(detail, List) and len(detail) > 0 and isinstance(detail[0], dict): + for error_dict in detail: + if ( + error_dict.get("loc") + and isinstance(error_dict.get("loc"), list) + and len(error_dict.get("loc")) == 2 + ): + invalid_params.append(error_dict["loc"][1]) + + new_data = {k: v for k, v in data.items() if k not in invalid_params} + return new_data diff --git a/litellm/llms/OpenAI/openai.py b/litellm/llms/OpenAI/openai.py index 5df5b1132..fb78ee026 100644 --- a/litellm/llms/OpenAI/openai.py +++ b/litellm/llms/OpenAI/openai.py @@ -31,6 +31,7 @@ from litellm.utils import ( from ...types.llms.openai import * from ..base import BaseLLM from ..prompt_templates.factory import custom_prompt, prompt_factory +from .common_utils import drop_params_from_unprocessable_entity_error class OpenAIError(Exception): @@ -831,27 +832,9 @@ class OpenAIChatCompletion(BaseLLM): except openai.UnprocessableEntityError as e: ## check if body contains unprocessable params - related issue https://github.com/BerriAI/litellm/issues/4800 if litellm.drop_params is True or drop_params is True: - invalid_params: List[str] = [] - if e.body is not None and isinstance(e.body, dict) and e.body.get("detail"): # type: ignore - detail = e.body.get("detail") # type: ignore - if ( - isinstance(detail, List) - and len(detail) > 0 - and isinstance(detail[0], dict) - ): - for error_dict in detail: - if ( - error_dict.get("loc") - and isinstance(error_dict.get("loc"), list) - and len(error_dict.get("loc")) == 2 - ): - invalid_params.append(error_dict["loc"][1]) - - new_data = {} - for k, v in optional_params.items(): - if k not in invalid_params: - new_data[k] = v - optional_params = new_data + optional_params = drop_params_from_unprocessable_entity_error( + e, optional_params + ) else: raise e # e.message @@ -967,27 +950,7 @@ class OpenAIChatCompletion(BaseLLM): except openai.UnprocessableEntityError as e: ## check if body contains unprocessable params - related issue https://github.com/BerriAI/litellm/issues/4800 if litellm.drop_params is True or drop_params is True: - invalid_params: List[str] = [] - if e.body is not None and isinstance(e.body, dict) and e.body.get("detail"): # type: ignore - detail = e.body.get("detail") # type: ignore - if ( - isinstance(detail, List) - and len(detail) > 0 - and isinstance(detail[0], dict) - ): - for error_dict in detail: - if ( - error_dict.get("loc") - and isinstance(error_dict.get("loc"), list) - and len(error_dict.get("loc")) == 2 - ): - invalid_params.append(error_dict["loc"][1]) - - new_data = {} - for k, v in data.items(): - if k not in invalid_params: - new_data[k] = v - data = new_data + data = drop_params_from_unprocessable_entity_error(e, data) else: raise e # e.message