From 583a3b330d566d8f6c7a8c079f0cfecdcc6978ed Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 14 Aug 2024 13:41:04 -0700 Subject: [PATCH] fix(utils.py): support calling openai models via `azure_ai/` --- litellm/main.py | 8 ++++++-- litellm/proxy/_new_secret_config.yaml | 10 +++++----- litellm/tests/test_completion.py | 25 +++++++++++++++++++++++++ litellm/utils.py | 20 +++++++++++++++++++- 4 files changed, 55 insertions(+), 8 deletions(-) diff --git a/litellm/main.py b/litellm/main.py index d7a3ca996d..7be4798574 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -4898,7 +4898,6 @@ async def ahealth_check( verbose_logger.error( "litellm.ahealth_check(): Exception occured - {}".format(str(e)) ) - verbose_logger.debug(traceback.format_exc()) stack_trace = traceback.format_exc() if isinstance(stack_trace, str): stack_trace = stack_trace[:1000] @@ -4907,7 +4906,12 @@ async def ahealth_check( "error": "Missing `mode`. Set the `mode` for the model - https://docs.litellm.ai/docs/proxy/health#embedding-models" } - error_to_return = str(e) + " stack trace: " + stack_trace + error_to_return = ( + str(e) + + "\nHave you set 'mode' - https://docs.litellm.ai/docs/proxy/health#embedding-models" + + "\nstack trace: " + + stack_trace + ) return {"error": error_to_return} diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 87a561e318..41b2a66c01 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,7 +1,7 @@ model_list: - - model_name: "*" + - model_name: azure-embedding-model litellm_params: - model: "*" - -litellm_settings: - success_callback: ["langsmith"] \ No newline at end of file + model: azure/azure-embedding-model + api_base: os.environ/AZURE_API_BASE + api_key: os.environ/AZURE_API_KEY + api_version: "2023-07-01-preview" diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 4ea9ee3b0f..033b4431fa 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -190,6 +190,31 @@ def test_completion_azure_command_r(): pytest.fail(f"Error occurred: {e}") +@pytest.mark.parametrize( + "api_base", + [ + "https://litellm8397336933.openai.azure.com", + "https://litellm8397336933.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2023-03-15-preview", + ], +) +def test_completion_azure_ai_gpt_4o(api_base): + try: + litellm.set_verbose = True + + response = completion( + model="azure_ai/gpt-4o", + api_base=api_base, + api_key=os.getenv("AZURE_AI_OPENAI_KEY"), + messages=[{"role": "user", "content": "What is the meaning of life?"}], + ) + + print(response) + except litellm.Timeout as e: + pass + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + @pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.asyncio async def test_completion_databricks(sync_mode): diff --git a/litellm/utils.py b/litellm/utils.py index 49528d0f77..4c5fc6fd48 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4479,7 +4479,22 @@ def _is_non_openai_azure_model(model: str) -> bool: or f"mistral/{model_name}" in litellm.mistral_chat_models ): return True - except: + except Exception: + return False + return False + + +def _is_azure_openai_model(model: str) -> bool: + try: + if "/" in model: + model = model.split("/", 1)[1] + if ( + model in litellm.open_ai_chat_completion_models + or model in litellm.open_ai_text_completion_models + or litellm.open_ai_embedding_models + ): + return True + except Exception: return False return False @@ -4613,6 +4628,9 @@ def get_llm_provider( elif custom_llm_provider == "azure_ai": api_base = api_base or get_secret("AZURE_AI_API_BASE") # type: ignore dynamic_api_key = api_key or get_secret("AZURE_AI_API_KEY") + + if _is_azure_openai_model(model=model): + custom_llm_provider = "azure" elif custom_llm_provider == "github": api_base = api_base or get_secret("GITHUB_API_BASE") or "https://models.inference.ai.azure.com" # type: ignore dynamic_api_key = api_key or get_secret("GITHUB_API_KEY")