mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(cohere_chat.py): support passing 'extra_headers'
Fixes https://github.com/BerriAI/litellm/issues/4709
This commit is contained in:
parent
7d10451bc8
commit
8f306f8e41
5 changed files with 50 additions and 16 deletions
|
@ -1634,6 +1634,13 @@ def completion(
|
|||
or "https://api.cohere.ai/v1/generate"
|
||||
)
|
||||
|
||||
headers = headers or litellm.headers or {}
|
||||
if headers is None:
|
||||
headers = {}
|
||||
|
||||
if extra_headers is not None:
|
||||
headers.update(extra_headers)
|
||||
|
||||
model_response = cohere.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -1644,6 +1651,7 @@ def completion(
|
|||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
headers=headers,
|
||||
api_key=cohere_key,
|
||||
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
||||
)
|
||||
|
@ -1674,6 +1682,13 @@ def completion(
|
|||
or "https://api.cohere.ai/v1/chat"
|
||||
)
|
||||
|
||||
headers = headers or litellm.headers or {}
|
||||
if headers is None:
|
||||
headers = {}
|
||||
|
||||
if extra_headers is not None:
|
||||
headers.update(extra_headers)
|
||||
|
||||
model_response = cohere_chat.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -1682,6 +1697,7 @@ def completion(
|
|||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
api_key=cohere_key,
|
||||
|
@ -3158,6 +3174,7 @@ def embedding(
|
|||
encoding_format = kwargs.get("encoding_format", None)
|
||||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||
aembedding = kwargs.get("aembedding", None)
|
||||
extra_headers = kwargs.get("extra_headers", None)
|
||||
### CUSTOM MODEL COST ###
|
||||
input_cost_per_token = kwargs.get("input_cost_per_token", None)
|
||||
output_cost_per_token = kwargs.get("output_cost_per_token", None)
|
||||
|
@ -3229,6 +3246,7 @@ def embedding(
|
|||
"model_config",
|
||||
"cooldown_time",
|
||||
"tags",
|
||||
"extra_headers",
|
||||
]
|
||||
default_params = openai_params + litellm_params
|
||||
non_default_params = {
|
||||
|
@ -3292,7 +3310,7 @@ def embedding(
|
|||
"cooldown_time": cooldown_time,
|
||||
},
|
||||
)
|
||||
if azure == True or custom_llm_provider == "azure":
|
||||
if azure is True or custom_llm_provider == "azure":
|
||||
# azure configs
|
||||
api_type = get_secret("AZURE_API_TYPE") or "azure"
|
||||
|
||||
|
@ -3398,12 +3416,18 @@ def embedding(
|
|||
or get_secret("CO_API_KEY")
|
||||
or litellm.api_key
|
||||
)
|
||||
|
||||
if extra_headers is not None and isinstance(extra_headers, dict):
|
||||
headers = extra_headers
|
||||
else:
|
||||
headers = {}
|
||||
response = cohere.embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
api_key=cohere_key, # type: ignore
|
||||
headers=headers,
|
||||
logging_obj=logging,
|
||||
model_response=EmbeddingResponse(),
|
||||
aembedding=aembedding,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue