diff --git a/litellm/main.py b/litellm/main.py index 1ee36504f1..120f85cb1e 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -398,6 +398,7 @@ def completion( logprobs: Optional[bool] = None, top_logprobs: Optional[int] = None, deployment_id=None, + extra_headers: Optional[dict] = None, # soon to be deprecated params by OpenAI functions: Optional[List] = None, function_call: Optional[str] = None, @@ -514,6 +515,7 @@ def completion( "max_retries", "logprobs", "top_logprobs", + "extra_headers", ] litellm_params = [ "metadata", @@ -691,6 +693,7 @@ def completion( max_retries=max_retries, logprobs=logprobs, top_logprobs=top_logprobs, + extra_headers=extra_headers, **non_default_params, ) diff --git a/litellm/tests/test_optional_params.py b/litellm/tests/test_optional_params.py index 45c3b8a38e..4fa8df3b6d 100644 --- a/litellm/tests/test_optional_params.py +++ b/litellm/tests/test_optional_params.py @@ -118,3 +118,18 @@ def test_azure_gpt_optional_params_gpt_vision_with_extra_body(): # test_azure_gpt_optional_params_gpt_vision_with_extra_body() + + +def test_openai_extra_headers(): + optional_params = litellm.utils.get_optional_params( + user="John", + custom_llm_provider="openai", + max_tokens=10, + temperature=0.2, + extra_headers={"AI-Resource Group": "ishaan-resource"}, + ) + + print(optional_params) + assert optional_params["max_tokens"] == 10 + assert optional_params["temperature"] == 0.2 + assert optional_params["extra_headers"] == {"AI-Resource Group": "ishaan-resource"} diff --git a/litellm/utils.py b/litellm/utils.py index e3516f7fdc..ed130c247a 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -3907,6 +3907,7 @@ def get_optional_params( max_retries=None, logprobs=None, top_logprobs=None, + extra_headers=None, **kwargs, ): # retrieve all parameters passed to the function @@ -3946,6 +3947,7 @@ def get_optional_params( "max_retries": None, "logprobs": None, "top_logprobs": None, + "extra_headers": None, } # filter out those parameters that were passed with non-default values non_default_params = { @@ -4753,6 +4755,7 @@ def get_optional_params( "max_retries", "logprobs", "top_logprobs", + "extra_headers", ] _check_valid_arg(supported_params=supported_params) if functions is not None: @@ -4793,6 +4796,8 @@ def get_optional_params( optional_params["logprobs"] = logprobs if top_logprobs is not None: optional_params["top_logprobs"] = top_logprobs + if extra_headers is not None: + optional_params["extra_headers"] = extra_headers if custom_llm_provider in ["openai", "azure"] + litellm.openai_compatible_providers: # for openai, azure we should pass the extra/passed params within `extra_body` https://github.com/openai/openai-python/blob/ac33853ba10d13ac149b1fa3ca6dba7d613065c9/src/openai/resources/models.py#L46 extra_body = passed_params.pop("extra_body", {})