refactor setting keys

This commit is contained in:
ishaan-jaff 2023-08-17 15:14:25 -07:00
parent 04c55b8f28
commit 3f5e47e3ce
2 changed files with 15 additions and 32 deletions

View file

@ -66,12 +66,7 @@ def completion(
openai.api_base = litellm.api_base if litellm.api_base is not None else get_secret("AZURE_API_BASE")
openai.api_version = litellm.api_version if litellm.api_version is not None else get_secret("AZURE_API_VERSION")
# set key
if api_key:
openai.api_key = api_key
elif litellm.azure_key:
openai.api_key = litellm.azure_key
else:
openai.api_key = get_secret("AZURE_API_KEY")
openai.api_key = api_key or litellm.azure_key or get_secret("AZURE_API_KEY")
## LOGGING
logging(model=model, input=messages, additional_args=optional_params, custom_llm_provider=custom_llm_provider, logger_fn=logger_fn)
## COMPLETION CALL
@ -96,12 +91,9 @@ def completion(
openai.api_version = None
if litellm.organization:
openai.organization = litellm.organization
if api_key:
openai.api_key = api_key
elif litellm.openai_key:
openai.api_key = litellm.openai_key
else:
openai.api_key = get_secret("OPENAI_API_KEY")
# set API KEY
openai.api_key = api_key or litellm.openai_key or get_secret("OPENAI_API_KEY")
## LOGGING
logging(model=model, input=messages, additional_args=args, custom_llm_provider=custom_llm_provider, logger_fn=logger_fn)
## COMPLETION CALL
@ -122,12 +114,7 @@ def completion(
openai.api_type = "openai"
openai.api_base = litellm.api_base if litellm.api_base is not None else "https://api.openai.com/v1"
openai.api_version = None
if api_key:
openai.api_key = api_key
elif litellm.openai_key:
openai.api_key = litellm.openai_key
else:
openai.api_key = get_secret("OPENAI_API_KEY")
openai.api_key = api_key or litellm.openai_key or get_secret("OPENAI_API_KEY")
if litellm.organization:
openai.organization = litellm.organization
prompt = " ".join([message["content"] for message in messages])
@ -158,15 +145,13 @@ def completion(
# import replicate/if it fails then pip install replicate
install_and_import("replicate")
import replicate
# replicate defaults to os.environ.get("REPLICATE_API_TOKEN")
# checking in case user set it to REPLICATE_API_KEY instead
if not get_secret("REPLICATE_API_TOKEN") and get_secret("REPLICATE_API_KEY"):
replicate_api_token = get_secret("REPLICATE_API_KEY")
os.environ["REPLICATE_API_TOKEN"] = replicate_api_token
elif api_key:
os.environ["REPLICATE_API_TOKEN"] = api_key
elif litellm.replicate_key:
os.environ["REPLICATE_API_TOKEN"] = litellm.replicate_key
# Setting the relevant API KEY for replicate, replicate defaults to using os.environ.get("REPLICATE_API_TOKEN")
replicate_key = os.environ.get("REPLICATE_API_TOKEN")
if replicate_key == None:
# user did not set REPLICATE_API_TOKEN in .env
replicate_key = get_secret("REPLICATE_API_KEY") or get_secret("REPLICATE_API_TOKEN") or api_key or litellm.replicate_key
# set replicate kye
os.environ["REPLICATE_API_TOKEN"] = replicate_key
prompt = " ".join([message["content"] for message in messages])
input = {"prompt": prompt}
if "max_tokens" in optional_params:
@ -304,7 +289,7 @@ def completion(
response = model_response
elif custom_llm_provider == "together_ai":
import requests
TOGETHER_AI_TOKEN = get_secret("TOGETHER_AI_TOKEN")
TOGETHER_AI_TOKEN = get_secret("TOGETHER_AI_TOKEN") or get_secret("TOGETHERAI_API_KEY")
headers = {"Authorization": f"Bearer {TOGETHER_AI_TOKEN}"}
endpoint = 'https://api.together.xyz/inference'
prompt = " ".join([message["content"] for message in messages]) # TODO: Add chat support for together AI

View file

@ -833,10 +833,8 @@ def get_secret(secret_name):
# TODO: check which secret manager is being used
# currently only supports Infisical
secret = litellm.secret_manager_client.get_secret(secret_name).secret_value
if secret != None: # failsafe when secret manager fails
# if secret manager fails default to using .env variables
os.environ[secret_name] = secret # set to env to be safe
return secret
if secret != None:
return secret # if secret found in secret manager return it
elif litellm.api_key != None: # if users use litellm default key
return litellm.api_key
else: