forked from phoenix/litellm-mirror
Merge pull request #4806 from BerriAI/litellm_drop_invalid_params
fix(openai.py): drop invalid params if `drop_params: true` for azure ai
This commit is contained in:
commit
fee4f3385d
5 changed files with 251 additions and 90 deletions
|
@ -803,6 +803,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
client=None,
|
||||
organization: Optional[str] = None,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
drop_params: Optional[bool] = None,
|
||||
):
|
||||
super().completion()
|
||||
exception_mapping_worked = False
|
||||
|
@ -858,6 +859,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
client=client,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
drop_params=drop_params,
|
||||
)
|
||||
else:
|
||||
return self.acompletion(
|
||||
|
@ -871,6 +873,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
client=client,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
drop_params=drop_params,
|
||||
)
|
||||
elif optional_params.get("stream", False):
|
||||
return self.streaming(
|
||||
|
@ -925,6 +928,33 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
response_object=stringified_response,
|
||||
model_response_object=model_response,
|
||||
)
|
||||
except openai.UnprocessableEntityError as e:
|
||||
## check if body contains unprocessable params - related issue https://github.com/BerriAI/litellm/issues/4800
|
||||
if litellm.drop_params is True or drop_params is True:
|
||||
if e.body is not None and e.body.get("detail"): # type: ignore
|
||||
detail = e.body.get("detail") # type: ignore
|
||||
invalid_params: List[str] = []
|
||||
if (
|
||||
isinstance(detail, List)
|
||||
and len(detail) > 0
|
||||
and isinstance(detail[0], dict)
|
||||
):
|
||||
for error_dict in detail:
|
||||
if (
|
||||
error_dict.get("loc")
|
||||
and isinstance(error_dict.get("loc"), list)
|
||||
and len(error_dict.get("loc")) == 2
|
||||
):
|
||||
invalid_params.append(error_dict["loc"][1])
|
||||
|
||||
new_data = {}
|
||||
for k, v in optional_params.items():
|
||||
if k not in invalid_params:
|
||||
new_data[k] = v
|
||||
optional_params = new_data
|
||||
else:
|
||||
raise e
|
||||
# e.message
|
||||
except Exception as e:
|
||||
if print_verbose is not None:
|
||||
print_verbose(f"openai.py: Received openai error - {str(e)}")
|
||||
|
@ -982,8 +1012,12 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
client=None,
|
||||
max_retries=None,
|
||||
headers=None,
|
||||
drop_params: Optional[bool] = None,
|
||||
):
|
||||
response = None
|
||||
for _ in range(
|
||||
2
|
||||
): # if call fails due to alternating messages, retry with reformatted message
|
||||
try:
|
||||
openai_aclient = self._get_openai_client(
|
||||
is_async=True,
|
||||
|
@ -1000,7 +1034,9 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
input=data["messages"],
|
||||
api_key=openai_aclient.api_key,
|
||||
additional_args={
|
||||
"headers": {"Authorization": f"Bearer {openai_aclient.api_key}"},
|
||||
"headers": {
|
||||
"Authorization": f"Bearer {openai_aclient.api_key}"
|
||||
},
|
||||
"api_base": openai_aclient._base_url._uri_reference,
|
||||
"acompletion": True,
|
||||
"complete_input_dict": data,
|
||||
|
@ -1023,6 +1059,33 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
model_response_object=model_response,
|
||||
hidden_params={"headers": headers},
|
||||
)
|
||||
except openai.UnprocessableEntityError as e:
|
||||
## check if body contains unprocessable params - related issue https://github.com/BerriAI/litellm/issues/4800
|
||||
if litellm.drop_params is True or drop_params is True:
|
||||
if e.body is not None and e.body.get("detail"): # type: ignore
|
||||
detail = e.body.get("detail") # type: ignore
|
||||
invalid_params: List[str] = []
|
||||
if (
|
||||
isinstance(detail, List)
|
||||
and len(detail) > 0
|
||||
and isinstance(detail[0], dict)
|
||||
):
|
||||
for error_dict in detail:
|
||||
if (
|
||||
error_dict.get("loc")
|
||||
and isinstance(error_dict.get("loc"), list)
|
||||
and len(error_dict.get("loc")) == 2
|
||||
):
|
||||
invalid_params.append(error_dict["loc"][1])
|
||||
|
||||
new_data = {}
|
||||
for k, v in data.items():
|
||||
if k not in invalid_params:
|
||||
new_data[k] = v
|
||||
data = new_data
|
||||
else:
|
||||
raise e
|
||||
# e.message
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
@ -1081,8 +1144,10 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
client=None,
|
||||
max_retries=None,
|
||||
headers=None,
|
||||
drop_params: Optional[bool] = None,
|
||||
):
|
||||
response = None
|
||||
for _ in range(2):
|
||||
try:
|
||||
openai_aclient = self._get_openai_client(
|
||||
is_async=True,
|
||||
|
@ -1117,6 +1182,32 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
stream_options=data.get("stream_options", None),
|
||||
)
|
||||
return streamwrapper
|
||||
except openai.UnprocessableEntityError as e:
|
||||
## check if body contains unprocessable params - related issue https://github.com/BerriAI/litellm/issues/4800
|
||||
if litellm.drop_params is True or drop_params is True:
|
||||
if e.body is not None and e.body.get("detail"): # type: ignore
|
||||
detail = e.body.get("detail") # type: ignore
|
||||
invalid_params: List[str] = []
|
||||
if (
|
||||
isinstance(detail, List)
|
||||
and len(detail) > 0
|
||||
and isinstance(detail[0], dict)
|
||||
):
|
||||
for error_dict in detail:
|
||||
if (
|
||||
error_dict.get("loc")
|
||||
and isinstance(error_dict.get("loc"), list)
|
||||
and len(error_dict.get("loc")) == 2
|
||||
):
|
||||
invalid_params.append(error_dict["loc"][1])
|
||||
|
||||
new_data = {}
|
||||
for k, v in data.items():
|
||||
if k not in invalid_params:
|
||||
new_data[k] = v
|
||||
data = new_data
|
||||
else:
|
||||
raise e
|
||||
except (
|
||||
Exception
|
||||
) as e: # need to exception handle here. async exceptions don't get caught in sync functions.
|
||||
|
@ -1127,7 +1218,9 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
)
|
||||
else:
|
||||
if type(e).__name__ == "ReadTimeout":
|
||||
raise OpenAIError(status_code=408, message=f"{type(e).__name__}")
|
||||
raise OpenAIError(
|
||||
status_code=408, message=f"{type(e).__name__}"
|
||||
)
|
||||
elif hasattr(e, "status_code"):
|
||||
raise OpenAIError(status_code=e.status_code, message=str(e))
|
||||
else:
|
||||
|
|
|
@ -1176,6 +1176,7 @@ def completion(
|
|||
client=client, # pass AsyncOpenAI, OpenAI client
|
||||
organization=organization,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
drop_params=non_default_params.get("drop_params"),
|
||||
)
|
||||
except Exception as e:
|
||||
## LOGGING - log the original exception returned
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
model_list:
|
||||
- model_name: azure-chatgpt
|
||||
- model_name: azure-mistral
|
||||
litellm_params:
|
||||
model: azure/chatgpt-v-2
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
model: azure_ai/mistral
|
||||
api_key: os.environ/AZURE_AI_MISTRAL_API_KEY
|
||||
api_base: os.environ/AZURE_AI_MISTRAL_API_BASE
|
||||
|
||||
litellm_settings:
|
||||
drop_params: true
|
|
@ -142,6 +142,36 @@ def test_completion_azure_ai_command_r():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_azure_ai_mistral_invalid_params(sync_mode):
|
||||
try:
|
||||
import os
|
||||
|
||||
litellm.set_verbose = True
|
||||
|
||||
os.environ["AZURE_AI_API_BASE"] = os.getenv("AZURE_MISTRAL_API_BASE", "")
|
||||
os.environ["AZURE_AI_API_KEY"] = os.getenv("AZURE_MISTRAL_API_KEY", "")
|
||||
|
||||
data = {
|
||||
"model": "azure_ai/mistral",
|
||||
"messages": [{"role": "user", "content": "What is the meaning of life?"}],
|
||||
"frequency_penalty": 0.1,
|
||||
"presence_penalty": 0.1,
|
||||
"drop_params": True,
|
||||
}
|
||||
if sync_mode:
|
||||
response: litellm.ModelResponse = completion(**data) # type: ignore
|
||||
else:
|
||||
response: litellm.ModelResponse = await litellm.acompletion(**data) # type: ignore
|
||||
|
||||
assert "azure_ai" in response.model
|
||||
except litellm.Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_completion_azure_command_r():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
|
|
|
@ -2881,6 +2881,40 @@ def test_azure_streaming_and_function_calling():
|
|||
raise e
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_azure_ai_mistral_invalid_params(sync_mode):
|
||||
try:
|
||||
import os
|
||||
|
||||
litellm.set_verbose = True
|
||||
|
||||
os.environ["AZURE_AI_API_BASE"] = os.getenv("AZURE_MISTRAL_API_BASE", "")
|
||||
os.environ["AZURE_AI_API_KEY"] = os.getenv("AZURE_MISTRAL_API_KEY", "")
|
||||
|
||||
data = {
|
||||
"model": "azure_ai/mistral",
|
||||
"messages": [{"role": "user", "content": "What is the meaning of life?"}],
|
||||
"frequency_penalty": 0.1,
|
||||
"presence_penalty": 0.1,
|
||||
"drop_params": True,
|
||||
"stream": True,
|
||||
}
|
||||
if sync_mode:
|
||||
response: litellm.ModelResponse = completion(**data) # type: ignore
|
||||
for chunk in response:
|
||||
print(chunk)
|
||||
else:
|
||||
response: litellm.ModelResponse = await litellm.acompletion(**data) # type: ignore
|
||||
|
||||
async for chunk in response:
|
||||
print(chunk)
|
||||
except litellm.Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_azure_astreaming_and_function_calling():
|
||||
import uuid
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue