From 2ecd132a945142a32a6c110abe9b4694fc3a08a0 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Sat, 5 Aug 2023 12:52:11 -0700 Subject: [PATCH] with infisical managing keys --- litellm/__init__.py | 2 ++ litellm/main.py | 25 +++++++++++++------------ litellm/tests/test_secrets.py | 34 ++++++++++++++++++++++++++++++++++ litellm/utils.py | 19 ++++++++++++++++++- requirements.txt | 3 ++- 5 files changed, 69 insertions(+), 14 deletions(-) create mode 100644 litellm/tests/test_secrets.py diff --git a/litellm/__init__.py b/litellm/__init__.py index 57ea54918..2d8bbff11 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -26,6 +26,8 @@ MAX_TOKENS = { ####### PROXY PARAMS ################### configurable params if you use proxy models like Helicone api_base = None headers = None +####### Secret Manager ##################### +secret_manager_client = None ####### COMPLETION MODELS ################### open_ai_chat_completion_models = [ "gpt-4", diff --git a/litellm/main.py b/litellm/main.py index a51498140..114972a8e 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -13,6 +13,7 @@ from tenacity import ( stop_after_attempt, wait_random_exponential, ) # for exponential backoff +from litellm.utils import get_secret ####### ENVIRONMENT VARIABLES ################### dotenv.load_dotenv() # Loading env variables using dotenv @@ -65,14 +66,14 @@ def completion( if azure == True: # azure configs openai.api_type = "azure" - openai.api_base = litellm.api_base if litellm.api_base is not None else os.environ.get("AZURE_API_BASE") - openai.api_version = os.environ.get("AZURE_API_VERSION") + openai.api_base = litellm.api_base if litellm.api_base is not None else get_secret("AZURE_API_BASE") + openai.api_version = get_secret("AZURE_API_VERSION") if api_key: openai.api_key = api_key elif litellm.azure_key: openai.api_key = litellm.azure_key else: - openai.api_key = os.environ.get("AZURE_API_KEY") + openai.api_key = get_secret("AZURE_API_KEY") ## LOGGING logging(model=model, input=messages, azure=azure, logger_fn=logger_fn) ## COMPLETION CALL @@ -98,7 +99,7 @@ def completion( elif litellm.openai_key: openai.api_key = litellm.openai_key else: - openai.api_key = os.environ.get("OPENAI_API_KEY") + openai.api_key = get_secret("OPENAI_API_KEY") ## LOGGING logging(model=model, input=messages, azure=azure, logger_fn=logger_fn) ## COMPLETION CALL @@ -124,7 +125,7 @@ def completion( elif litellm.openai_key: openai.api_key = litellm.openai_key else: - openai.api_key = os.environ.get("OPENAI_API_KEY") + openai.api_key = get_secret("OPENAI_API_KEY") prompt = " ".join([message["content"] for message in messages]) ## LOGGING logging(model=model, input=prompt, azure=azure, logger_fn=logger_fn) @@ -152,8 +153,8 @@ def completion( elif "replicate" in model: # replicate defaults to os.environ.get("REPLICATE_API_TOKEN") # checking in case user set it to REPLICATE_API_KEY instead - if not os.environ.get("REPLICATE_API_TOKEN") and os.environ.get("REPLICATE_API_KEY"): - replicate_api_token = os.environ.get("REPLICATE_API_KEY") + if not get_secret("REPLICATE_API_TOKEN") and get_secret("REPLICATE_API_KEY"): + replicate_api_token = get_secret("REPLICATE_API_KEY") os.environ["REPLICATE_API_TOKEN"] = replicate_api_token elif api_key: os.environ["REPLICATE_API_TOKEN"] = api_key @@ -240,7 +241,7 @@ def completion( elif litellm.cohere_key: cohere_key = litellm.cohere_key else: - cohere_key = os.environ.get("COHERE_API_KEY") + cohere_key = get_secret("COHERE_API_KEY") co = cohere.Client(cohere_key) prompt = " ".join([message["content"] for message in messages]) ## LOGGING @@ -286,9 +287,9 @@ def embedding(model, input=[], azure=False, force_timeout=60, logger_fn=None): if azure == True: # azure configs openai.api_type = "azure" - openai.api_base = os.environ.get("AZURE_API_BASE") - openai.api_version = os.environ.get("AZURE_API_VERSION") - openai.api_key = os.environ.get("AZURE_API_KEY") + openai.api_base = get_secret("AZURE_API_BASE") + openai.api_version = get_secret("AZURE_API_VERSION") + openai.api_key = get_secret("AZURE_API_KEY") ## LOGGING logging(model=model, input=input, azure=azure, logger_fn=logger_fn) ## EMBEDDING CALL @@ -298,7 +299,7 @@ def embedding(model, input=[], azure=False, force_timeout=60, logger_fn=None): openai.api_type = "openai" openai.api_base = "https://api.openai.com/v1" openai.api_version = None - openai.api_key = os.environ.get("OPENAI_API_KEY") + openai.api_key = get_secret("OPENAI_API_KEY") ## LOGGING logging(model=model, input=input, azure=azure, logger_fn=logger_fn) ## EMBEDDING CALL diff --git a/litellm/tests/test_secrets.py b/litellm/tests/test_secrets.py new file mode 100644 index 000000000..474d6bde0 --- /dev/null +++ b/litellm/tests/test_secrets.py @@ -0,0 +1,34 @@ +#### What this tests #### +# This tests error logging (with custom user functions) for the `completion` + `embedding` endpoints without callbacks (i.e. slack, posthog, etc. not set) +# Requirements: Remove any env keys you have related to slack/posthog/etc. + anthropic api key (cause an exception) + +import sys, os +import traceback +sys.path.insert(0, os.path.abspath('../..')) # Adds the parent directory to the system path +import litellm +from litellm import embedding, completion +from infisical import InfisicalClient +import pytest + +infisical_token = os.environ["INFISICAL_TOKEN"] + +litellm.secret_manager_client = InfisicalClient(token=infisical_token) + +user_message = "Hello, whats the weather in San Francisco??" +messages = [{ "content": user_message,"role": "user"}] + +def test_completion_openai(): + try: + response = completion(model="gpt-3.5-turbo", messages=messages) + # Add any assertions here to check the response + print(response) + except Exception as e: + pytest.fail(f"Error occurred: {e}") + +def test_completion_openai_with_optional_params(): + try: + response = completion(model="gpt-3.5-turbo", messages=messages, temperature=0.5, top_p=0.1, user="ishaan_dev@berri.ai") + # Add any assertions here to check the response + print(response) + except Exception as e: + pytest.fail(f"Error occurred: {e}") diff --git a/litellm/utils.py b/litellm/utils.py index 8e62d3470..4c9fe5463 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -403,4 +403,21 @@ def litellm_telemetry(data): response.raise_for_status() # Raise an exception for HTTP errors except requests.exceptions.RequestException as e: # Handle any errors in the request - pass \ No newline at end of file + pass + +######### Secret Manager ############################ +# checks if user has passed in a secret manager client +# if passed in then checks the secret there +def get_secret(secret_name): + if litellm.secret_manager_client != None: + # TODO: check which secret manager is being used + # currently only supports Infisical + secret = litellm.secret_manager_client.get_secret(secret_name).secret_value + if secret != None: + # if secret manager fails default to using .env variables + os.environ[secret_name] = secret # set to env to be safe + return secret + else: + return os.environ.get(secret_name) + else: + return os.environ.get(secret_name) diff --git a/requirements.txt b/requirements.txt index eca980d36..836f656e4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,5 @@ pytest python-dotenv openai[datalib] tenacity -tiktoken \ No newline at end of file +tiktoken +infisical \ No newline at end of file