Merge pull request #4806 from BerriAI/litellm_drop_invalid_params

fix(openai.py): drop invalid params if `drop_params: true` for azure ai
This commit is contained in:
Krish Dholakia 2024-07-20 17:45:46 -07:00 committed by GitHub
commit fee4f3385d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 251 additions and 90 deletions

View file

@ -803,6 +803,7 @@ class OpenAIChatCompletion(BaseLLM):
client=None, client=None,
organization: Optional[str] = None, organization: Optional[str] = None,
custom_llm_provider: Optional[str] = None, custom_llm_provider: Optional[str] = None,
drop_params: Optional[bool] = None,
): ):
super().completion() super().completion()
exception_mapping_worked = False exception_mapping_worked = False
@ -858,6 +859,7 @@ class OpenAIChatCompletion(BaseLLM):
client=client, client=client,
max_retries=max_retries, max_retries=max_retries,
organization=organization, organization=organization,
drop_params=drop_params,
) )
else: else:
return self.acompletion( return self.acompletion(
@ -871,6 +873,7 @@ class OpenAIChatCompletion(BaseLLM):
client=client, client=client,
max_retries=max_retries, max_retries=max_retries,
organization=organization, organization=organization,
drop_params=drop_params,
) )
elif optional_params.get("stream", False): elif optional_params.get("stream", False):
return self.streaming( return self.streaming(
@ -925,6 +928,33 @@ class OpenAIChatCompletion(BaseLLM):
response_object=stringified_response, response_object=stringified_response,
model_response_object=model_response, model_response_object=model_response,
) )
except openai.UnprocessableEntityError as e:
## check if body contains unprocessable params - related issue https://github.com/BerriAI/litellm/issues/4800
if litellm.drop_params is True or drop_params is True:
if e.body is not None and e.body.get("detail"): # type: ignore
detail = e.body.get("detail") # type: ignore
invalid_params: List[str] = []
if (
isinstance(detail, List)
and len(detail) > 0
and isinstance(detail[0], dict)
):
for error_dict in detail:
if (
error_dict.get("loc")
and isinstance(error_dict.get("loc"), list)
and len(error_dict.get("loc")) == 2
):
invalid_params.append(error_dict["loc"][1])
new_data = {}
for k, v in optional_params.items():
if k not in invalid_params:
new_data[k] = v
optional_params = new_data
else:
raise e
# e.message
except Exception as e: except Exception as e:
if print_verbose is not None: if print_verbose is not None:
print_verbose(f"openai.py: Received openai error - {str(e)}") print_verbose(f"openai.py: Received openai error - {str(e)}")
@ -982,49 +1012,82 @@ class OpenAIChatCompletion(BaseLLM):
client=None, client=None,
max_retries=None, max_retries=None,
headers=None, headers=None,
drop_params: Optional[bool] = None,
): ):
response = None response = None
try: for _ in range(
openai_aclient = self._get_openai_client( 2
is_async=True, ): # if call fails due to alternating messages, retry with reformatted message
api_key=api_key, try:
api_base=api_base, openai_aclient = self._get_openai_client(
timeout=timeout, is_async=True,
max_retries=max_retries, api_key=api_key,
organization=organization, api_base=api_base,
client=client, timeout=timeout,
) max_retries=max_retries,
organization=organization,
client=client,
)
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=data["messages"], input=data["messages"],
api_key=openai_aclient.api_key, api_key=openai_aclient.api_key,
additional_args={ additional_args={
"headers": {"Authorization": f"Bearer {openai_aclient.api_key}"}, "headers": {
"api_base": openai_aclient._base_url._uri_reference, "Authorization": f"Bearer {openai_aclient.api_key}"
"acompletion": True, },
"complete_input_dict": data, "api_base": openai_aclient._base_url._uri_reference,
}, "acompletion": True,
) "complete_input_dict": data,
},
)
headers, response = await self.make_openai_chat_completion_request( headers, response = await self.make_openai_chat_completion_request(
openai_aclient=openai_aclient, data=data, timeout=timeout openai_aclient=openai_aclient, data=data, timeout=timeout
) )
stringified_response = response.model_dump() stringified_response = response.model_dump()
logging_obj.post_call( logging_obj.post_call(
input=data["messages"], input=data["messages"],
api_key=api_key, api_key=api_key,
original_response=stringified_response, original_response=stringified_response,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
logging_obj.model_call_details["response_headers"] = headers logging_obj.model_call_details["response_headers"] = headers
return convert_to_model_response_object( return convert_to_model_response_object(
response_object=stringified_response, response_object=stringified_response,
model_response_object=model_response, model_response_object=model_response,
hidden_params={"headers": headers}, hidden_params={"headers": headers},
) )
except Exception as e: except openai.UnprocessableEntityError as e:
raise e ## check if body contains unprocessable params - related issue https://github.com/BerriAI/litellm/issues/4800
if litellm.drop_params is True or drop_params is True:
if e.body is not None and e.body.get("detail"): # type: ignore
detail = e.body.get("detail") # type: ignore
invalid_params: List[str] = []
if (
isinstance(detail, List)
and len(detail) > 0
and isinstance(detail[0], dict)
):
for error_dict in detail:
if (
error_dict.get("loc")
and isinstance(error_dict.get("loc"), list)
and len(error_dict.get("loc")) == 2
):
invalid_params.append(error_dict["loc"][1])
new_data = {}
for k, v in data.items():
if k not in invalid_params:
new_data[k] = v
data = new_data
else:
raise e
# e.message
except Exception as e:
raise e
def streaming( def streaming(
self, self,
@ -1081,57 +1144,87 @@ class OpenAIChatCompletion(BaseLLM):
client=None, client=None,
max_retries=None, max_retries=None,
headers=None, headers=None,
drop_params: Optional[bool] = None,
): ):
response = None response = None
try: for _ in range(2):
openai_aclient = self._get_openai_client( try:
is_async=True, openai_aclient = self._get_openai_client(
api_key=api_key, is_async=True,
api_base=api_base, api_key=api_key,
timeout=timeout, api_base=api_base,
max_retries=max_retries, timeout=timeout,
organization=organization, max_retries=max_retries,
client=client, organization=organization,
) client=client,
## LOGGING
logging_obj.pre_call(
input=data["messages"],
api_key=api_key,
additional_args={
"headers": headers,
"api_base": api_base,
"acompletion": True,
"complete_input_dict": data,
},
)
headers, response = await self.make_openai_chat_completion_request(
openai_aclient=openai_aclient, data=data, timeout=timeout
)
logging_obj.model_call_details["response_headers"] = headers
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="openai",
logging_obj=logging_obj,
stream_options=data.get("stream_options", None),
)
return streamwrapper
except (
Exception
) as e: # need to exception handle here. async exceptions don't get caught in sync functions.
if response is not None and hasattr(response, "text"):
raise OpenAIError(
status_code=500,
message=f"{str(e)}\n\nOriginal Response: {response.text}",
) )
else: ## LOGGING
if type(e).__name__ == "ReadTimeout": logging_obj.pre_call(
raise OpenAIError(status_code=408, message=f"{type(e).__name__}") input=data["messages"],
elif hasattr(e, "status_code"): api_key=api_key,
raise OpenAIError(status_code=e.status_code, message=str(e)) additional_args={
"headers": headers,
"api_base": api_base,
"acompletion": True,
"complete_input_dict": data,
},
)
headers, response = await self.make_openai_chat_completion_request(
openai_aclient=openai_aclient, data=data, timeout=timeout
)
logging_obj.model_call_details["response_headers"] = headers
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="openai",
logging_obj=logging_obj,
stream_options=data.get("stream_options", None),
)
return streamwrapper
except openai.UnprocessableEntityError as e:
## check if body contains unprocessable params - related issue https://github.com/BerriAI/litellm/issues/4800
if litellm.drop_params is True or drop_params is True:
if e.body is not None and e.body.get("detail"): # type: ignore
detail = e.body.get("detail") # type: ignore
invalid_params: List[str] = []
if (
isinstance(detail, List)
and len(detail) > 0
and isinstance(detail[0], dict)
):
for error_dict in detail:
if (
error_dict.get("loc")
and isinstance(error_dict.get("loc"), list)
and len(error_dict.get("loc")) == 2
):
invalid_params.append(error_dict["loc"][1])
new_data = {}
for k, v in data.items():
if k not in invalid_params:
new_data[k] = v
data = new_data
else: else:
raise OpenAIError(status_code=500, message=f"{str(e)}") raise e
except (
Exception
) as e: # need to exception handle here. async exceptions don't get caught in sync functions.
if response is not None and hasattr(response, "text"):
raise OpenAIError(
status_code=500,
message=f"{str(e)}\n\nOriginal Response: {response.text}",
)
else:
if type(e).__name__ == "ReadTimeout":
raise OpenAIError(
status_code=408, message=f"{type(e).__name__}"
)
elif hasattr(e, "status_code"):
raise OpenAIError(status_code=e.status_code, message=str(e))
else:
raise OpenAIError(status_code=500, message=f"{str(e)}")
# Embedding # Embedding
async def make_openai_embedding_request( async def make_openai_embedding_request(

View file

@ -1176,6 +1176,7 @@ def completion(
client=client, # pass AsyncOpenAI, OpenAI client client=client, # pass AsyncOpenAI, OpenAI client
organization=organization, organization=organization,
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
drop_params=non_default_params.get("drop_params"),
) )
except Exception as e: except Exception as e:
## LOGGING - log the original exception returned ## LOGGING - log the original exception returned

View file

@ -1,6 +1,9 @@
model_list: model_list:
- model_name: azure-chatgpt - model_name: azure-mistral
litellm_params: litellm_params:
model: azure/chatgpt-v-2 model: azure_ai/mistral
api_key: os.environ/AZURE_API_KEY api_key: os.environ/AZURE_AI_MISTRAL_API_KEY
api_base: os.environ/AZURE_API_BASE api_base: os.environ/AZURE_AI_MISTRAL_API_BASE
litellm_settings:
drop_params: true

View file

@ -142,6 +142,36 @@ def test_completion_azure_ai_command_r():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_completion_azure_ai_mistral_invalid_params(sync_mode):
try:
import os
litellm.set_verbose = True
os.environ["AZURE_AI_API_BASE"] = os.getenv("AZURE_MISTRAL_API_BASE", "")
os.environ["AZURE_AI_API_KEY"] = os.getenv("AZURE_MISTRAL_API_KEY", "")
data = {
"model": "azure_ai/mistral",
"messages": [{"role": "user", "content": "What is the meaning of life?"}],
"frequency_penalty": 0.1,
"presence_penalty": 0.1,
"drop_params": True,
}
if sync_mode:
response: litellm.ModelResponse = completion(**data) # type: ignore
else:
response: litellm.ModelResponse = await litellm.acompletion(**data) # type: ignore
assert "azure_ai" in response.model
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_completion_azure_command_r(): def test_completion_azure_command_r():
try: try:
litellm.set_verbose = True litellm.set_verbose = True

View file

@ -2881,6 +2881,40 @@ def test_azure_streaming_and_function_calling():
raise e raise e
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_completion_azure_ai_mistral_invalid_params(sync_mode):
try:
import os
litellm.set_verbose = True
os.environ["AZURE_AI_API_BASE"] = os.getenv("AZURE_MISTRAL_API_BASE", "")
os.environ["AZURE_AI_API_KEY"] = os.getenv("AZURE_MISTRAL_API_KEY", "")
data = {
"model": "azure_ai/mistral",
"messages": [{"role": "user", "content": "What is the meaning of life?"}],
"frequency_penalty": 0.1,
"presence_penalty": 0.1,
"drop_params": True,
"stream": True,
}
if sync_mode:
response: litellm.ModelResponse = completion(**data) # type: ignore
for chunk in response:
print(chunk)
else:
response: litellm.ModelResponse = await litellm.acompletion(**data) # type: ignore
async for chunk in response:
print(chunk)
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_azure_astreaming_and_function_calling(): async def test_azure_astreaming_and_function_calling():
import uuid import uuid