with infisical managing keys

This commit is contained in:
ishaan-jaff 2023-08-05 12:52:11 -07:00
parent 7575d7ea47
commit 2ecd132a94
5 changed files with 69 additions and 14 deletions

View file

@ -26,6 +26,8 @@ MAX_TOKENS = {
####### PROXY PARAMS ################### configurable params if you use proxy models like Helicone
api_base = None
headers = None
####### Secret Manager #####################
secret_manager_client = None
####### COMPLETION MODELS ###################
open_ai_chat_completion_models = [
"gpt-4",

View file

@ -13,6 +13,7 @@ from tenacity import (
stop_after_attempt,
wait_random_exponential,
) # for exponential backoff
from litellm.utils import get_secret
####### ENVIRONMENT VARIABLES ###################
dotenv.load_dotenv() # Loading env variables using dotenv
@ -65,14 +66,14 @@ def completion(
if azure == True:
# azure configs
openai.api_type = "azure"
openai.api_base = litellm.api_base if litellm.api_base is not None else os.environ.get("AZURE_API_BASE")
openai.api_version = os.environ.get("AZURE_API_VERSION")
openai.api_base = litellm.api_base if litellm.api_base is not None else get_secret("AZURE_API_BASE")
openai.api_version = get_secret("AZURE_API_VERSION")
if api_key:
openai.api_key = api_key
elif litellm.azure_key:
openai.api_key = litellm.azure_key
else:
openai.api_key = os.environ.get("AZURE_API_KEY")
openai.api_key = get_secret("AZURE_API_KEY")
## LOGGING
logging(model=model, input=messages, azure=azure, logger_fn=logger_fn)
## COMPLETION CALL
@ -98,7 +99,7 @@ def completion(
elif litellm.openai_key:
openai.api_key = litellm.openai_key
else:
openai.api_key = os.environ.get("OPENAI_API_KEY")
openai.api_key = get_secret("OPENAI_API_KEY")
## LOGGING
logging(model=model, input=messages, azure=azure, logger_fn=logger_fn)
## COMPLETION CALL
@ -124,7 +125,7 @@ def completion(
elif litellm.openai_key:
openai.api_key = litellm.openai_key
else:
openai.api_key = os.environ.get("OPENAI_API_KEY")
openai.api_key = get_secret("OPENAI_API_KEY")
prompt = " ".join([message["content"] for message in messages])
## LOGGING
logging(model=model, input=prompt, azure=azure, logger_fn=logger_fn)
@ -152,8 +153,8 @@ def completion(
elif "replicate" in model:
# replicate defaults to os.environ.get("REPLICATE_API_TOKEN")
# checking in case user set it to REPLICATE_API_KEY instead
if not os.environ.get("REPLICATE_API_TOKEN") and os.environ.get("REPLICATE_API_KEY"):
replicate_api_token = os.environ.get("REPLICATE_API_KEY")
if not get_secret("REPLICATE_API_TOKEN") and get_secret("REPLICATE_API_KEY"):
replicate_api_token = get_secret("REPLICATE_API_KEY")
os.environ["REPLICATE_API_TOKEN"] = replicate_api_token
elif api_key:
os.environ["REPLICATE_API_TOKEN"] = api_key
@ -240,7 +241,7 @@ def completion(
elif litellm.cohere_key:
cohere_key = litellm.cohere_key
else:
cohere_key = os.environ.get("COHERE_API_KEY")
cohere_key = get_secret("COHERE_API_KEY")
co = cohere.Client(cohere_key)
prompt = " ".join([message["content"] for message in messages])
## LOGGING
@ -286,9 +287,9 @@ def embedding(model, input=[], azure=False, force_timeout=60, logger_fn=None):
if azure == True:
# azure configs
openai.api_type = "azure"
openai.api_base = os.environ.get("AZURE_API_BASE")
openai.api_version = os.environ.get("AZURE_API_VERSION")
openai.api_key = os.environ.get("AZURE_API_KEY")
openai.api_base = get_secret("AZURE_API_BASE")
openai.api_version = get_secret("AZURE_API_VERSION")
openai.api_key = get_secret("AZURE_API_KEY")
## LOGGING
logging(model=model, input=input, azure=azure, logger_fn=logger_fn)
## EMBEDDING CALL
@ -298,7 +299,7 @@ def embedding(model, input=[], azure=False, force_timeout=60, logger_fn=None):
openai.api_type = "openai"
openai.api_base = "https://api.openai.com/v1"
openai.api_version = None
openai.api_key = os.environ.get("OPENAI_API_KEY")
openai.api_key = get_secret("OPENAI_API_KEY")
## LOGGING
logging(model=model, input=input, azure=azure, logger_fn=logger_fn)
## EMBEDDING CALL

View file

@ -0,0 +1,34 @@
#### What this tests ####
# This tests error logging (with custom user functions) for the `completion` + `embedding` endpoints without callbacks (i.e. slack, posthog, etc. not set)
# Requirements: Remove any env keys you have related to slack/posthog/etc. + anthropic api key (cause an exception)
import sys, os
import traceback
sys.path.insert(0, os.path.abspath('../..')) # Adds the parent directory to the system path
import litellm
from litellm import embedding, completion
from infisical import InfisicalClient
import pytest
infisical_token = os.environ["INFISICAL_TOKEN"]
litellm.secret_manager_client = InfisicalClient(token=infisical_token)
user_message = "Hello, whats the weather in San Francisco??"
messages = [{ "content": user_message,"role": "user"}]
def test_completion_openai():
try:
response = completion(model="gpt-3.5-turbo", messages=messages)
# Add any assertions here to check the response
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_completion_openai_with_optional_params():
try:
response = completion(model="gpt-3.5-turbo", messages=messages, temperature=0.5, top_p=0.1, user="ishaan_dev@berri.ai")
# Add any assertions here to check the response
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -403,4 +403,21 @@ def litellm_telemetry(data):
response.raise_for_status() # Raise an exception for HTTP errors
except requests.exceptions.RequestException as e:
# Handle any errors in the request
pass
pass
######### Secret Manager ############################
# checks if user has passed in a secret manager client
# if passed in then checks the secret there
def get_secret(secret_name):
if litellm.secret_manager_client != None:
# TODO: check which secret manager is being used
# currently only supports Infisical
secret = litellm.secret_manager_client.get_secret(secret_name).secret_value
if secret != None:
# if secret manager fails default to using .env variables
os.environ[secret_name] = secret # set to env to be safe
return secret
else:
return os.environ.get(secret_name)
else:
return os.environ.get(secret_name)

View file

@ -6,4 +6,5 @@ pytest
python-dotenv
openai[datalib]
tenacity
tiktoken
tiktoken
infisical