fix(cohere_chat.py): support passing 'extra_headers'

Fixes https://github.com/BerriAI/litellm/issues/4709
This commit is contained in:
Krrish Dholakia 2024-08-22 10:16:43 -07:00
parent 36c021b309
commit 11bfc1dca7
5 changed files with 50 additions and 16 deletions

View file

@ -124,12 +124,14 @@ class CohereConfig:
} }
def validate_environment(api_key): def validate_environment(api_key, headers: dict):
headers = { headers.update(
"Request-Source": "unspecified:litellm", {
"accept": "application/json", "Request-Source": "unspecified:litellm",
"content-type": "application/json", "accept": "application/json",
} "content-type": "application/json",
}
)
if api_key: if api_key:
headers["Authorization"] = f"Bearer {api_key}" headers["Authorization"] = f"Bearer {api_key}"
return headers return headers
@ -144,11 +146,12 @@ def completion(
encoding, encoding,
api_key, api_key,
logging_obj, logging_obj,
headers: dict,
optional_params=None, optional_params=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
): ):
headers = validate_environment(api_key) headers = validate_environment(api_key, headers=headers)
completion_url = api_base completion_url = api_base
model = model model = model
prompt = " ".join(message["content"] for message in messages) prompt = " ".join(message["content"] for message in messages)
@ -338,13 +341,14 @@ def embedding(
model_response: litellm.EmbeddingResponse, model_response: litellm.EmbeddingResponse,
logging_obj: LiteLLMLoggingObj, logging_obj: LiteLLMLoggingObj,
optional_params: dict, optional_params: dict,
headers: dict,
encoding: Any, encoding: Any,
api_key: Optional[str] = None, api_key: Optional[str] = None,
aembedding: Optional[bool] = None, aembedding: Optional[bool] = None,
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None), timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
): ):
headers = validate_environment(api_key) headers = validate_environment(api_key, headers=headers)
embed_url = "https://api.cohere.ai/v1/embed" embed_url = "https://api.cohere.ai/v1/embed"
model = model model = model
data = {"model": model, "texts": input, **optional_params} data = {"model": model, "texts": input, **optional_params}

View file

@ -116,12 +116,14 @@ class CohereChatConfig:
} }
def validate_environment(api_key): def validate_environment(api_key, headers: dict):
headers = { headers.update(
"Request-Source": "unspecified:litellm", {
"accept": "application/json", "Request-Source": "unspecified:litellm",
"content-type": "application/json", "accept": "application/json",
} "content-type": "application/json",
}
)
if api_key: if api_key:
headers["Authorization"] = f"Bearer {api_key}" headers["Authorization"] = f"Bearer {api_key}"
return headers return headers
@ -203,13 +205,14 @@ def completion(
model_response: ModelResponse, model_response: ModelResponse,
print_verbose: Callable, print_verbose: Callable,
optional_params: dict, optional_params: dict,
headers: dict,
encoding, encoding,
api_key, api_key,
logging_obj, logging_obj,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
): ):
headers = validate_environment(api_key) headers = validate_environment(api_key, headers=headers)
completion_url = api_base completion_url = api_base
model = model model = model
most_recent_message, chat_history = cohere_messages_pt_v2( most_recent_message, chat_history = cohere_messages_pt_v2(

View file

@ -1634,6 +1634,13 @@ def completion(
or "https://api.cohere.ai/v1/generate" or "https://api.cohere.ai/v1/generate"
) )
headers = headers or litellm.headers or {}
if headers is None:
headers = {}
if extra_headers is not None:
headers.update(extra_headers)
model_response = cohere.completion( model_response = cohere.completion(
model=model, model=model,
messages=messages, messages=messages,
@ -1644,6 +1651,7 @@ def completion(
litellm_params=litellm_params, litellm_params=litellm_params,
logger_fn=logger_fn, logger_fn=logger_fn,
encoding=encoding, encoding=encoding,
headers=headers,
api_key=cohere_key, api_key=cohere_key,
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
) )
@ -1674,6 +1682,13 @@ def completion(
or "https://api.cohere.ai/v1/chat" or "https://api.cohere.ai/v1/chat"
) )
headers = headers or litellm.headers or {}
if headers is None:
headers = {}
if extra_headers is not None:
headers.update(extra_headers)
model_response = cohere_chat.completion( model_response = cohere_chat.completion(
model=model, model=model,
messages=messages, messages=messages,
@ -1682,6 +1697,7 @@ def completion(
print_verbose=print_verbose, print_verbose=print_verbose,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params, litellm_params=litellm_params,
headers=headers,
logger_fn=logger_fn, logger_fn=logger_fn,
encoding=encoding, encoding=encoding,
api_key=cohere_key, api_key=cohere_key,
@ -3158,6 +3174,7 @@ def embedding(
encoding_format = kwargs.get("encoding_format", None) encoding_format = kwargs.get("encoding_format", None)
proxy_server_request = kwargs.get("proxy_server_request", None) proxy_server_request = kwargs.get("proxy_server_request", None)
aembedding = kwargs.get("aembedding", None) aembedding = kwargs.get("aembedding", None)
extra_headers = kwargs.get("extra_headers", None)
### CUSTOM MODEL COST ### ### CUSTOM MODEL COST ###
input_cost_per_token = kwargs.get("input_cost_per_token", None) input_cost_per_token = kwargs.get("input_cost_per_token", None)
output_cost_per_token = kwargs.get("output_cost_per_token", None) output_cost_per_token = kwargs.get("output_cost_per_token", None)
@ -3229,6 +3246,7 @@ def embedding(
"model_config", "model_config",
"cooldown_time", "cooldown_time",
"tags", "tags",
"extra_headers",
] ]
default_params = openai_params + litellm_params default_params = openai_params + litellm_params
non_default_params = { non_default_params = {
@ -3292,7 +3310,7 @@ def embedding(
"cooldown_time": cooldown_time, "cooldown_time": cooldown_time,
}, },
) )
if azure == True or custom_llm_provider == "azure": if azure is True or custom_llm_provider == "azure":
# azure configs # azure configs
api_type = get_secret("AZURE_API_TYPE") or "azure" api_type = get_secret("AZURE_API_TYPE") or "azure"
@ -3398,12 +3416,18 @@ def embedding(
or get_secret("CO_API_KEY") or get_secret("CO_API_KEY")
or litellm.api_key or litellm.api_key
) )
if extra_headers is not None and isinstance(extra_headers, dict):
headers = extra_headers
else:
headers = {}
response = cohere.embedding( response = cohere.embedding(
model=model, model=model,
input=input, input=input,
optional_params=optional_params, optional_params=optional_params,
encoding=encoding, encoding=encoding,
api_key=cohere_key, # type: ignore api_key=cohere_key, # type: ignore
headers=headers,
logging_obj=logging, logging_obj=logging,
model_response=EmbeddingResponse(), model_response=EmbeddingResponse(),
aembedding=aembedding, aembedding=aembedding,

View file

@ -3653,6 +3653,7 @@ def test_completion_cohere():
response = completion( response = completion(
model="command-r", model="command-r",
messages=messages, messages=messages,
extra_headers={"Helicone-Property-Locale": "ko"},
) )
print(response) print(response)
except Exception as e: except Exception as e:

View file

@ -4219,6 +4219,7 @@ def get_supported_openai_params(
"presence_penalty", "presence_penalty",
"stop", "stop",
"n", "n",
"extra_headers",
] ]
elif custom_llm_provider == "cohere_chat": elif custom_llm_provider == "cohere_chat":
return [ return [
@ -4233,6 +4234,7 @@ def get_supported_openai_params(
"tools", "tools",
"tool_choice", "tool_choice",
"seed", "seed",
"extra_headers",
] ]
elif custom_llm_provider == "maritalk": elif custom_llm_provider == "maritalk":
return [ return [