fix _apply_openai_param_overrides

This commit is contained in:
Ishaan Jaff 2025-04-01 21:17:59 -07:00
parent 9acda77b75
commit f7129e5e59
3 changed files with 51 additions and 3 deletions

View file

@ -71,6 +71,7 @@ class OpenAIOSeriesConfig(OpenAIGPTConfig):
verbose_logger.debug( verbose_logger.debug(
f"Unable to infer model provider for model={model}, defaulting to openai for o1 supported param check" f"Unable to infer model provider for model={model}, defaulting to openai for o1 supported param check"
) )
return [ return [
param for param in all_openai_params if param not in non_supported_params param for param in all_openai_params if param not in non_supported_params
] ]

View file

@ -3051,7 +3051,7 @@ def get_optional_params( # noqa: PLR0915
new_parameters.pop("additionalProperties", None) new_parameters.pop("additionalProperties", None)
tool_function["parameters"] = new_parameters tool_function["parameters"] = new_parameters
def _check_valid_arg(supported_params: List[str], allowed_openai_params: List[str]): def _check_valid_arg(supported_params: List[str]):
""" """
Check if the params passed to completion() are supported by the provider Check if the params passed to completion() are supported by the provider
@ -3068,7 +3068,6 @@ def get_optional_params( # noqa: PLR0915
verbose_logger.debug( verbose_logger.debug(
f"\nLiteLLM: Non-Default params passed to completion() {non_default_params}" f"\nLiteLLM: Non-Default params passed to completion() {non_default_params}"
) )
supported_params = supported_params + allowed_openai_params
unsupported_params = {} unsupported_params = {}
for k in non_default_params.keys(): for k in non_default_params.keys():
if k not in supported_params: if k not in supported_params:
@ -3103,9 +3102,13 @@ def get_optional_params( # noqa: PLR0915
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
model=model, custom_llm_provider="openai" model=model, custom_llm_provider="openai"
) )
supported_params = supported_params or []
allowed_openai_params = allowed_openai_params or []
supported_params.extend(allowed_openai_params)
_check_valid_arg( _check_valid_arg(
supported_params=supported_params or [], supported_params=supported_params or [],
allowed_openai_params=allowed_openai_params or [],
) )
## raise exception if provider doesn't support passed in param ## raise exception if provider doesn't support passed in param
if custom_llm_provider == "anthropic": if custom_llm_provider == "anthropic":
@ -3745,6 +3748,26 @@ def get_optional_params( # noqa: PLR0915
if k not in default_params.keys(): if k not in default_params.keys():
optional_params[k] = passed_params[k] optional_params[k] = passed_params[k]
print_verbose(f"Final returned optional params: {optional_params}") print_verbose(f"Final returned optional params: {optional_params}")
optional_params = _apply_openai_param_overrides(
optional_params=optional_params,
non_default_params=non_default_params,
allowed_openai_params=allowed_openai_params,
)
return optional_params
def _apply_openai_param_overrides(
optional_params: dict, non_default_params: dict, allowed_openai_params: list
):
"""
If user passes in allowed_openai_params, apply them to optional_params
These params will get passed as is to the LLM API since the user opted in to passing them in the request
"""
if allowed_openai_params:
for param in allowed_openai_params:
if param not in optional_params:
optional_params[param] = non_default_params.pop(param, None)
return optional_params return optional_params

View file

@ -67,6 +67,30 @@ def test_anthropic_optional_params(stop_sequence, expected_count):
assert len(optional_params) == expected_count assert len(optional_params) == expected_count
def test_get_optional_params_with_allowed_openai_params():
"""
Test if use can dynamically pass in allowed_openai_params to override default behavior
"""
litellm.drop_params = True
tools = [{'type': 'function', 'function': {'name': 'get_current_time', 'description': 'Get the current time in a given location.', 'parameters': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city name, e.g. San Francisco'}}, 'required': ['location']}}}]
response_format = {"type": "json"}
reasoning_effort = "low"
optional_params = get_optional_params(
model="cf/llama-3.1-70b-instruct",
custom_llm_provider="cloudflare",
allowed_openai_params=["tools", "reasoning_effort", "response_format"],
tools=tools,
response_format=response_format,
reasoning_effort=reasoning_effort,
)
print(f"optional_params: {optional_params}")
assert optional_params["tools"] == tools
assert optional_params["response_format"] == response_format
assert optional_params["reasoning_effort"] == reasoning_effort
def test_bedrock_optional_params_embeddings(): def test_bedrock_optional_params_embeddings():
litellm.drop_params = True litellm.drop_params = True
optional_params = get_optional_params_embeddings( optional_params = get_optional_params_embeddings(