litellm-mirror/tests/local_testing/test_get_llm_provider.py
Krish Dholakia 0120176541
Litellm dev 12 30 2024 p2 (#7495)
* test(azure_openai_o1.py): initial commit with testing for azure openai o1 preview model

* fix(base_llm_unit_tests.py): handle azure o1 preview response format tests

skip as o1 on azure doesn't support tool calling yet

* fix: initial commit of azure o1 handler using openai caller

simplifies calling + allows fake streaming logic alr. implemented for openai to just work

* feat(azure/o1_handler.py): fake o1 streaming for azure o1 models

azure does not currently support streaming for o1

* feat(o1_transformation.py): support overriding 'should_fake_stream' on azure/o1 via 'supports_native_streaming' param on model info

enables user to toggle on when azure allows o1 streaming without needing to bump versions

* style(router.py): remove 'give feedback/get help' messaging when router is used

Prevents noisy messaging

Closes https://github.com/BerriAI/litellm/issues/5942

* fix(types/utils.py): handle none logprobs

Fixes https://github.com/BerriAI/litellm/issues/328

* fix(exception_mapping_utils.py): fix error str unbound error

* refactor(azure_ai/): move to openai_like chat completion handler

allows for easy swapping of api base url's (e.g. ai.services.com)

Fixes https://github.com/BerriAI/litellm/issues/7275

* refactor(azure_ai/): move to base llm http handler

* fix(azure_ai/): handle differing api endpoints

* fix(azure_ai/): make sure all unit tests are passing

* fix: fix linting errors

* fix: fix linting errors

* fix: fix linting error

* fix: fix linting errors

* fix(azure_ai/transformation.py): handle extra body param

* fix(azure_ai/transformation.py): fix max retries param handling

* fix: fix test

* test(test_azure_o1.py): fix test

* fix(llm_http_handler.py): support handling azure ai unprocessable entity error

* fix(llm_http_handler.py): handle sync invalid param error for azure ai

* fix(azure_ai/): streaming support with base_llm_http_handler

* fix(llm_http_handler.py): working sync stream calls with unprocessable entity handling for azure ai

* fix: fix linting errors

* fix(llm_http_handler.py): fix linting error

* fix(azure_ai/): handle cohere tool call invalid index param error
2025-01-01 18:57:29 -08:00

202 lines
6.5 KiB
Python

import os
import sys
import traceback
from dotenv import load_dotenv
load_dotenv()
import io
from unittest.mock import patch
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm
def test_get_llm_provider():
_, response, _, _ = litellm.get_llm_provider(model="anthropic.claude-v2:1")
assert response == "bedrock"
# test_get_llm_provider()
def test_get_llm_provider_fireworks(): # tests finetuned fireworks models - https://github.com/BerriAI/litellm/issues/4923
model, custom_llm_provider, _, _ = litellm.get_llm_provider(
model="fireworks_ai/accounts/my-test-1234"
)
assert custom_llm_provider == "fireworks_ai"
assert model == "accounts/my-test-1234"
def test_get_llm_provider_catch_all():
_, response, _, _ = litellm.get_llm_provider(model="*")
assert response == "openai"
def test_get_llm_provider_gpt_instruct():
_, response, _, _ = litellm.get_llm_provider(model="gpt-3.5-turbo-instruct-0914")
assert response == "text-completion-openai"
def test_get_llm_provider_mistral_custom_api_base():
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
model="mistral/mistral-large-fr",
api_base="https://mistral-large-fr-ishaan.francecentral.inference.ai.azure.com/v1",
)
assert custom_llm_provider == "mistral"
assert model == "mistral-large-fr"
assert (
api_base
== "https://mistral-large-fr-ishaan.francecentral.inference.ai.azure.com/v1"
)
def test_get_llm_provider_deepseek_custom_api_base():
os.environ["DEEPSEEK_API_BASE"] = "MY-FAKE-BASE"
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
model="deepseek/deep-chat",
)
assert custom_llm_provider == "deepseek"
assert model == "deep-chat"
assert api_base == "MY-FAKE-BASE"
os.environ.pop("DEEPSEEK_API_BASE")
def test_get_llm_provider_vertex_ai_image_models():
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
model="imagegeneration@006", custom_llm_provider=None
)
assert custom_llm_provider == "vertex_ai"
def test_get_llm_provider_ai21_chat():
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
model="jamba-1.5-large",
)
assert custom_llm_provider == "ai21_chat"
assert model == "jamba-1.5-large"
assert api_base == "https://api.ai21.com/studio/v1"
def test_get_llm_provider_ai21_chat_test2():
"""
if user prefix with ai21/ but calls jamba-1.5-large then it should be ai21_chat provider
"""
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
model="ai21/jamba-1.5-large",
)
print("model=", model)
print("custom_llm_provider=", custom_llm_provider)
print("api_base=", api_base)
assert custom_llm_provider == "ai21_chat"
assert model == "jamba-1.5-large"
assert api_base == "https://api.ai21.com/studio/v1"
def test_get_llm_provider_cohere_chat_test2():
"""
if user prefix with cohere/ but calls command-r-plus then it should be cohere_chat provider
"""
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
model="cohere/command-r-plus",
)
print("model=", model)
print("custom_llm_provider=", custom_llm_provider)
print("api_base=", api_base)
assert custom_llm_provider == "cohere_chat"
assert model == "command-r-plus"
def test_get_llm_provider_azure_o1():
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
model="azure/o1-mini",
)
assert custom_llm_provider == "azure"
assert model == "o1-mini"
def test_default_api_base():
from litellm.litellm_core_utils.get_llm_provider_logic import (
_get_openai_compatible_provider_info,
)
# Patch environment variable to remove API base if it's set
with patch.dict(os.environ, {}, clear=True):
for provider in litellm.openai_compatible_providers:
# Get the API base for the given provider
_, _, _, api_base = _get_openai_compatible_provider_info(
model=f"{provider}/*", api_base=None, api_key=None, dynamic_api_key=None
)
if api_base is None:
continue
for other_provider in litellm.provider_list:
if other_provider != provider and provider != "{}_chat".format(
other_provider.value
):
if provider == "codestral" and other_provider == "mistral":
continue
elif provider == "github" and other_provider == "azure":
continue
assert other_provider.value not in api_base.replace("/openai", "")
def test_hosted_vllm_default_api_key():
from litellm.litellm_core_utils.get_llm_provider_logic import (
_get_openai_compatible_provider_info,
)
_, _, dynamic_api_key, _ = _get_openai_compatible_provider_info(
model="hosted_vllm/llama-3.1-70b-instruct",
api_base=None,
api_key=None,
dynamic_api_key=None,
)
assert dynamic_api_key == "fake-api-key"
def test_get_llm_provider_jina_ai():
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
model="jina_ai/jina-embeddings-v3",
)
assert custom_llm_provider == "jina_ai"
assert api_base == "https://api.jina.ai/v1"
assert model == "jina-embeddings-v3"
def test_get_llm_provider_hosted_vllm():
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
model="hosted_vllm/llama-3.1-70b-instruct",
)
assert custom_llm_provider == "hosted_vllm"
assert model == "llama-3.1-70b-instruct"
assert dynamic_api_key == "fake-api-key"
def test_get_llm_provider_watson_text():
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
model="watsonx_text/watson-text-to-speech",
)
assert custom_llm_provider == "watsonx_text"
assert model == "watson-text-to-speech"
def test_azure_global_standard_get_llm_provider():
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
model="azure_ai/gpt-4o-global-standard",
api_base="https://my-deployment-francecentral.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview",
api_key="fake-api-key",
)
assert custom_llm_provider == "azure_ai"