with infisical managing keys

This commit is contained in:
ishaan-jaff 2023-08-05 12:52:11 -07:00
parent 7575d7ea47
commit 2ecd132a94
5 changed files with 69 additions and 14 deletions

View file

@ -13,6 +13,7 @@ from tenacity import (
stop_after_attempt,
wait_random_exponential,
) # for exponential backoff
from litellm.utils import get_secret
####### ENVIRONMENT VARIABLES ###################
dotenv.load_dotenv() # Loading env variables using dotenv
@ -65,14 +66,14 @@ def completion(
if azure == True:
# azure configs
openai.api_type = "azure"
openai.api_base = litellm.api_base if litellm.api_base is not None else os.environ.get("AZURE_API_BASE")
openai.api_version = os.environ.get("AZURE_API_VERSION")
openai.api_base = litellm.api_base if litellm.api_base is not None else get_secret("AZURE_API_BASE")
openai.api_version = get_secret("AZURE_API_VERSION")
if api_key:
openai.api_key = api_key
elif litellm.azure_key:
openai.api_key = litellm.azure_key
else:
openai.api_key = os.environ.get("AZURE_API_KEY")
openai.api_key = get_secret("AZURE_API_KEY")
## LOGGING
logging(model=model, input=messages, azure=azure, logger_fn=logger_fn)
## COMPLETION CALL
@ -98,7 +99,7 @@ def completion(
elif litellm.openai_key:
openai.api_key = litellm.openai_key
else:
openai.api_key = os.environ.get("OPENAI_API_KEY")
openai.api_key = get_secret("OPENAI_API_KEY")
## LOGGING
logging(model=model, input=messages, azure=azure, logger_fn=logger_fn)
## COMPLETION CALL
@ -124,7 +125,7 @@ def completion(
elif litellm.openai_key:
openai.api_key = litellm.openai_key
else:
openai.api_key = os.environ.get("OPENAI_API_KEY")
openai.api_key = get_secret("OPENAI_API_KEY")
prompt = " ".join([message["content"] for message in messages])
## LOGGING
logging(model=model, input=prompt, azure=azure, logger_fn=logger_fn)
@ -152,8 +153,8 @@ def completion(
elif "replicate" in model:
# replicate defaults to os.environ.get("REPLICATE_API_TOKEN")
# checking in case user set it to REPLICATE_API_KEY instead
if not os.environ.get("REPLICATE_API_TOKEN") and os.environ.get("REPLICATE_API_KEY"):
replicate_api_token = os.environ.get("REPLICATE_API_KEY")
if not get_secret("REPLICATE_API_TOKEN") and get_secret("REPLICATE_API_KEY"):
replicate_api_token = get_secret("REPLICATE_API_KEY")
os.environ["REPLICATE_API_TOKEN"] = replicate_api_token
elif api_key:
os.environ["REPLICATE_API_TOKEN"] = api_key
@ -240,7 +241,7 @@ def completion(
elif litellm.cohere_key:
cohere_key = litellm.cohere_key
else:
cohere_key = os.environ.get("COHERE_API_KEY")
cohere_key = get_secret("COHERE_API_KEY")
co = cohere.Client(cohere_key)
prompt = " ".join([message["content"] for message in messages])
## LOGGING
@ -286,9 +287,9 @@ def embedding(model, input=[], azure=False, force_timeout=60, logger_fn=None):
if azure == True:
# azure configs
openai.api_type = "azure"
openai.api_base = os.environ.get("AZURE_API_BASE")
openai.api_version = os.environ.get("AZURE_API_VERSION")
openai.api_key = os.environ.get("AZURE_API_KEY")
openai.api_base = get_secret("AZURE_API_BASE")
openai.api_version = get_secret("AZURE_API_VERSION")
openai.api_key = get_secret("AZURE_API_KEY")
## LOGGING
logging(model=model, input=input, azure=azure, logger_fn=logger_fn)
## EMBEDDING CALL
@ -298,7 +299,7 @@ def embedding(model, input=[], azure=False, force_timeout=60, logger_fn=None):
openai.api_type = "openai"
openai.api_base = "https://api.openai.com/v1"
openai.api_version = None
openai.api_key = os.environ.get("OPENAI_API_KEY")
openai.api_key = get_secret("OPENAI_API_KEY")
## LOGGING
logging(model=model, input=input, azure=azure, logger_fn=logger_fn)
## EMBEDDING CALL