fix(cohere_chat.py): support passing 'extra_headers'

Fixes https://github.com/BerriAI/litellm/issues/4709
This commit is contained in:
Krrish Dholakia 2024-08-22 10:16:43 -07:00
parent 7d10451bc8
commit 8f306f8e41
5 changed files with 50 additions and 16 deletions

View file

@ -124,12 +124,14 @@ class CohereConfig:
}
def validate_environment(api_key):
headers = {
"Request-Source": "unspecified:litellm",
"accept": "application/json",
"content-type": "application/json",
}
def validate_environment(api_key, headers: dict):
headers.update(
{
"Request-Source": "unspecified:litellm",
"accept": "application/json",
"content-type": "application/json",
}
)
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
return headers
@ -144,11 +146,12 @@ def completion(
encoding,
api_key,
logging_obj,
headers: dict,
optional_params=None,
litellm_params=None,
logger_fn=None,
):
headers = validate_environment(api_key)
headers = validate_environment(api_key, headers=headers)
completion_url = api_base
model = model
prompt = " ".join(message["content"] for message in messages)
@ -338,13 +341,14 @@ def embedding(
model_response: litellm.EmbeddingResponse,
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
headers: dict,
encoding: Any,
api_key: Optional[str] = None,
aembedding: Optional[bool] = None,
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
):
headers = validate_environment(api_key)
headers = validate_environment(api_key, headers=headers)
embed_url = "https://api.cohere.ai/v1/embed"
model = model
data = {"model": model, "texts": input, **optional_params}

View file

@ -116,12 +116,14 @@ class CohereChatConfig:
}
def validate_environment(api_key):
headers = {
"Request-Source": "unspecified:litellm",
"accept": "application/json",
"content-type": "application/json",
}
def validate_environment(api_key, headers: dict):
headers.update(
{
"Request-Source": "unspecified:litellm",
"accept": "application/json",
"content-type": "application/json",
}
)
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
return headers
@ -203,13 +205,14 @@ def completion(
model_response: ModelResponse,
print_verbose: Callable,
optional_params: dict,
headers: dict,
encoding,
api_key,
logging_obj,
litellm_params=None,
logger_fn=None,
):
headers = validate_environment(api_key)
headers = validate_environment(api_key, headers=headers)
completion_url = api_base
model = model
most_recent_message, chat_history = cohere_messages_pt_v2(

View file

@ -1634,6 +1634,13 @@ def completion(
or "https://api.cohere.ai/v1/generate"
)
headers = headers or litellm.headers or {}
if headers is None:
headers = {}
if extra_headers is not None:
headers.update(extra_headers)
model_response = cohere.completion(
model=model,
messages=messages,
@ -1644,6 +1651,7 @@ def completion(
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
headers=headers,
api_key=cohere_key,
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
)
@ -1674,6 +1682,13 @@ def completion(
or "https://api.cohere.ai/v1/chat"
)
headers = headers or litellm.headers or {}
if headers is None:
headers = {}
if extra_headers is not None:
headers.update(extra_headers)
model_response = cohere_chat.completion(
model=model,
messages=messages,
@ -1682,6 +1697,7 @@ def completion(
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
logger_fn=logger_fn,
encoding=encoding,
api_key=cohere_key,
@ -3158,6 +3174,7 @@ def embedding(
encoding_format = kwargs.get("encoding_format", None)
proxy_server_request = kwargs.get("proxy_server_request", None)
aembedding = kwargs.get("aembedding", None)
extra_headers = kwargs.get("extra_headers", None)
### CUSTOM MODEL COST ###
input_cost_per_token = kwargs.get("input_cost_per_token", None)
output_cost_per_token = kwargs.get("output_cost_per_token", None)
@ -3229,6 +3246,7 @@ def embedding(
"model_config",
"cooldown_time",
"tags",
"extra_headers",
]
default_params = openai_params + litellm_params
non_default_params = {
@ -3292,7 +3310,7 @@ def embedding(
"cooldown_time": cooldown_time,
},
)
if azure == True or custom_llm_provider == "azure":
if azure is True or custom_llm_provider == "azure":
# azure configs
api_type = get_secret("AZURE_API_TYPE") or "azure"
@ -3398,12 +3416,18 @@ def embedding(
or get_secret("CO_API_KEY")
or litellm.api_key
)
if extra_headers is not None and isinstance(extra_headers, dict):
headers = extra_headers
else:
headers = {}
response = cohere.embedding(
model=model,
input=input,
optional_params=optional_params,
encoding=encoding,
api_key=cohere_key, # type: ignore
headers=headers,
logging_obj=logging,
model_response=EmbeddingResponse(),
aembedding=aembedding,

View file

@ -3653,6 +3653,7 @@ def test_completion_cohere():
response = completion(
model="command-r",
messages=messages,
extra_headers={"Helicone-Property-Locale": "ko"},
)
print(response)
except Exception as e:

View file

@ -4219,6 +4219,7 @@ def get_supported_openai_params(
"presence_penalty",
"stop",
"n",
"extra_headers",
]
elif custom_llm_provider == "cohere_chat":
return [
@ -4233,6 +4234,7 @@ def get_supported_openai_params(
"tools",
"tool_choice",
"seed",
"extra_headers",
]
elif custom_llm_provider == "maritalk":
return [