From 4a27a50f9b52e7938499ec4f653c7f1b9c945d2b Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 11 Jun 2024 14:06:56 -0700 Subject: [PATCH] fix(utils.py): add new 'azure_ai/' route supports azure's openai compatible api endpoint --- .pre-commit-config.yaml | 8 ++++---- litellm/__init__.py | 2 ++ litellm/tests/test_completion.py | 21 +++++++++++++++++++++ litellm/utils.py | 6 ++++++ 4 files changed, 33 insertions(+), 4 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cc41d85f14..bec679090b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,8 +1,8 @@ repos: -- repo: https://github.com/psf/black - rev: 24.2.0 - hooks: - - id: black +# - repo: https://github.com/psf/black +# rev: 24.2.0 +# hooks: +# - id: black - repo: https://github.com/pycqa/flake8 rev: 7.0.0 # The version of flake8 to use hooks: diff --git a/litellm/__init__.py b/litellm/__init__.py index e92ae355e2..3f755d10f6 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -405,6 +405,7 @@ openai_compatible_providers: List = [ "xinference", "together_ai", "fireworks_ai", + "azure_ai", ] @@ -609,6 +610,7 @@ provider_list: List = [ "baseten", "azure", "azure_text", + "azure_ai", "sagemaker", "bedrock", "vllm", diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 2428cbf48d..3c3ba564ae 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -114,6 +114,27 @@ def test_null_role_response(): assert response.choices[0].message.role == "assistant" +def test_completion_azure_ai_command_r(): + try: + import os + + litellm.set_verbose = True + + os.environ["AZURE_AI_API_BASE"] = os.getenv("AZURE_COHERE_API_BASE", "") + os.environ["AZURE_AI_API_KEY"] = os.getenv("AZURE_COHERE_API_KEY", "") + + response: litellm.ModelResponse = completion( + model="azure_ai/command-r-plus", + messages=[{"role": "user", "content": "What is the meaning of life?"}], + ) # type: ignore + + assert "azure_ai" in response.model + except litellm.Timeout as e: + pass + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + def test_completion_azure_command_r(): try: litellm.set_verbose = True diff --git a/litellm/utils.py b/litellm/utils.py index 5e85419dcd..1538cec1d4 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -6612,6 +6612,12 @@ def get_llm_provider( or get_secret("TOGETHERAI_API_KEY") or get_secret("TOGETHER_AI_TOKEN") ) + elif custom_llm_provider == "azure_ai": + api_base = ( + api_base + or get_secret("AZURE_AI_API_BASE") # for Azure AI Mistral + ) # type: ignore + dynamic_api_key = get_secret("AZURE_AI_API_KEY") if api_base is not None and not isinstance(api_base, str): raise Exception( "api base needs to be a string. api_base={}".format(api_base)