diff --git a/litellm/llms/cohere.py b/litellm/llms/cohere.py index 3873027b2a..8bd1051e84 100644 --- a/litellm/llms/cohere.py +++ b/litellm/llms/cohere.py @@ -124,12 +124,14 @@ class CohereConfig: } -def validate_environment(api_key): - headers = { - "Request-Source": "unspecified:litellm", - "accept": "application/json", - "content-type": "application/json", - } +def validate_environment(api_key, headers: dict): + headers.update( + { + "Request-Source": "unspecified:litellm", + "accept": "application/json", + "content-type": "application/json", + } + ) if api_key: headers["Authorization"] = f"Bearer {api_key}" return headers @@ -144,11 +146,12 @@ def completion( encoding, api_key, logging_obj, + headers: dict, optional_params=None, litellm_params=None, logger_fn=None, ): - headers = validate_environment(api_key) + headers = validate_environment(api_key, headers=headers) completion_url = api_base model = model prompt = " ".join(message["content"] for message in messages) @@ -338,13 +341,14 @@ def embedding( model_response: litellm.EmbeddingResponse, logging_obj: LiteLLMLoggingObj, optional_params: dict, + headers: dict, encoding: Any, api_key: Optional[str] = None, aembedding: Optional[bool] = None, timeout: Union[float, httpx.Timeout] = httpx.Timeout(None), client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, ): - headers = validate_environment(api_key) + headers = validate_environment(api_key, headers=headers) embed_url = "https://api.cohere.ai/v1/embed" model = model data = {"model": model, "texts": input, **optional_params} diff --git a/litellm/llms/cohere_chat.py b/litellm/llms/cohere_chat.py index a0a9a98749..f13e74614b 100644 --- a/litellm/llms/cohere_chat.py +++ b/litellm/llms/cohere_chat.py @@ -116,12 +116,14 @@ class CohereChatConfig: } -def validate_environment(api_key): - headers = { - "Request-Source": "unspecified:litellm", - "accept": "application/json", - "content-type": "application/json", - } +def validate_environment(api_key, headers: dict): + headers.update( + { + "Request-Source": "unspecified:litellm", + "accept": "application/json", + "content-type": "application/json", + } + ) if api_key: headers["Authorization"] = f"Bearer {api_key}" return headers @@ -203,13 +205,14 @@ def completion( model_response: ModelResponse, print_verbose: Callable, optional_params: dict, + headers: dict, encoding, api_key, logging_obj, litellm_params=None, logger_fn=None, ): - headers = validate_environment(api_key) + headers = validate_environment(api_key, headers=headers) completion_url = api_base model = model most_recent_message, chat_history = cohere_messages_pt_v2( diff --git a/litellm/main.py b/litellm/main.py index 80a9a94a34..1beca01137 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1634,6 +1634,13 @@ def completion( or "https://api.cohere.ai/v1/generate" ) + headers = headers or litellm.headers or {} + if headers is None: + headers = {} + + if extra_headers is not None: + headers.update(extra_headers) + model_response = cohere.completion( model=model, messages=messages, @@ -1644,6 +1651,7 @@ def completion( litellm_params=litellm_params, logger_fn=logger_fn, encoding=encoding, + headers=headers, api_key=cohere_key, logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements ) @@ -1674,6 +1682,13 @@ def completion( or "https://api.cohere.ai/v1/chat" ) + headers = headers or litellm.headers or {} + if headers is None: + headers = {} + + if extra_headers is not None: + headers.update(extra_headers) + model_response = cohere_chat.completion( model=model, messages=messages, @@ -1682,6 +1697,7 @@ def completion( print_verbose=print_verbose, optional_params=optional_params, litellm_params=litellm_params, + headers=headers, logger_fn=logger_fn, encoding=encoding, api_key=cohere_key, @@ -3158,6 +3174,7 @@ def embedding( encoding_format = kwargs.get("encoding_format", None) proxy_server_request = kwargs.get("proxy_server_request", None) aembedding = kwargs.get("aembedding", None) + extra_headers = kwargs.get("extra_headers", None) ### CUSTOM MODEL COST ### input_cost_per_token = kwargs.get("input_cost_per_token", None) output_cost_per_token = kwargs.get("output_cost_per_token", None) @@ -3229,6 +3246,7 @@ def embedding( "model_config", "cooldown_time", "tags", + "extra_headers", ] default_params = openai_params + litellm_params non_default_params = { @@ -3292,7 +3310,7 @@ def embedding( "cooldown_time": cooldown_time, }, ) - if azure == True or custom_llm_provider == "azure": + if azure is True or custom_llm_provider == "azure": # azure configs api_type = get_secret("AZURE_API_TYPE") or "azure" @@ -3398,12 +3416,18 @@ def embedding( or get_secret("CO_API_KEY") or litellm.api_key ) + + if extra_headers is not None and isinstance(extra_headers, dict): + headers = extra_headers + else: + headers = {} response = cohere.embedding( model=model, input=input, optional_params=optional_params, encoding=encoding, api_key=cohere_key, # type: ignore + headers=headers, logging_obj=logging, model_response=EmbeddingResponse(), aembedding=aembedding, diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 0941484d95..c0c3c70f92 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -3653,6 +3653,7 @@ def test_completion_cohere(): response = completion( model="command-r", messages=messages, + extra_headers={"Helicone-Property-Locale": "ko"}, ) print(response) except Exception as e: diff --git a/litellm/utils.py b/litellm/utils.py index 0e9e531e98..f3bb944a84 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4219,6 +4219,7 @@ def get_supported_openai_params( "presence_penalty", "stop", "n", + "extra_headers", ] elif custom_llm_provider == "cohere_chat": return [ @@ -4233,6 +4234,7 @@ def get_supported_openai_params( "tools", "tool_choice", "seed", + "extra_headers", ] elif custom_llm_provider == "maritalk": return [