fix(utils.py): add new 'azure_ai/' route

supports azure's openai compatible api endpoint
This commit is contained in:
Krrish Dholakia 2024-06-11 14:06:56 -07:00
parent 318abde095
commit 88e567af2c
4 changed files with 33 additions and 4 deletions

View file

@ -1,8 +1,8 @@
repos: repos:
- repo: https://github.com/psf/black # - repo: https://github.com/psf/black
rev: 24.2.0 # rev: 24.2.0
hooks: # hooks:
- id: black # - id: black
- repo: https://github.com/pycqa/flake8 - repo: https://github.com/pycqa/flake8
rev: 7.0.0 # The version of flake8 to use rev: 7.0.0 # The version of flake8 to use
hooks: hooks:

View file

@ -405,6 +405,7 @@ openai_compatible_providers: List = [
"xinference", "xinference",
"together_ai", "together_ai",
"fireworks_ai", "fireworks_ai",
"azure_ai",
] ]
@ -609,6 +610,7 @@ provider_list: List = [
"baseten", "baseten",
"azure", "azure",
"azure_text", "azure_text",
"azure_ai",
"sagemaker", "sagemaker",
"bedrock", "bedrock",
"vllm", "vllm",

View file

@ -114,6 +114,27 @@ def test_null_role_response():
assert response.choices[0].message.role == "assistant" assert response.choices[0].message.role == "assistant"
def test_completion_azure_ai_command_r():
try:
import os
litellm.set_verbose = True
os.environ["AZURE_AI_API_BASE"] = os.getenv("AZURE_COHERE_API_BASE", "")
os.environ["AZURE_AI_API_KEY"] = os.getenv("AZURE_COHERE_API_KEY", "")
response: litellm.ModelResponse = completion(
model="azure_ai/command-r-plus",
messages=[{"role": "user", "content": "What is the meaning of life?"}],
) # type: ignore
assert "azure_ai" in response.model
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_completion_azure_command_r(): def test_completion_azure_command_r():
try: try:
litellm.set_verbose = True litellm.set_verbose = True

View file

@ -6612,6 +6612,12 @@ def get_llm_provider(
or get_secret("TOGETHERAI_API_KEY") or get_secret("TOGETHERAI_API_KEY")
or get_secret("TOGETHER_AI_TOKEN") or get_secret("TOGETHER_AI_TOKEN")
) )
elif custom_llm_provider == "azure_ai":
api_base = (
api_base
or get_secret("AZURE_AI_API_BASE") # for Azure AI Mistral
) # type: ignore
dynamic_api_key = get_secret("AZURE_AI_API_KEY")
if api_base is not None and not isinstance(api_base, str): if api_base is not None and not isinstance(api_base, str):
raise Exception( raise Exception(
"api base needs to be a string. api_base={}".format(api_base) "api base needs to be a string. api_base={}".format(api_base)