mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix(cohere_chat.py): support passing 'extra_headers'
Fixes https://github.com/BerriAI/litellm/issues/4709
This commit is contained in:
parent
36c021b309
commit
11bfc1dca7
5 changed files with 50 additions and 16 deletions
|
@ -124,12 +124,14 @@ class CohereConfig:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def validate_environment(api_key):
|
def validate_environment(api_key, headers: dict):
|
||||||
headers = {
|
headers.update(
|
||||||
"Request-Source": "unspecified:litellm",
|
{
|
||||||
"accept": "application/json",
|
"Request-Source": "unspecified:litellm",
|
||||||
"content-type": "application/json",
|
"accept": "application/json",
|
||||||
}
|
"content-type": "application/json",
|
||||||
|
}
|
||||||
|
)
|
||||||
if api_key:
|
if api_key:
|
||||||
headers["Authorization"] = f"Bearer {api_key}"
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
return headers
|
return headers
|
||||||
|
@ -144,11 +146,12 @@ def completion(
|
||||||
encoding,
|
encoding,
|
||||||
api_key,
|
api_key,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
|
headers: dict,
|
||||||
optional_params=None,
|
optional_params=None,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
):
|
):
|
||||||
headers = validate_environment(api_key)
|
headers = validate_environment(api_key, headers=headers)
|
||||||
completion_url = api_base
|
completion_url = api_base
|
||||||
model = model
|
model = model
|
||||||
prompt = " ".join(message["content"] for message in messages)
|
prompt = " ".join(message["content"] for message in messages)
|
||||||
|
@ -338,13 +341,14 @@ def embedding(
|
||||||
model_response: litellm.EmbeddingResponse,
|
model_response: litellm.EmbeddingResponse,
|
||||||
logging_obj: LiteLLMLoggingObj,
|
logging_obj: LiteLLMLoggingObj,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
headers: dict,
|
||||||
encoding: Any,
|
encoding: Any,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
aembedding: Optional[bool] = None,
|
aembedding: Optional[bool] = None,
|
||||||
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
|
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
|
||||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||||
):
|
):
|
||||||
headers = validate_environment(api_key)
|
headers = validate_environment(api_key, headers=headers)
|
||||||
embed_url = "https://api.cohere.ai/v1/embed"
|
embed_url = "https://api.cohere.ai/v1/embed"
|
||||||
model = model
|
model = model
|
||||||
data = {"model": model, "texts": input, **optional_params}
|
data = {"model": model, "texts": input, **optional_params}
|
||||||
|
|
|
@ -116,12 +116,14 @@ class CohereChatConfig:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def validate_environment(api_key):
|
def validate_environment(api_key, headers: dict):
|
||||||
headers = {
|
headers.update(
|
||||||
"Request-Source": "unspecified:litellm",
|
{
|
||||||
"accept": "application/json",
|
"Request-Source": "unspecified:litellm",
|
||||||
"content-type": "application/json",
|
"accept": "application/json",
|
||||||
}
|
"content-type": "application/json",
|
||||||
|
}
|
||||||
|
)
|
||||||
if api_key:
|
if api_key:
|
||||||
headers["Authorization"] = f"Bearer {api_key}"
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
return headers
|
return headers
|
||||||
|
@ -203,13 +205,14 @@ def completion(
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
print_verbose: Callable,
|
print_verbose: Callable,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
headers: dict,
|
||||||
encoding,
|
encoding,
|
||||||
api_key,
|
api_key,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
):
|
):
|
||||||
headers = validate_environment(api_key)
|
headers = validate_environment(api_key, headers=headers)
|
||||||
completion_url = api_base
|
completion_url = api_base
|
||||||
model = model
|
model = model
|
||||||
most_recent_message, chat_history = cohere_messages_pt_v2(
|
most_recent_message, chat_history = cohere_messages_pt_v2(
|
||||||
|
|
|
@ -1634,6 +1634,13 @@ def completion(
|
||||||
or "https://api.cohere.ai/v1/generate"
|
or "https://api.cohere.ai/v1/generate"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
headers = headers or litellm.headers or {}
|
||||||
|
if headers is None:
|
||||||
|
headers = {}
|
||||||
|
|
||||||
|
if extra_headers is not None:
|
||||||
|
headers.update(extra_headers)
|
||||||
|
|
||||||
model_response = cohere.completion(
|
model_response = cohere.completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -1644,6 +1651,7 @@ def completion(
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
|
headers=headers,
|
||||||
api_key=cohere_key,
|
api_key=cohere_key,
|
||||||
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
||||||
)
|
)
|
||||||
|
@ -1674,6 +1682,13 @@ def completion(
|
||||||
or "https://api.cohere.ai/v1/chat"
|
or "https://api.cohere.ai/v1/chat"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
headers = headers or litellm.headers or {}
|
||||||
|
if headers is None:
|
||||||
|
headers = {}
|
||||||
|
|
||||||
|
if extra_headers is not None:
|
||||||
|
headers.update(extra_headers)
|
||||||
|
|
||||||
model_response = cohere_chat.completion(
|
model_response = cohere_chat.completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -1682,6 +1697,7 @@ def completion(
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
|
headers=headers,
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
api_key=cohere_key,
|
api_key=cohere_key,
|
||||||
|
@ -3158,6 +3174,7 @@ def embedding(
|
||||||
encoding_format = kwargs.get("encoding_format", None)
|
encoding_format = kwargs.get("encoding_format", None)
|
||||||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||||
aembedding = kwargs.get("aembedding", None)
|
aembedding = kwargs.get("aembedding", None)
|
||||||
|
extra_headers = kwargs.get("extra_headers", None)
|
||||||
### CUSTOM MODEL COST ###
|
### CUSTOM MODEL COST ###
|
||||||
input_cost_per_token = kwargs.get("input_cost_per_token", None)
|
input_cost_per_token = kwargs.get("input_cost_per_token", None)
|
||||||
output_cost_per_token = kwargs.get("output_cost_per_token", None)
|
output_cost_per_token = kwargs.get("output_cost_per_token", None)
|
||||||
|
@ -3229,6 +3246,7 @@ def embedding(
|
||||||
"model_config",
|
"model_config",
|
||||||
"cooldown_time",
|
"cooldown_time",
|
||||||
"tags",
|
"tags",
|
||||||
|
"extra_headers",
|
||||||
]
|
]
|
||||||
default_params = openai_params + litellm_params
|
default_params = openai_params + litellm_params
|
||||||
non_default_params = {
|
non_default_params = {
|
||||||
|
@ -3292,7 +3310,7 @@ def embedding(
|
||||||
"cooldown_time": cooldown_time,
|
"cooldown_time": cooldown_time,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if azure == True or custom_llm_provider == "azure":
|
if azure is True or custom_llm_provider == "azure":
|
||||||
# azure configs
|
# azure configs
|
||||||
api_type = get_secret("AZURE_API_TYPE") or "azure"
|
api_type = get_secret("AZURE_API_TYPE") or "azure"
|
||||||
|
|
||||||
|
@ -3398,12 +3416,18 @@ def embedding(
|
||||||
or get_secret("CO_API_KEY")
|
or get_secret("CO_API_KEY")
|
||||||
or litellm.api_key
|
or litellm.api_key
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if extra_headers is not None and isinstance(extra_headers, dict):
|
||||||
|
headers = extra_headers
|
||||||
|
else:
|
||||||
|
headers = {}
|
||||||
response = cohere.embedding(
|
response = cohere.embedding(
|
||||||
model=model,
|
model=model,
|
||||||
input=input,
|
input=input,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
api_key=cohere_key, # type: ignore
|
api_key=cohere_key, # type: ignore
|
||||||
|
headers=headers,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
model_response=EmbeddingResponse(),
|
model_response=EmbeddingResponse(),
|
||||||
aembedding=aembedding,
|
aembedding=aembedding,
|
||||||
|
|
|
@ -3653,6 +3653,7 @@ def test_completion_cohere():
|
||||||
response = completion(
|
response = completion(
|
||||||
model="command-r",
|
model="command-r",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
extra_headers={"Helicone-Property-Locale": "ko"},
|
||||||
)
|
)
|
||||||
print(response)
|
print(response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -4219,6 +4219,7 @@ def get_supported_openai_params(
|
||||||
"presence_penalty",
|
"presence_penalty",
|
||||||
"stop",
|
"stop",
|
||||||
"n",
|
"n",
|
||||||
|
"extra_headers",
|
||||||
]
|
]
|
||||||
elif custom_llm_provider == "cohere_chat":
|
elif custom_llm_provider == "cohere_chat":
|
||||||
return [
|
return [
|
||||||
|
@ -4233,6 +4234,7 @@ def get_supported_openai_params(
|
||||||
"tools",
|
"tools",
|
||||||
"tool_choice",
|
"tool_choice",
|
||||||
"seed",
|
"seed",
|
||||||
|
"extra_headers",
|
||||||
]
|
]
|
||||||
elif custom_llm_provider == "maritalk":
|
elif custom_llm_provider == "maritalk":
|
||||||
return [
|
return [
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue