From 86c9e05c10e308f9d3ecd750c67bf69d51613d2d Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 20 Jul 2024 15:08:15 -0700 Subject: [PATCH 1/2] fix(openai.py): drop invalid params if `drop_params: true` for azure ai Fixes https://github.com/BerriAI/litellm/issues/4800 --- litellm/llms/openai.py | 110 +++++++++++++++++--------- litellm/proxy/_new_secret_config.yaml | 11 ++- 2 files changed, 78 insertions(+), 43 deletions(-) diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 0a40ab3fe..e831ad140 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -984,47 +984,79 @@ class OpenAIChatCompletion(BaseLLM): headers=None, ): response = None - try: - openai_aclient = self._get_openai_client( - is_async=True, - api_key=api_key, - api_base=api_base, - timeout=timeout, - max_retries=max_retries, - organization=organization, - client=client, - ) + for _ in range( + 2 + ): # if call fails due to alternating messages, retry with reformatted message + try: + openai_aclient = self._get_openai_client( + is_async=True, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) - ## LOGGING - logging_obj.pre_call( - input=data["messages"], - api_key=openai_aclient.api_key, - additional_args={ - "headers": {"Authorization": f"Bearer {openai_aclient.api_key}"}, - "api_base": openai_aclient._base_url._uri_reference, - "acompletion": True, - "complete_input_dict": data, - }, - ) + ## LOGGING + logging_obj.pre_call( + input=data["messages"], + api_key=openai_aclient.api_key, + additional_args={ + "headers": { + "Authorization": f"Bearer {openai_aclient.api_key}" + }, + "api_base": openai_aclient._base_url._uri_reference, + "acompletion": True, + "complete_input_dict": data, + }, + ) - headers, response = await self.make_openai_chat_completion_request( - openai_aclient=openai_aclient, data=data, timeout=timeout - ) - stringified_response = response.model_dump() - logging_obj.post_call( - input=data["messages"], - api_key=api_key, - original_response=stringified_response, - additional_args={"complete_input_dict": data}, - ) - logging_obj.model_call_details["response_headers"] = headers - return convert_to_model_response_object( - response_object=stringified_response, - model_response_object=model_response, - hidden_params={"headers": headers}, - ) - except Exception as e: - raise e + headers, response = await self.make_openai_chat_completion_request( + openai_aclient=openai_aclient, data=data, timeout=timeout + ) + stringified_response = response.model_dump() + logging_obj.post_call( + input=data["messages"], + api_key=api_key, + original_response=stringified_response, + additional_args={"complete_input_dict": data}, + ) + logging_obj.model_call_details["response_headers"] = headers + return convert_to_model_response_object( + response_object=stringified_response, + model_response_object=model_response, + hidden_params={"headers": headers}, + ) + except openai.UnprocessableEntityError as e: + ## check if body contains unprocessable params - related issue https://github.com/BerriAI/litellm/issues/4800 + if litellm.drop_params is True: + if e.body is not None and e.body.get("detail"): # type: ignore + detail = e.body.get("detail") # type: ignore + invalid_params: List[str] = [] + if ( + isinstance(detail, List) + and len(detail) > 0 + and isinstance(detail[0], dict) + ): + for error_dict in detail: + if ( + error_dict.get("loc") + and isinstance(error_dict.get("loc"), list) + and len(error_dict.get("loc")) == 2 + ): + invalid_params.append(error_dict["loc"][1]) + + new_data = {} + for k, v in data.items(): + if k not in invalid_params: + new_data[k] = v + data = new_data + else: + raise e + # e.message + except Exception as e: + raise e def streaming( self, diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index f62078da1..1a1258a41 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,6 +1,9 @@ model_list: - - model_name: azure-chatgpt + - model_name: azure-mistral litellm_params: - model: azure/chatgpt-v-2 - api_key: os.environ/AZURE_API_KEY - api_base: os.environ/AZURE_API_BASE \ No newline at end of file + model: azure_ai/mistral + api_key: os.environ/AZURE_AI_MISTRAL_API_KEY + api_base: os.environ/AZURE_AI_MISTRAL_API_BASE + +litellm_settings: + drop_params: true \ No newline at end of file From a27454b8e3112200f6c10520ab5a6e87f153a307 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 20 Jul 2024 15:23:42 -0700 Subject: [PATCH 2/2] fix(openai.py): support completion, streaming, async_streaming --- litellm/llms/openai.py | 157 +++++++++++++++++++++---------- litellm/main.py | 1 + litellm/tests/test_completion.py | 30 ++++++ litellm/tests/test_streaming.py | 34 +++++++ 4 files changed, 174 insertions(+), 48 deletions(-) diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index e831ad140..2606c8f96 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -803,6 +803,7 @@ class OpenAIChatCompletion(BaseLLM): client=None, organization: Optional[str] = None, custom_llm_provider: Optional[str] = None, + drop_params: Optional[bool] = None, ): super().completion() exception_mapping_worked = False @@ -858,6 +859,7 @@ class OpenAIChatCompletion(BaseLLM): client=client, max_retries=max_retries, organization=organization, + drop_params=drop_params, ) else: return self.acompletion( @@ -871,6 +873,7 @@ class OpenAIChatCompletion(BaseLLM): client=client, max_retries=max_retries, organization=organization, + drop_params=drop_params, ) elif optional_params.get("stream", False): return self.streaming( @@ -925,6 +928,33 @@ class OpenAIChatCompletion(BaseLLM): response_object=stringified_response, model_response_object=model_response, ) + except openai.UnprocessableEntityError as e: + ## check if body contains unprocessable params - related issue https://github.com/BerriAI/litellm/issues/4800 + if litellm.drop_params is True or drop_params is True: + if e.body is not None and e.body.get("detail"): # type: ignore + detail = e.body.get("detail") # type: ignore + invalid_params: List[str] = [] + if ( + isinstance(detail, List) + and len(detail) > 0 + and isinstance(detail[0], dict) + ): + for error_dict in detail: + if ( + error_dict.get("loc") + and isinstance(error_dict.get("loc"), list) + and len(error_dict.get("loc")) == 2 + ): + invalid_params.append(error_dict["loc"][1]) + + new_data = {} + for k, v in optional_params.items(): + if k not in invalid_params: + new_data[k] = v + optional_params = new_data + else: + raise e + # e.message except Exception as e: if print_verbose is not None: print_verbose(f"openai.py: Received openai error - {str(e)}") @@ -982,6 +1012,7 @@ class OpenAIChatCompletion(BaseLLM): client=None, max_retries=None, headers=None, + drop_params: Optional[bool] = None, ): response = None for _ in range( @@ -1030,7 +1061,7 @@ class OpenAIChatCompletion(BaseLLM): ) except openai.UnprocessableEntityError as e: ## check if body contains unprocessable params - related issue https://github.com/BerriAI/litellm/issues/4800 - if litellm.drop_params is True: + if litellm.drop_params is True or drop_params is True: if e.body is not None and e.body.get("detail"): # type: ignore detail = e.body.get("detail") # type: ignore invalid_params: List[str] = [] @@ -1113,57 +1144,87 @@ class OpenAIChatCompletion(BaseLLM): client=None, max_retries=None, headers=None, + drop_params: Optional[bool] = None, ): response = None - try: - openai_aclient = self._get_openai_client( - is_async=True, - api_key=api_key, - api_base=api_base, - timeout=timeout, - max_retries=max_retries, - organization=organization, - client=client, - ) - ## LOGGING - logging_obj.pre_call( - input=data["messages"], - api_key=api_key, - additional_args={ - "headers": headers, - "api_base": api_base, - "acompletion": True, - "complete_input_dict": data, - }, - ) - - headers, response = await self.make_openai_chat_completion_request( - openai_aclient=openai_aclient, data=data, timeout=timeout - ) - logging_obj.model_call_details["response_headers"] = headers - streamwrapper = CustomStreamWrapper( - completion_stream=response, - model=model, - custom_llm_provider="openai", - logging_obj=logging_obj, - stream_options=data.get("stream_options", None), - ) - return streamwrapper - except ( - Exception - ) as e: # need to exception handle here. async exceptions don't get caught in sync functions. - if response is not None and hasattr(response, "text"): - raise OpenAIError( - status_code=500, - message=f"{str(e)}\n\nOriginal Response: {response.text}", + for _ in range(2): + try: + openai_aclient = self._get_openai_client( + is_async=True, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, ) - else: - if type(e).__name__ == "ReadTimeout": - raise OpenAIError(status_code=408, message=f"{type(e).__name__}") - elif hasattr(e, "status_code"): - raise OpenAIError(status_code=e.status_code, message=str(e)) + ## LOGGING + logging_obj.pre_call( + input=data["messages"], + api_key=api_key, + additional_args={ + "headers": headers, + "api_base": api_base, + "acompletion": True, + "complete_input_dict": data, + }, + ) + + headers, response = await self.make_openai_chat_completion_request( + openai_aclient=openai_aclient, data=data, timeout=timeout + ) + logging_obj.model_call_details["response_headers"] = headers + streamwrapper = CustomStreamWrapper( + completion_stream=response, + model=model, + custom_llm_provider="openai", + logging_obj=logging_obj, + stream_options=data.get("stream_options", None), + ) + return streamwrapper + except openai.UnprocessableEntityError as e: + ## check if body contains unprocessable params - related issue https://github.com/BerriAI/litellm/issues/4800 + if litellm.drop_params is True or drop_params is True: + if e.body is not None and e.body.get("detail"): # type: ignore + detail = e.body.get("detail") # type: ignore + invalid_params: List[str] = [] + if ( + isinstance(detail, List) + and len(detail) > 0 + and isinstance(detail[0], dict) + ): + for error_dict in detail: + if ( + error_dict.get("loc") + and isinstance(error_dict.get("loc"), list) + and len(error_dict.get("loc")) == 2 + ): + invalid_params.append(error_dict["loc"][1]) + + new_data = {} + for k, v in data.items(): + if k not in invalid_params: + new_data[k] = v + data = new_data else: - raise OpenAIError(status_code=500, message=f"{str(e)}") + raise e + except ( + Exception + ) as e: # need to exception handle here. async exceptions don't get caught in sync functions. + if response is not None and hasattr(response, "text"): + raise OpenAIError( + status_code=500, + message=f"{str(e)}\n\nOriginal Response: {response.text}", + ) + else: + if type(e).__name__ == "ReadTimeout": + raise OpenAIError( + status_code=408, message=f"{type(e).__name__}" + ) + elif hasattr(e, "status_code"): + raise OpenAIError(status_code=e.status_code, message=str(e)) + else: + raise OpenAIError(status_code=500, message=f"{str(e)}") # Embedding async def make_openai_embedding_request( diff --git a/litellm/main.py b/litellm/main.py index 7e0e534ad..8cb52d945 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1176,6 +1176,7 @@ def completion( client=client, # pass AsyncOpenAI, OpenAI client organization=organization, custom_llm_provider=custom_llm_provider, + drop_params=non_default_params.get("drop_params"), ) except Exception as e: ## LOGGING - log the original exception returned diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 34eebb712..7eda96cb9 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -142,6 +142,36 @@ def test_completion_azure_ai_command_r(): pytest.fail(f"Error occurred: {e}") +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_completion_azure_ai_mistral_invalid_params(sync_mode): + try: + import os + + litellm.set_verbose = True + + os.environ["AZURE_AI_API_BASE"] = os.getenv("AZURE_MISTRAL_API_BASE", "") + os.environ["AZURE_AI_API_KEY"] = os.getenv("AZURE_MISTRAL_API_KEY", "") + + data = { + "model": "azure_ai/mistral", + "messages": [{"role": "user", "content": "What is the meaning of life?"}], + "frequency_penalty": 0.1, + "presence_penalty": 0.1, + "drop_params": True, + } + if sync_mode: + response: litellm.ModelResponse = completion(**data) # type: ignore + else: + response: litellm.ModelResponse = await litellm.acompletion(**data) # type: ignore + + assert "azure_ai" in response.model + except litellm.Timeout as e: + pass + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + def test_completion_azure_command_r(): try: litellm.set_verbose = True diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index d07aa681d..64c2eb4ab 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -2881,6 +2881,40 @@ def test_azure_streaming_and_function_calling(): raise e +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_completion_azure_ai_mistral_invalid_params(sync_mode): + try: + import os + + litellm.set_verbose = True + + os.environ["AZURE_AI_API_BASE"] = os.getenv("AZURE_MISTRAL_API_BASE", "") + os.environ["AZURE_AI_API_KEY"] = os.getenv("AZURE_MISTRAL_API_KEY", "") + + data = { + "model": "azure_ai/mistral", + "messages": [{"role": "user", "content": "What is the meaning of life?"}], + "frequency_penalty": 0.1, + "presence_penalty": 0.1, + "drop_params": True, + "stream": True, + } + if sync_mode: + response: litellm.ModelResponse = completion(**data) # type: ignore + for chunk in response: + print(chunk) + else: + response: litellm.ModelResponse = await litellm.acompletion(**data) # type: ignore + + async for chunk in response: + print(chunk) + except litellm.Timeout as e: + pass + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + @pytest.mark.asyncio async def test_azure_astreaming_and_function_calling(): import uuid