From 3c8b58bd80573ae514594ef40e2c676f1d827997 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Fri, 23 Feb 2024 08:48:21 -0800 Subject: [PATCH] (feat) support extra_headers --- litellm/main.py | 3 +++ litellm/tests/test_optional_params.py | 15 +++++++++++++++ litellm/utils.py | 5 +++++ 3 files changed, 23 insertions(+) diff --git a/litellm/main.py b/litellm/main.py index 1ee36504f..120f85cb1 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -398,6 +398,7 @@ def completion( logprobs: Optional[bool] = None, top_logprobs: Optional[int] = None, deployment_id=None, + extra_headers: Optional[dict] = None, # soon to be deprecated params by OpenAI functions: Optional[List] = None, function_call: Optional[str] = None, @@ -514,6 +515,7 @@ def completion( "max_retries", "logprobs", "top_logprobs", + "extra_headers", ] litellm_params = [ "metadata", @@ -691,6 +693,7 @@ def completion( max_retries=max_retries, logprobs=logprobs, top_logprobs=top_logprobs, + extra_headers=extra_headers, **non_default_params, ) diff --git a/litellm/tests/test_optional_params.py b/litellm/tests/test_optional_params.py index 45c3b8a38..4fa8df3b6 100644 --- a/litellm/tests/test_optional_params.py +++ b/litellm/tests/test_optional_params.py @@ -118,3 +118,18 @@ def test_azure_gpt_optional_params_gpt_vision_with_extra_body(): # test_azure_gpt_optional_params_gpt_vision_with_extra_body() + + +def test_openai_extra_headers(): + optional_params = litellm.utils.get_optional_params( + user="John", + custom_llm_provider="openai", + max_tokens=10, + temperature=0.2, + extra_headers={"AI-Resource Group": "ishaan-resource"}, + ) + + print(optional_params) + assert optional_params["max_tokens"] == 10 + assert optional_params["temperature"] == 0.2 + assert optional_params["extra_headers"] == {"AI-Resource Group": "ishaan-resource"} diff --git a/litellm/utils.py b/litellm/utils.py index e3516f7fd..ed130c247 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -3907,6 +3907,7 @@ def get_optional_params( max_retries=None, logprobs=None, top_logprobs=None, + extra_headers=None, **kwargs, ): # retrieve all parameters passed to the function @@ -3946,6 +3947,7 @@ def get_optional_params( "max_retries": None, "logprobs": None, "top_logprobs": None, + "extra_headers": None, } # filter out those parameters that were passed with non-default values non_default_params = { @@ -4753,6 +4755,7 @@ def get_optional_params( "max_retries", "logprobs", "top_logprobs", + "extra_headers", ] _check_valid_arg(supported_params=supported_params) if functions is not None: @@ -4793,6 +4796,8 @@ def get_optional_params( optional_params["logprobs"] = logprobs if top_logprobs is not None: optional_params["top_logprobs"] = top_logprobs + if extra_headers is not None: + optional_params["extra_headers"] = extra_headers if custom_llm_provider in ["openai", "azure"] + litellm.openai_compatible_providers: # for openai, azure we should pass the extra/passed params within `extra_body` https://github.com/openai/openai-python/blob/ac33853ba10d13ac149b1fa3ca6dba7d613065c9/src/openai/resources/models.py#L46 extra_body = passed_params.pop("extra_body", {})