From 4f183dc6a0b208b438e5119c14d31ead27faa9f6 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 23 Nov 2023 13:41:56 -0800 Subject: [PATCH] fix(utils.py): support reading api keys dynamically from the os environment --- litellm/main.py | 4 ++-- litellm/tests/test_completion.py | 20 +++++++++++--------- litellm/utils.py | 7 +++++-- 3 files changed, 18 insertions(+), 13 deletions(-) diff --git a/litellm/main.py b/litellm/main.py index 5b29e62a21..a6bb673b12 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -371,7 +371,7 @@ def completion( if deployment_id != None: # azure llms model=deployment_id custom_llm_provider="azure" - model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) + model, custom_llm_provider, api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base, api_key=api_key) custom_prompt_dict = {} # type: ignore if initial_prompt_value or roles or final_prompt_value or bos_token or eos_token: custom_prompt_dict = {model: {}} @@ -1709,7 +1709,7 @@ def embedding( - exception_type: If an exception occurs during the API call. """ azure = kwargs.get("azure", None) - model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) + model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base, api_key=api_key) try: response = None logging = litellm_logging_obj diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index f4c5a83f62..c3db66499a 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -183,7 +183,7 @@ def test_completion_perplexity_api(): except Exception as e: pytest.fail(f"Error occurred: {e}") -test_completion_perplexity_api() +# test_completion_perplexity_api() def test_completion_perplexity_api_2(): try: @@ -582,20 +582,22 @@ def test_completion_azure(): response = completion( model="azure/chatgpt-v-2", messages=messages, + api_key="os.environ/AZURE_API_KEY_NA" ) + print(f"response: {response}") ## Test azure flag for backwards compatibility - response = completion( - model="chatgpt-v-2", - messages=messages, - azure=True, - max_tokens=10 - ) + # response = completion( + # model="chatgpt-v-2", + # messages=messages, + # azure=True, + # max_tokens=10 + # ) # Add any assertions here to check the response print(response) except Exception as e: pytest.fail(f"Error occurred: {e}") -# test_completion_azure() +test_completion_azure() def test_azure_openai_ad_token(): # this tests if the azure ad token is set in the request header @@ -833,7 +835,7 @@ def test_completion_replicate_llama2_stream(): print(f"complete_response: {complete_response}") except Exception as e: pytest.fail(f"Error occurred: {e}") -test_completion_replicate_llama2_stream() +# test_completion_replicate_llama2_stream() # commenthing this out since we won't be always testing a custom replicate deployment # def test_completion_replicate_deployments(): diff --git a/litellm/utils.py b/litellm/utils.py index 96cf556be9..e9e649252a 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2294,14 +2294,17 @@ def get_optional_params( # use the openai defaults optional_params[k] = passed_params[k] return optional_params -def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_base: Optional[str] = None): +def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_base: Optional[str] = None, api_key: Optional[str] = None): try: dynamic_api_key = None # check if llm provider provided if custom_llm_provider: return model, custom_llm_provider, dynamic_api_key, api_base - + + if api_key and api_key.startswith("os.environ/"): + api_key_env_name = api_key.replace("os.environ/", "") + dynamic_api_key = os.getenv(api_key_env_name) # check if llm provider part of model name if model.split("/",1)[0] in litellm.provider_list and model.split("/",1)[0] not in litellm.model_list: custom_llm_provider = model.split("/", 1)[0]