mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
Merge branch 'main' into litellm_allow_using_azure_ad_token_auth
This commit is contained in:
commit
228252b92d
33 changed files with 802 additions and 84 deletions
|
@ -944,6 +944,8 @@ def completion(
|
|||
cooldown_time=cooldown_time,
|
||||
text_completion=kwargs.get("text_completion"),
|
||||
azure_ad_token_provider=kwargs.get("azure_ad_token_provider"),
|
||||
user_continue_message=kwargs.get("user_continue_message"),
|
||||
|
||||
)
|
||||
logging.update_environment_variables(
|
||||
model=model,
|
||||
|
@ -1635,6 +1637,13 @@ def completion(
|
|||
or "https://api.cohere.ai/v1/generate"
|
||||
)
|
||||
|
||||
headers = headers or litellm.headers or {}
|
||||
if headers is None:
|
||||
headers = {}
|
||||
|
||||
if extra_headers is not None:
|
||||
headers.update(extra_headers)
|
||||
|
||||
model_response = cohere.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -1645,6 +1654,7 @@ def completion(
|
|||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
headers=headers,
|
||||
api_key=cohere_key,
|
||||
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
||||
)
|
||||
|
@ -1675,6 +1685,13 @@ def completion(
|
|||
or "https://api.cohere.ai/v1/chat"
|
||||
)
|
||||
|
||||
headers = headers or litellm.headers or {}
|
||||
if headers is None:
|
||||
headers = {}
|
||||
|
||||
if extra_headers is not None:
|
||||
headers.update(extra_headers)
|
||||
|
||||
model_response = cohere_chat.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -1683,6 +1700,7 @@ def completion(
|
|||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
api_key=cohere_key,
|
||||
|
@ -2289,7 +2307,7 @@ def completion(
|
|||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
litellm_params=litellm_params, # type: ignore
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
|
@ -3159,6 +3177,7 @@ def embedding(
|
|||
encoding_format = kwargs.get("encoding_format", None)
|
||||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||
aembedding = kwargs.get("aembedding", None)
|
||||
extra_headers = kwargs.get("extra_headers", None)
|
||||
### CUSTOM MODEL COST ###
|
||||
input_cost_per_token = kwargs.get("input_cost_per_token", None)
|
||||
output_cost_per_token = kwargs.get("output_cost_per_token", None)
|
||||
|
@ -3234,6 +3253,7 @@ def embedding(
|
|||
"tenant_id",
|
||||
"client_id",
|
||||
"client_secret",
|
||||
"extra_headers",
|
||||
]
|
||||
default_params = openai_params + litellm_params
|
||||
non_default_params = {
|
||||
|
@ -3297,7 +3317,7 @@ def embedding(
|
|||
"cooldown_time": cooldown_time,
|
||||
},
|
||||
)
|
||||
if azure == True or custom_llm_provider == "azure":
|
||||
if azure is True or custom_llm_provider == "azure":
|
||||
# azure configs
|
||||
api_type = get_secret("AZURE_API_TYPE") or "azure"
|
||||
|
||||
|
@ -3403,12 +3423,18 @@ def embedding(
|
|||
or get_secret("CO_API_KEY")
|
||||
or litellm.api_key
|
||||
)
|
||||
|
||||
if extra_headers is not None and isinstance(extra_headers, dict):
|
||||
headers = extra_headers
|
||||
else:
|
||||
headers = {}
|
||||
response = cohere.embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
api_key=cohere_key, # type: ignore
|
||||
headers=headers,
|
||||
logging_obj=logging,
|
||||
model_response=EmbeddingResponse(),
|
||||
aembedding=aembedding,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue