Merge branch 'main' into litellm_allow_using_azure_ad_token_auth

This commit is contained in:
Ishaan Jaff 2024-08-22 18:21:24 -07:00 committed by GitHub
commit 228252b92d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
33 changed files with 802 additions and 84 deletions

View file

@ -944,6 +944,8 @@ def completion(
cooldown_time=cooldown_time,
text_completion=kwargs.get("text_completion"),
azure_ad_token_provider=kwargs.get("azure_ad_token_provider"),
user_continue_message=kwargs.get("user_continue_message"),
)
logging.update_environment_variables(
model=model,
@ -1635,6 +1637,13 @@ def completion(
or "https://api.cohere.ai/v1/generate"
)
headers = headers or litellm.headers or {}
if headers is None:
headers = {}
if extra_headers is not None:
headers.update(extra_headers)
model_response = cohere.completion(
model=model,
messages=messages,
@ -1645,6 +1654,7 @@ def completion(
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
headers=headers,
api_key=cohere_key,
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
)
@ -1675,6 +1685,13 @@ def completion(
or "https://api.cohere.ai/v1/chat"
)
headers = headers or litellm.headers or {}
if headers is None:
headers = {}
if extra_headers is not None:
headers.update(extra_headers)
model_response = cohere_chat.completion(
model=model,
messages=messages,
@ -1683,6 +1700,7 @@ def completion(
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
logger_fn=logger_fn,
encoding=encoding,
api_key=cohere_key,
@ -2289,7 +2307,7 @@ def completion(
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
litellm_params=litellm_params, # type: ignore
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
@ -3159,6 +3177,7 @@ def embedding(
encoding_format = kwargs.get("encoding_format", None)
proxy_server_request = kwargs.get("proxy_server_request", None)
aembedding = kwargs.get("aembedding", None)
extra_headers = kwargs.get("extra_headers", None)
### CUSTOM MODEL COST ###
input_cost_per_token = kwargs.get("input_cost_per_token", None)
output_cost_per_token = kwargs.get("output_cost_per_token", None)
@ -3234,6 +3253,7 @@ def embedding(
"tenant_id",
"client_id",
"client_secret",
"extra_headers",
]
default_params = openai_params + litellm_params
non_default_params = {
@ -3297,7 +3317,7 @@ def embedding(
"cooldown_time": cooldown_time,
},
)
if azure == True or custom_llm_provider == "azure":
if azure is True or custom_llm_provider == "azure":
# azure configs
api_type = get_secret("AZURE_API_TYPE") or "azure"
@ -3403,12 +3423,18 @@ def embedding(
or get_secret("CO_API_KEY")
or litellm.api_key
)
if extra_headers is not None and isinstance(extra_headers, dict):
headers = extra_headers
else:
headers = {}
response = cohere.embedding(
model=model,
input=input,
optional_params=optional_params,
encoding=encoding,
api_key=cohere_key, # type: ignore
headers=headers,
logging_obj=logging,
model_response=EmbeddingResponse(),
aembedding=aembedding,