forked from phoenix/litellm-mirror
Merge pull request #4806 from BerriAI/litellm_drop_invalid_params
fix(openai.py): drop invalid params if `drop_params: true` for azure ai
This commit is contained in:
commit
fee4f3385d
5 changed files with 251 additions and 90 deletions
|
@ -803,6 +803,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
client=None,
|
client=None,
|
||||||
organization: Optional[str] = None,
|
organization: Optional[str] = None,
|
||||||
custom_llm_provider: Optional[str] = None,
|
custom_llm_provider: Optional[str] = None,
|
||||||
|
drop_params: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
super().completion()
|
super().completion()
|
||||||
exception_mapping_worked = False
|
exception_mapping_worked = False
|
||||||
|
@ -858,6 +859,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
client=client,
|
client=client,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
organization=organization,
|
organization=organization,
|
||||||
|
drop_params=drop_params,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self.acompletion(
|
return self.acompletion(
|
||||||
|
@ -871,6 +873,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
client=client,
|
client=client,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
organization=organization,
|
organization=organization,
|
||||||
|
drop_params=drop_params,
|
||||||
)
|
)
|
||||||
elif optional_params.get("stream", False):
|
elif optional_params.get("stream", False):
|
||||||
return self.streaming(
|
return self.streaming(
|
||||||
|
@ -925,6 +928,33 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
response_object=stringified_response,
|
response_object=stringified_response,
|
||||||
model_response_object=model_response,
|
model_response_object=model_response,
|
||||||
)
|
)
|
||||||
|
except openai.UnprocessableEntityError as e:
|
||||||
|
## check if body contains unprocessable params - related issue https://github.com/BerriAI/litellm/issues/4800
|
||||||
|
if litellm.drop_params is True or drop_params is True:
|
||||||
|
if e.body is not None and e.body.get("detail"): # type: ignore
|
||||||
|
detail = e.body.get("detail") # type: ignore
|
||||||
|
invalid_params: List[str] = []
|
||||||
|
if (
|
||||||
|
isinstance(detail, List)
|
||||||
|
and len(detail) > 0
|
||||||
|
and isinstance(detail[0], dict)
|
||||||
|
):
|
||||||
|
for error_dict in detail:
|
||||||
|
if (
|
||||||
|
error_dict.get("loc")
|
||||||
|
and isinstance(error_dict.get("loc"), list)
|
||||||
|
and len(error_dict.get("loc")) == 2
|
||||||
|
):
|
||||||
|
invalid_params.append(error_dict["loc"][1])
|
||||||
|
|
||||||
|
new_data = {}
|
||||||
|
for k, v in optional_params.items():
|
||||||
|
if k not in invalid_params:
|
||||||
|
new_data[k] = v
|
||||||
|
optional_params = new_data
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
# e.message
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if print_verbose is not None:
|
if print_verbose is not None:
|
||||||
print_verbose(f"openai.py: Received openai error - {str(e)}")
|
print_verbose(f"openai.py: Received openai error - {str(e)}")
|
||||||
|
@ -982,49 +1012,82 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
client=None,
|
client=None,
|
||||||
max_retries=None,
|
max_retries=None,
|
||||||
headers=None,
|
headers=None,
|
||||||
|
drop_params: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
response = None
|
response = None
|
||||||
try:
|
for _ in range(
|
||||||
openai_aclient = self._get_openai_client(
|
2
|
||||||
is_async=True,
|
): # if call fails due to alternating messages, retry with reformatted message
|
||||||
api_key=api_key,
|
try:
|
||||||
api_base=api_base,
|
openai_aclient = self._get_openai_client(
|
||||||
timeout=timeout,
|
is_async=True,
|
||||||
max_retries=max_retries,
|
api_key=api_key,
|
||||||
organization=organization,
|
api_base=api_base,
|
||||||
client=client,
|
timeout=timeout,
|
||||||
)
|
max_retries=max_retries,
|
||||||
|
organization=organization,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=data["messages"],
|
input=data["messages"],
|
||||||
api_key=openai_aclient.api_key,
|
api_key=openai_aclient.api_key,
|
||||||
additional_args={
|
additional_args={
|
||||||
"headers": {"Authorization": f"Bearer {openai_aclient.api_key}"},
|
"headers": {
|
||||||
"api_base": openai_aclient._base_url._uri_reference,
|
"Authorization": f"Bearer {openai_aclient.api_key}"
|
||||||
"acompletion": True,
|
},
|
||||||
"complete_input_dict": data,
|
"api_base": openai_aclient._base_url._uri_reference,
|
||||||
},
|
"acompletion": True,
|
||||||
)
|
"complete_input_dict": data,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
headers, response = await self.make_openai_chat_completion_request(
|
headers, response = await self.make_openai_chat_completion_request(
|
||||||
openai_aclient=openai_aclient, data=data, timeout=timeout
|
openai_aclient=openai_aclient, data=data, timeout=timeout
|
||||||
)
|
)
|
||||||
stringified_response = response.model_dump()
|
stringified_response = response.model_dump()
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=data["messages"],
|
input=data["messages"],
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
original_response=stringified_response,
|
original_response=stringified_response,
|
||||||
additional_args={"complete_input_dict": data},
|
additional_args={"complete_input_dict": data},
|
||||||
)
|
)
|
||||||
logging_obj.model_call_details["response_headers"] = headers
|
logging_obj.model_call_details["response_headers"] = headers
|
||||||
return convert_to_model_response_object(
|
return convert_to_model_response_object(
|
||||||
response_object=stringified_response,
|
response_object=stringified_response,
|
||||||
model_response_object=model_response,
|
model_response_object=model_response,
|
||||||
hidden_params={"headers": headers},
|
hidden_params={"headers": headers},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except openai.UnprocessableEntityError as e:
|
||||||
raise e
|
## check if body contains unprocessable params - related issue https://github.com/BerriAI/litellm/issues/4800
|
||||||
|
if litellm.drop_params is True or drop_params is True:
|
||||||
|
if e.body is not None and e.body.get("detail"): # type: ignore
|
||||||
|
detail = e.body.get("detail") # type: ignore
|
||||||
|
invalid_params: List[str] = []
|
||||||
|
if (
|
||||||
|
isinstance(detail, List)
|
||||||
|
and len(detail) > 0
|
||||||
|
and isinstance(detail[0], dict)
|
||||||
|
):
|
||||||
|
for error_dict in detail:
|
||||||
|
if (
|
||||||
|
error_dict.get("loc")
|
||||||
|
and isinstance(error_dict.get("loc"), list)
|
||||||
|
and len(error_dict.get("loc")) == 2
|
||||||
|
):
|
||||||
|
invalid_params.append(error_dict["loc"][1])
|
||||||
|
|
||||||
|
new_data = {}
|
||||||
|
for k, v in data.items():
|
||||||
|
if k not in invalid_params:
|
||||||
|
new_data[k] = v
|
||||||
|
data = new_data
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
# e.message
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
def streaming(
|
def streaming(
|
||||||
self,
|
self,
|
||||||
|
@ -1081,57 +1144,87 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
client=None,
|
client=None,
|
||||||
max_retries=None,
|
max_retries=None,
|
||||||
headers=None,
|
headers=None,
|
||||||
|
drop_params: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
response = None
|
response = None
|
||||||
try:
|
for _ in range(2):
|
||||||
openai_aclient = self._get_openai_client(
|
try:
|
||||||
is_async=True,
|
openai_aclient = self._get_openai_client(
|
||||||
api_key=api_key,
|
is_async=True,
|
||||||
api_base=api_base,
|
api_key=api_key,
|
||||||
timeout=timeout,
|
api_base=api_base,
|
||||||
max_retries=max_retries,
|
timeout=timeout,
|
||||||
organization=organization,
|
max_retries=max_retries,
|
||||||
client=client,
|
organization=organization,
|
||||||
)
|
client=client,
|
||||||
## LOGGING
|
|
||||||
logging_obj.pre_call(
|
|
||||||
input=data["messages"],
|
|
||||||
api_key=api_key,
|
|
||||||
additional_args={
|
|
||||||
"headers": headers,
|
|
||||||
"api_base": api_base,
|
|
||||||
"acompletion": True,
|
|
||||||
"complete_input_dict": data,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
headers, response = await self.make_openai_chat_completion_request(
|
|
||||||
openai_aclient=openai_aclient, data=data, timeout=timeout
|
|
||||||
)
|
|
||||||
logging_obj.model_call_details["response_headers"] = headers
|
|
||||||
streamwrapper = CustomStreamWrapper(
|
|
||||||
completion_stream=response,
|
|
||||||
model=model,
|
|
||||||
custom_llm_provider="openai",
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
stream_options=data.get("stream_options", None),
|
|
||||||
)
|
|
||||||
return streamwrapper
|
|
||||||
except (
|
|
||||||
Exception
|
|
||||||
) as e: # need to exception handle here. async exceptions don't get caught in sync functions.
|
|
||||||
if response is not None and hasattr(response, "text"):
|
|
||||||
raise OpenAIError(
|
|
||||||
status_code=500,
|
|
||||||
message=f"{str(e)}\n\nOriginal Response: {response.text}",
|
|
||||||
)
|
)
|
||||||
else:
|
## LOGGING
|
||||||
if type(e).__name__ == "ReadTimeout":
|
logging_obj.pre_call(
|
||||||
raise OpenAIError(status_code=408, message=f"{type(e).__name__}")
|
input=data["messages"],
|
||||||
elif hasattr(e, "status_code"):
|
api_key=api_key,
|
||||||
raise OpenAIError(status_code=e.status_code, message=str(e))
|
additional_args={
|
||||||
|
"headers": headers,
|
||||||
|
"api_base": api_base,
|
||||||
|
"acompletion": True,
|
||||||
|
"complete_input_dict": data,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
headers, response = await self.make_openai_chat_completion_request(
|
||||||
|
openai_aclient=openai_aclient, data=data, timeout=timeout
|
||||||
|
)
|
||||||
|
logging_obj.model_call_details["response_headers"] = headers
|
||||||
|
streamwrapper = CustomStreamWrapper(
|
||||||
|
completion_stream=response,
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="openai",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
stream_options=data.get("stream_options", None),
|
||||||
|
)
|
||||||
|
return streamwrapper
|
||||||
|
except openai.UnprocessableEntityError as e:
|
||||||
|
## check if body contains unprocessable params - related issue https://github.com/BerriAI/litellm/issues/4800
|
||||||
|
if litellm.drop_params is True or drop_params is True:
|
||||||
|
if e.body is not None and e.body.get("detail"): # type: ignore
|
||||||
|
detail = e.body.get("detail") # type: ignore
|
||||||
|
invalid_params: List[str] = []
|
||||||
|
if (
|
||||||
|
isinstance(detail, List)
|
||||||
|
and len(detail) > 0
|
||||||
|
and isinstance(detail[0], dict)
|
||||||
|
):
|
||||||
|
for error_dict in detail:
|
||||||
|
if (
|
||||||
|
error_dict.get("loc")
|
||||||
|
and isinstance(error_dict.get("loc"), list)
|
||||||
|
and len(error_dict.get("loc")) == 2
|
||||||
|
):
|
||||||
|
invalid_params.append(error_dict["loc"][1])
|
||||||
|
|
||||||
|
new_data = {}
|
||||||
|
for k, v in data.items():
|
||||||
|
if k not in invalid_params:
|
||||||
|
new_data[k] = v
|
||||||
|
data = new_data
|
||||||
else:
|
else:
|
||||||
raise OpenAIError(status_code=500, message=f"{str(e)}")
|
raise e
|
||||||
|
except (
|
||||||
|
Exception
|
||||||
|
) as e: # need to exception handle here. async exceptions don't get caught in sync functions.
|
||||||
|
if response is not None and hasattr(response, "text"):
|
||||||
|
raise OpenAIError(
|
||||||
|
status_code=500,
|
||||||
|
message=f"{str(e)}\n\nOriginal Response: {response.text}",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if type(e).__name__ == "ReadTimeout":
|
||||||
|
raise OpenAIError(
|
||||||
|
status_code=408, message=f"{type(e).__name__}"
|
||||||
|
)
|
||||||
|
elif hasattr(e, "status_code"):
|
||||||
|
raise OpenAIError(status_code=e.status_code, message=str(e))
|
||||||
|
else:
|
||||||
|
raise OpenAIError(status_code=500, message=f"{str(e)}")
|
||||||
|
|
||||||
# Embedding
|
# Embedding
|
||||||
async def make_openai_embedding_request(
|
async def make_openai_embedding_request(
|
||||||
|
|
|
@ -1176,6 +1176,7 @@ def completion(
|
||||||
client=client, # pass AsyncOpenAI, OpenAI client
|
client=client, # pass AsyncOpenAI, OpenAI client
|
||||||
organization=organization,
|
organization=organization,
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
drop_params=non_default_params.get("drop_params"),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
## LOGGING - log the original exception returned
|
## LOGGING - log the original exception returned
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: azure-chatgpt
|
- model_name: azure-mistral
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: azure/chatgpt-v-2
|
model: azure_ai/mistral
|
||||||
api_key: os.environ/AZURE_API_KEY
|
api_key: os.environ/AZURE_AI_MISTRAL_API_KEY
|
||||||
api_base: os.environ/AZURE_API_BASE
|
api_base: os.environ/AZURE_AI_MISTRAL_API_BASE
|
||||||
|
|
||||||
|
litellm_settings:
|
||||||
|
drop_params: true
|
|
@ -142,6 +142,36 @@ def test_completion_azure_ai_command_r():
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_completion_azure_ai_mistral_invalid_params(sync_mode):
|
||||||
|
try:
|
||||||
|
import os
|
||||||
|
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
os.environ["AZURE_AI_API_BASE"] = os.getenv("AZURE_MISTRAL_API_BASE", "")
|
||||||
|
os.environ["AZURE_AI_API_KEY"] = os.getenv("AZURE_MISTRAL_API_KEY", "")
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"model": "azure_ai/mistral",
|
||||||
|
"messages": [{"role": "user", "content": "What is the meaning of life?"}],
|
||||||
|
"frequency_penalty": 0.1,
|
||||||
|
"presence_penalty": 0.1,
|
||||||
|
"drop_params": True,
|
||||||
|
}
|
||||||
|
if sync_mode:
|
||||||
|
response: litellm.ModelResponse = completion(**data) # type: ignore
|
||||||
|
else:
|
||||||
|
response: litellm.ModelResponse = await litellm.acompletion(**data) # type: ignore
|
||||||
|
|
||||||
|
assert "azure_ai" in response.model
|
||||||
|
except litellm.Timeout as e:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
def test_completion_azure_command_r():
|
def test_completion_azure_command_r():
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
|
|
@ -2881,6 +2881,40 @@ def test_azure_streaming_and_function_calling():
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_completion_azure_ai_mistral_invalid_params(sync_mode):
|
||||||
|
try:
|
||||||
|
import os
|
||||||
|
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
os.environ["AZURE_AI_API_BASE"] = os.getenv("AZURE_MISTRAL_API_BASE", "")
|
||||||
|
os.environ["AZURE_AI_API_KEY"] = os.getenv("AZURE_MISTRAL_API_KEY", "")
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"model": "azure_ai/mistral",
|
||||||
|
"messages": [{"role": "user", "content": "What is the meaning of life?"}],
|
||||||
|
"frequency_penalty": 0.1,
|
||||||
|
"presence_penalty": 0.1,
|
||||||
|
"drop_params": True,
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
if sync_mode:
|
||||||
|
response: litellm.ModelResponse = completion(**data) # type: ignore
|
||||||
|
for chunk in response:
|
||||||
|
print(chunk)
|
||||||
|
else:
|
||||||
|
response: litellm.ModelResponse = await litellm.acompletion(**data) # type: ignore
|
||||||
|
|
||||||
|
async for chunk in response:
|
||||||
|
print(chunk)
|
||||||
|
except litellm.Timeout as e:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_azure_astreaming_and_function_calling():
|
async def test_azure_astreaming_and_function_calling():
|
||||||
import uuid
|
import uuid
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue