forked from phoenix/litellm-mirror
fix(cohere_chat.py): support passing 'extra_headers'
Fixes https://github.com/BerriAI/litellm/issues/4709
This commit is contained in:
parent
36c021b309
commit
11bfc1dca7
5 changed files with 50 additions and 16 deletions
|
@ -124,12 +124,14 @@ class CohereConfig:
|
|||
}
|
||||
|
||||
|
||||
def validate_environment(api_key):
|
||||
headers = {
|
||||
"Request-Source": "unspecified:litellm",
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
def validate_environment(api_key, headers: dict):
|
||||
headers.update(
|
||||
{
|
||||
"Request-Source": "unspecified:litellm",
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
)
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
return headers
|
||||
|
@ -144,11 +146,12 @@ def completion(
|
|||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
headers: dict,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
):
|
||||
headers = validate_environment(api_key)
|
||||
headers = validate_environment(api_key, headers=headers)
|
||||
completion_url = api_base
|
||||
model = model
|
||||
prompt = " ".join(message["content"] for message in messages)
|
||||
|
@ -338,13 +341,14 @@ def embedding(
|
|||
model_response: litellm.EmbeddingResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
aembedding: Optional[bool] = None,
|
||||
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
):
|
||||
headers = validate_environment(api_key)
|
||||
headers = validate_environment(api_key, headers=headers)
|
||||
embed_url = "https://api.cohere.ai/v1/embed"
|
||||
model = model
|
||||
data = {"model": model, "texts": input, **optional_params}
|
||||
|
|
|
@ -116,12 +116,14 @@ class CohereChatConfig:
|
|||
}
|
||||
|
||||
|
||||
def validate_environment(api_key):
|
||||
headers = {
|
||||
"Request-Source": "unspecified:litellm",
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
def validate_environment(api_key, headers: dict):
|
||||
headers.update(
|
||||
{
|
||||
"Request-Source": "unspecified:litellm",
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
)
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
return headers
|
||||
|
@ -203,13 +205,14 @@ def completion(
|
|||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
):
|
||||
headers = validate_environment(api_key)
|
||||
headers = validate_environment(api_key, headers=headers)
|
||||
completion_url = api_base
|
||||
model = model
|
||||
most_recent_message, chat_history = cohere_messages_pt_v2(
|
||||
|
|
|
@ -1634,6 +1634,13 @@ def completion(
|
|||
or "https://api.cohere.ai/v1/generate"
|
||||
)
|
||||
|
||||
headers = headers or litellm.headers or {}
|
||||
if headers is None:
|
||||
headers = {}
|
||||
|
||||
if extra_headers is not None:
|
||||
headers.update(extra_headers)
|
||||
|
||||
model_response = cohere.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -1644,6 +1651,7 @@ def completion(
|
|||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
headers=headers,
|
||||
api_key=cohere_key,
|
||||
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
||||
)
|
||||
|
@ -1674,6 +1682,13 @@ def completion(
|
|||
or "https://api.cohere.ai/v1/chat"
|
||||
)
|
||||
|
||||
headers = headers or litellm.headers or {}
|
||||
if headers is None:
|
||||
headers = {}
|
||||
|
||||
if extra_headers is not None:
|
||||
headers.update(extra_headers)
|
||||
|
||||
model_response = cohere_chat.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -1682,6 +1697,7 @@ def completion(
|
|||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
api_key=cohere_key,
|
||||
|
@ -3158,6 +3174,7 @@ def embedding(
|
|||
encoding_format = kwargs.get("encoding_format", None)
|
||||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||
aembedding = kwargs.get("aembedding", None)
|
||||
extra_headers = kwargs.get("extra_headers", None)
|
||||
### CUSTOM MODEL COST ###
|
||||
input_cost_per_token = kwargs.get("input_cost_per_token", None)
|
||||
output_cost_per_token = kwargs.get("output_cost_per_token", None)
|
||||
|
@ -3229,6 +3246,7 @@ def embedding(
|
|||
"model_config",
|
||||
"cooldown_time",
|
||||
"tags",
|
||||
"extra_headers",
|
||||
]
|
||||
default_params = openai_params + litellm_params
|
||||
non_default_params = {
|
||||
|
@ -3292,7 +3310,7 @@ def embedding(
|
|||
"cooldown_time": cooldown_time,
|
||||
},
|
||||
)
|
||||
if azure == True or custom_llm_provider == "azure":
|
||||
if azure is True or custom_llm_provider == "azure":
|
||||
# azure configs
|
||||
api_type = get_secret("AZURE_API_TYPE") or "azure"
|
||||
|
||||
|
@ -3398,12 +3416,18 @@ def embedding(
|
|||
or get_secret("CO_API_KEY")
|
||||
or litellm.api_key
|
||||
)
|
||||
|
||||
if extra_headers is not None and isinstance(extra_headers, dict):
|
||||
headers = extra_headers
|
||||
else:
|
||||
headers = {}
|
||||
response = cohere.embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
api_key=cohere_key, # type: ignore
|
||||
headers=headers,
|
||||
logging_obj=logging,
|
||||
model_response=EmbeddingResponse(),
|
||||
aembedding=aembedding,
|
||||
|
|
|
@ -3653,6 +3653,7 @@ def test_completion_cohere():
|
|||
response = completion(
|
||||
model="command-r",
|
||||
messages=messages,
|
||||
extra_headers={"Helicone-Property-Locale": "ko"},
|
||||
)
|
||||
print(response)
|
||||
except Exception as e:
|
||||
|
|
|
@ -4219,6 +4219,7 @@ def get_supported_openai_params(
|
|||
"presence_penalty",
|
||||
"stop",
|
||||
"n",
|
||||
"extra_headers",
|
||||
]
|
||||
elif custom_llm_provider == "cohere_chat":
|
||||
return [
|
||||
|
@ -4233,6 +4234,7 @@ def get_supported_openai_params(
|
|||
"tools",
|
||||
"tool_choice",
|
||||
"seed",
|
||||
"extra_headers",
|
||||
]
|
||||
elif custom_llm_provider == "maritalk":
|
||||
return [
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue