fix(utils.py): support calling openai models via azure_ai/

This commit is contained in:
Krrish Dholakia 2024-08-14 13:41:04 -07:00
parent 4de5bc35a2
commit 3026e69926
4 changed files with 55 additions and 8 deletions

View file

@ -4898,7 +4898,6 @@ async def ahealth_check(
verbose_logger.error(
"litellm.ahealth_check(): Exception occured - {}".format(str(e))
)
verbose_logger.debug(traceback.format_exc())
stack_trace = traceback.format_exc()
if isinstance(stack_trace, str):
stack_trace = stack_trace[:1000]
@ -4907,7 +4906,12 @@ async def ahealth_check(
"error": "Missing `mode`. Set the `mode` for the model - https://docs.litellm.ai/docs/proxy/health#embedding-models"
}
error_to_return = str(e) + " stack trace: " + stack_trace
error_to_return = (
str(e)
+ "\nHave you set 'mode' - https://docs.litellm.ai/docs/proxy/health#embedding-models"
+ "\nstack trace: "
+ stack_trace
)
return {"error": error_to_return}

View file

@ -1,7 +1,7 @@
model_list:
- model_name: "*"
- model_name: azure-embedding-model
litellm_params:
model: "*"
litellm_settings:
success_callback: ["langsmith"]
model: azure/azure-embedding-model
api_base: os.environ/AZURE_API_BASE
api_key: os.environ/AZURE_API_KEY
api_version: "2023-07-01-preview"

View file

@ -190,6 +190,31 @@ def test_completion_azure_command_r():
pytest.fail(f"Error occurred: {e}")
@pytest.mark.parametrize(
"api_base",
[
"https://litellm8397336933.openai.azure.com",
"https://litellm8397336933.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2023-03-15-preview",
],
)
def test_completion_azure_ai_gpt_4o(api_base):
try:
litellm.set_verbose = True
response = completion(
model="azure_ai/gpt-4o",
api_base=api_base,
api_key=os.getenv("AZURE_AI_OPENAI_KEY"),
messages=[{"role": "user", "content": "What is the meaning of life?"}],
)
print(response)
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_completion_databricks(sync_mode):

View file

@ -4479,7 +4479,22 @@ def _is_non_openai_azure_model(model: str) -> bool:
or f"mistral/{model_name}" in litellm.mistral_chat_models
):
return True
except:
except Exception:
return False
return False
def _is_azure_openai_model(model: str) -> bool:
try:
if "/" in model:
model = model.split("/", 1)[1]
if (
model in litellm.open_ai_chat_completion_models
or model in litellm.open_ai_text_completion_models
or litellm.open_ai_embedding_models
):
return True
except Exception:
return False
return False
@ -4613,6 +4628,9 @@ def get_llm_provider(
elif custom_llm_provider == "azure_ai":
api_base = api_base or get_secret("AZURE_AI_API_BASE") # type: ignore
dynamic_api_key = api_key or get_secret("AZURE_AI_API_KEY")
if _is_azure_openai_model(model=model):
custom_llm_provider = "azure"
elif custom_llm_provider == "github":
api_base = api_base or get_secret("GITHUB_API_BASE") or "https://models.inference.ai.azure.com" # type: ignore
dynamic_api_key = api_key or get_secret("GITHUB_API_KEY")