fix(utils.py): add new 'azure_ai/' route

supports azure's openai compatible api endpoint
This commit is contained in:
Krrish Dholakia 2024-06-11 14:06:56 -07:00
parent 61e5e162de
commit 4a27a50f9b
4 changed files with 33 additions and 4 deletions

View file

@ -1,8 +1,8 @@
repos:
- repo: https://github.com/psf/black
rev: 24.2.0
hooks:
- id: black
# - repo: https://github.com/psf/black
# rev: 24.2.0
# hooks:
# - id: black
- repo: https://github.com/pycqa/flake8
rev: 7.0.0 # The version of flake8 to use
hooks:

View file

@ -405,6 +405,7 @@ openai_compatible_providers: List = [
"xinference",
"together_ai",
"fireworks_ai",
"azure_ai",
]
@ -609,6 +610,7 @@ provider_list: List = [
"baseten",
"azure",
"azure_text",
"azure_ai",
"sagemaker",
"bedrock",
"vllm",

View file

@ -114,6 +114,27 @@ def test_null_role_response():
assert response.choices[0].message.role == "assistant"
def test_completion_azure_ai_command_r():
try:
import os
litellm.set_verbose = True
os.environ["AZURE_AI_API_BASE"] = os.getenv("AZURE_COHERE_API_BASE", "")
os.environ["AZURE_AI_API_KEY"] = os.getenv("AZURE_COHERE_API_KEY", "")
response: litellm.ModelResponse = completion(
model="azure_ai/command-r-plus",
messages=[{"role": "user", "content": "What is the meaning of life?"}],
) # type: ignore
assert "azure_ai" in response.model
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_completion_azure_command_r():
try:
litellm.set_verbose = True

View file

@ -6612,6 +6612,12 @@ def get_llm_provider(
or get_secret("TOGETHERAI_API_KEY")
or get_secret("TOGETHER_AI_TOKEN")
)
elif custom_llm_provider == "azure_ai":
api_base = (
api_base
or get_secret("AZURE_AI_API_BASE") # for Azure AI Mistral
) # type: ignore
dynamic_api_key = get_secret("AZURE_AI_API_KEY")
if api_base is not None and not isinstance(api_base, str):
raise Exception(
"api base needs to be a string. api_base={}".format(api_base)