test(router_code_coverage.py): check if all router functions are dire… (#6186)

* test(router_code_coverage.py): check if all router functions are directly tested

prevent regressions

* docs(configs.md): document all environment variables (#6185)

* docs: make it easier to find anthropic/openai prompt caching doc

* aded codecov yml (#6207)

* fix codecov.yaml

* run ci/cd again

* (refactor) caching use LLMCachingHandler for async_get_cache and set_cache  (#6208)

* use folder for caching

* fix importing caching

* fix clickhouse pyright

* fix linting

* fix correctly pass kwargs and args

* fix test case for embedding

* fix linting

* fix embedding caching logic

* fix refactor handle utils.py

* fix test_embedding_caching_azure_individual_items_reordered

* (feat) prometheus have well defined latency buckets (#6211)

* fix prometheus have well defined latency buckets

* use a well define latency bucket

* use types file for prometheus logging

* add test for LATENCY_BUCKETS

* fix prom testing

* fix config.yml

* (refactor caching) use LLMCachingHandler for caching streaming responses  (#6210)

* use folder for caching

* fix importing caching

* fix clickhouse pyright

* fix linting

* fix correctly pass kwargs and args

* fix test case for embedding

* fix linting

* fix embedding caching logic

* fix refactor handle utils.py

* refactor async set stream cache

* fix linting

* bump (#6187)

* update code cov yaml

* fix config.yml

* add caching component to code cov

* fix config.yml ci/cd

* add coverage for proxy auth

* (refactor caching) use common `_retrieve_from_cache` helper  (#6212)

* use folder for caching

* fix importing caching

* fix clickhouse pyright

* fix linting

* fix correctly pass kwargs and args

* fix test case for embedding

* fix linting

* fix embedding caching logic

* fix refactor handle utils.py

* refactor async set stream cache

* fix linting

* refactor - use _retrieve_from_cache

* refactor use _convert_cached_result_to_model_response

* fix linting errors

* bump: version 1.49.2 → 1.49.3

* fix code cov components

* test(test_router_helpers.py): add router component unit tests

* test: add additional router tests

* test: add more router testing

* test: add more router testing + more mock functions

* ci(router_code_coverage.py): fix check

---------

Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com>
Co-authored-by: yujonglee <yujonglee.dev@gmail.com>
This commit is contained in:
Krish Dholakia 2024-10-14 22:44:00 -07:00 committed by GitHub
parent 39486e2003
commit 1eb435e50a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 768 additions and 164 deletions

View file

@ -199,7 +199,7 @@ jobs:
command: |
pwd
ls
python -m pytest tests/local_testing --cov=litellm --cov-report=xml -vv -k "router" -x -s -v --junitxml=test-results/junit.xml --durations=5
python -m pytest tests/local_testing tests/router_unit_tests --cov=litellm --cov-report=xml -vv -k "router" -x -s -v --junitxml=test-results/junit.xml --durations=5
no_output_timeout: 120m
- run:
name: Rename the coverage files
@ -380,6 +380,7 @@ jobs:
- run: python -c "from litellm import *" || (echo '🚨 import failed, this means you introduced unprotected imports! 🚨'; exit 1)
- run: ruff check ./litellm
- run: python ./tests/documentation_tests/test_general_setting_keys.py
- run: python ./tests/code_coverage_tests/router_code_coverage.py
- run: python ./tests/documentation_tests/test_env_keys.py
db_migration_disable_update_check:

View file

@ -0,0 +1,28 @@
from typing import List, Optional
from ..types.utils import (
Categories,
CategoryAppliedInputTypes,
CategoryScores,
Embedding,
EmbeddingResponse,
ImageObject,
ImageResponse,
Moderation,
ModerationCreateResponse,
)
def mock_embedding(model: str, mock_response: Optional[List[float]]):
if mock_response is None:
mock_response = [0.0] * 1536
return EmbeddingResponse(
model=model,
data=[Embedding(embedding=mock_response, index=0, object="embedding")],
)
def mock_image_generation(model: str, mock_response: str):
return ImageResponse(
data=[ImageObject(url=mock_response)],
)

View file

@ -42,6 +42,10 @@ from litellm import ( # type: ignore
)
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.mock_functions import (
mock_embedding,
mock_image_generation,
)
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.secret_managers.main import get_secret_str
from litellm.utils import (
@ -3163,6 +3167,7 @@ def embedding(
tpm = kwargs.pop("tpm", None)
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
cooldown_time = kwargs.get("cooldown_time", None)
mock_response: Optional[List[float]] = kwargs.get("mock_response", None) # type: ignore
max_parallel_requests = kwargs.pop("max_parallel_requests", None)
model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", None)
@ -3268,6 +3273,9 @@ def embedding(
custom_llm_provider=custom_llm_provider,
**non_default_params,
)
if mock_response is not None:
return mock_embedding(model=model, mock_response=mock_response)
### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ###
if input_cost_per_token is not None and output_cost_per_token is not None:
litellm.register_model(
@ -4377,6 +4385,7 @@ def image_generation(
aimg_generation = kwargs.get("aimg_generation", False)
litellm_call_id = kwargs.get("litellm_call_id", None)
logger_fn = kwargs.get("logger_fn", None)
mock_response: Optional[str] = kwargs.get("mock_response", None) # type: ignore
proxy_server_request = kwargs.get("proxy_server_request", None)
model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", {})
@ -4486,6 +4495,8 @@ def image_generation(
},
custom_llm_provider=custom_llm_provider,
)
if mock_response is not None:
return mock_image_generation(model=model, mock_response=mock_response)
if custom_llm_provider == "azure":
# azure configs

View file

@ -111,6 +111,7 @@ from litellm.types.router import (
RouterModelGroupAliasItem,
RouterRateLimitError,
RouterRateLimitErrorBasic,
RoutingStrategy,
updateDeployment,
updateLiteLLMParams,
)
@ -519,6 +520,9 @@ class Router:
self._initialize_alerting()
def validate_fallbacks(self, fallback_param: Optional[List]):
"""
Validate the fallbacks parameter.
"""
if fallback_param is None:
return
@ -530,8 +534,13 @@ class Router:
f"Dictionary '{fallback_dict}' must have exactly one key, but has {len(fallback_dict)} keys."
)
def routing_strategy_init(self, routing_strategy: str, routing_strategy_args: dict):
if routing_strategy == "least-busy":
def routing_strategy_init(
self, routing_strategy: Union[RoutingStrategy, str], routing_strategy_args: dict
):
if (
routing_strategy == RoutingStrategy.LEAST_BUSY.value
or routing_strategy == RoutingStrategy.LEAST_BUSY
):
self.leastbusy_logger = LeastBusyLoggingHandler(
router_cache=self.cache, model_list=self.model_list
)
@ -542,7 +551,10 @@ class Router:
litellm.input_callback = [self.leastbusy_logger] # type: ignore
if isinstance(litellm.callbacks, list):
litellm.callbacks.append(self.leastbusy_logger) # type: ignore
elif routing_strategy == "usage-based-routing":
elif (
routing_strategy == RoutingStrategy.USAGE_BASED_ROUTING.value
or routing_strategy == RoutingStrategy.USAGE_BASED_ROUTING
):
self.lowesttpm_logger = LowestTPMLoggingHandler(
router_cache=self.cache,
model_list=self.model_list,
@ -550,7 +562,10 @@ class Router:
)
if isinstance(litellm.callbacks, list):
litellm.callbacks.append(self.lowesttpm_logger) # type: ignore
elif routing_strategy == "usage-based-routing-v2":
elif (
routing_strategy == RoutingStrategy.USAGE_BASED_ROUTING_V2.value
or routing_strategy == RoutingStrategy.USAGE_BASED_ROUTING_V2
):
self.lowesttpm_logger_v2 = LowestTPMLoggingHandler_v2(
router_cache=self.cache,
model_list=self.model_list,
@ -558,7 +573,10 @@ class Router:
)
if isinstance(litellm.callbacks, list):
litellm.callbacks.append(self.lowesttpm_logger_v2) # type: ignore
elif routing_strategy == "latency-based-routing":
elif (
routing_strategy == RoutingStrategy.LATENCY_BASED.value
or routing_strategy == RoutingStrategy.LATENCY_BASED
):
self.lowestlatency_logger = LowestLatencyLoggingHandler(
router_cache=self.cache,
model_list=self.model_list,
@ -566,7 +584,10 @@ class Router:
)
if isinstance(litellm.callbacks, list):
litellm.callbacks.append(self.lowestlatency_logger) # type: ignore
elif routing_strategy == "cost-based-routing":
elif (
routing_strategy == RoutingStrategy.COST_BASED.value
or routing_strategy == RoutingStrategy.COST_BASED
):
self.lowestcost_logger = LowestCostLoggingHandler(
router_cache=self.cache,
model_list=self.model_list,
@ -574,10 +595,14 @@ class Router:
)
if isinstance(litellm.callbacks, list):
litellm.callbacks.append(self.lowestcost_logger) # type: ignore
else:
pass
def print_deployment(self, deployment: dict):
"""
returns a copy of the deployment with the api key masked
Only returns 2 characters of the api key and masks the rest with * (10 *).
"""
try:
_deployment_copy = copy.deepcopy(deployment)
@ -1746,7 +1771,6 @@ class Router:
try:
kwargs["model"] = model
kwargs["prompt"] = prompt
kwargs["original_function"] = self.text_completion
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
@ -1770,13 +1794,7 @@ class Router:
# call via litellm.completion()
return litellm.text_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore
except Exception as e:
if self.num_retries > 0:
kwargs["model"] = model
kwargs["messages"] = messages
kwargs["original_function"] = self.text_completion
return self.function_with_retries(**kwargs)
else:
raise e
raise e
async def atext_completion(
self,
@ -3005,7 +3023,7 @@ class Router:
async def make_call(self, original_function: Any, *args, **kwargs):
"""
Handler for making a call to the .completion()/.embeddings() functions.
Handler for making a call to the .completion()/.embeddings()/etc. functions.
"""
model_group = kwargs.get("model")
response = await original_function(*args, **kwargs)

View file

@ -608,3 +608,11 @@ VALID_LITELLM_ENVIRONMENTS = [
"staging",
"production",
]
class RoutingStrategy(enum.Enum):
LEAST_BUSY = "least-busy"
LATENCY_BASED = "latency-based-routing"
COST_BASED = "cost-based-routing"
USAGE_BASED_ROUTING_V2 = "usage-based-routing-v2"
USAGE_BASED_ROUTING = "usage-based-routing"

View file

@ -11,6 +11,12 @@ from openai.types.completion_usage import (
CompletionUsage,
PromptTokensDetails,
)
from openai.types.moderation import (
Categories,
CategoryAppliedInputTypes,
CategoryScores,
)
from openai.types.moderation_create_response import Moderation, ModerationCreateResponse
from pydantic import BaseModel, ConfigDict, PrivateAttr
from typing_extensions import Callable, Dict, Required, TypedDict, override
@ -20,6 +26,7 @@ from .llms.openai import (
ChatCompletionUsageBlock,
OpenAIChatCompletionChunk,
)
from .rerank import RerankResponse
def _generate_id(): # private helper function
@ -811,7 +818,7 @@ class EmbeddingResponse(OpenAIObject):
model: Optional[str] = None,
usage: Optional[Usage] = None,
response_ms=None,
data: Optional[List] = None,
data: Optional[Union[List, List[Embedding]]] = None,
hidden_params=None,
_response_headers=None,
**params,

View file

@ -0,0 +1,117 @@
import ast
import os
def get_function_names_from_file(file_path):
"""
Extracts all function names from a given Python file.
"""
with open(file_path, "r") as file:
tree = ast.parse(file.read())
function_names = []
for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
function_names.append(node.name)
return function_names
def get_all_functions_called_in_tests(base_dir):
"""
Returns a set of function names that are called in test functions
inside 'local_testing' and 'router_unit_test' directories,
specifically in files containing the word 'router'.
"""
called_functions = set()
test_dirs = ["local_testing", "router_unit_tests"]
for test_dir in test_dirs:
dir_path = os.path.join(base_dir, test_dir)
if not os.path.exists(dir_path):
print(f"Warning: Directory {dir_path} does not exist.")
continue
print("dir_path: ", dir_path)
for root, _, files in os.walk(dir_path):
for file in files:
if file.endswith(".py") and "router" in file.lower():
print("file: ", file)
file_path = os.path.join(root, file)
with open(file_path, "r") as f:
try:
tree = ast.parse(f.read())
except SyntaxError:
print(f"Warning: Syntax error in file {file_path}")
continue
if file == "test_router_validate_fallbacks.py":
print(f"tree: {tree}")
for node in ast.walk(tree):
if isinstance(node, ast.Call) and isinstance(
node.func, ast.Name
):
called_functions.add(node.func.id)
elif isinstance(node, ast.Call) and isinstance(
node.func, ast.Attribute
):
called_functions.add(node.func.attr)
return called_functions
def get_functions_from_router(file_path):
"""
Extracts all functions defined in router.py.
"""
return get_function_names_from_file(file_path)
ignored_function_names = [
"__init__",
"_acreate_file",
"_acreate_batch",
"acreate_assistants",
"adelete_assistant",
"aget_assistants",
"acreate_thread",
"aget_thread",
"a_add_message",
"aget_messages",
"arun_thread",
]
def main():
router_file = "./litellm/router.py" # Update this path if it's located elsewhere
# router_file = "../../litellm/router.py" ## LOCAL TESTING
tests_dir = (
"./tests/" # Update this path if your tests directory is located elsewhere
)
# tests_dir = "../../tests/" # LOCAL TESTING
router_functions = get_functions_from_router(router_file)
print("router_functions: ", router_functions)
called_functions_in_tests = get_all_functions_called_in_tests(tests_dir)
untested_functions = [
fn for fn in router_functions if fn not in called_functions_in_tests
]
if untested_functions:
all_untested_functions = []
for func in untested_functions:
if func not in ignored_function_names:
all_untested_functions.append(func)
untested_perc = (len(all_untested_functions)) / len(router_functions)
print("perc_covered: ", untested_perc)
if untested_perc < 0.3:
print("The following functions in router.py are not tested:")
raise Exception(
f"{untested_perc * 100:.2f}% of functions in router.py are not tested: {all_untested_functions}"
)
else:
print("All functions in router.py are covered by tests.")
if __name__ == "__main__":
main()

View file

@ -366,41 +366,6 @@ def test_anthropic_tool_streaming():
assert tool_use["index"] == correct_tool_index
@pytest.mark.asyncio
async def test_anthropic_router_completion_e2e():
litellm.set_verbose = True
litellm.adapters = [{"id": "anthropic", "adapter": anthropic_adapter}]
router = Router(
model_list=[
{
"model_name": "claude-3-5-sonnet-20240620",
"litellm_params": {
"model": "gpt-3.5-turbo",
"mock_response": "hi this is macintosh.",
},
}
]
)
messages = [{"role": "user", "content": "Hey, how's it going?"}]
response = await router.aadapter_completion(
model="claude-3-5-sonnet-20240620",
messages=messages,
adapter_id="anthropic",
mock_response="This is a fake call",
)
print("Response: {}".format(response))
assert response is not None
assert isinstance(response, AnthropicResponse)
assert response.model == "gpt-3.5-turbo"
def test_anthropic_tool_calling_translation():
kwargs = {
"model": "claude-3-5-sonnet-20240620",

View file

@ -83,43 +83,6 @@ async def test_audio_speech_litellm(sync_mode, model, api_base, api_key):
assert isinstance(response, HttpxBinaryResponseContent)
@pytest.mark.parametrize("mode", ["iterator"]) # "file",
@pytest.mark.asyncio
async def test_audio_speech_router(mode):
speech_file_path = Path(__file__).parent / "speech.mp3"
from litellm import Router
client = Router(
model_list=[
{
"model_name": "tts",
"litellm_params": {
"model": "openai/tts-1",
},
},
]
)
response = await client.aspeech(
model="tts",
voice="alloy",
input="the quick brown fox jumped over the lazy dogs",
api_base=None,
api_key=None,
organization=None,
project=None,
max_retries=1,
timeout=600,
client=None,
optional_params={},
)
from litellm.llms.OpenAI.openai import HttpxBinaryResponseContent
assert isinstance(response, HttpxBinaryResponseContent)
@pytest.mark.parametrize(
"sync_mode",
[False, True],

View file

@ -1310,19 +1310,38 @@ def test_aembedding_on_router():
router = Router(model_list=model_list)
async def embedding_call():
## Test 1: user facing function
response = await router.aembedding(
model="text-embedding-ada-002",
input=["good morning from litellm", "this is another item"],
)
print(response)
## Test 2: underlying function
response = await router._aembedding(
model="text-embedding-ada-002",
input=["good morning from litellm 2"],
)
print(response)
router.reset()
asyncio.run(embedding_call())
print("\n Making sync Embedding call\n")
## Test 1: user facing function
response = router.embedding(
model="text-embedding-ada-002",
input=["good morning from litellm 2"],
)
print(response)
router.reset()
## Test 2: underlying function
response = router._embedding(
model="text-embedding-ada-002",
input=["good morning from litellm 2"],
)
print(response)
router.reset()
except Exception as e:
if "Your task failed as a result of our safety system." in str(e):
@ -1843,10 +1862,16 @@ async def test_router_amoderation():
]
router = Router(model_list=model_list)
## Test 1: user facing function
result = await router.amoderation(
model="openai-moderations", input="this is valid good text"
)
## Test 2: underlying function
result = await router._amoderation(
model="openai-moderations", input="this is valid good text"
)
print("moderation result", result)

View file

@ -79,81 +79,6 @@ async def test_transcription(model, api_key, api_base, response_format, sync_mod
assert transcript.text is not None
# This file includes the custom callbacks for LiteLLM Proxy
# Once defined, these can be passed in proxy_config.yaml
class MyCustomHandler(CustomLogger):
def __init__(self):
self.openai_client = None
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
# init logging config
print("logging a transcript kwargs: ", kwargs)
print("openai client=", kwargs.get("client"))
self.openai_client = kwargs.get("client")
except Exception:
pass
proxy_handler_instance = MyCustomHandler()
# Set litellm.callbacks = [proxy_handler_instance] on the proxy
# need to set litellm.callbacks = [proxy_handler_instance] # on the proxy
@pytest.mark.asyncio
async def test_transcription_on_router():
litellm.set_verbose = True
litellm.callbacks = [proxy_handler_instance]
print("\n Testing async transcription on router\n")
try:
model_list = [
{
"model_name": "whisper",
"litellm_params": {
"model": "whisper-1",
},
},
{
"model_name": "whisper",
"litellm_params": {
"model": "azure/azure-whisper",
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com/",
"api_key": os.getenv("AZURE_EUROPE_API_KEY"),
"api_version": "2024-02-15-preview",
},
},
]
router = Router(model_list=model_list)
router_level_clients = []
for deployment in router.model_list:
_deployment_openai_client = router._get_client(
deployment=deployment,
kwargs={"model": "whisper-1"},
client_type="async",
)
router_level_clients.append(str(_deployment_openai_client))
response = await router.atranscription(
model="whisper",
file=audio_file,
)
print(response)
# PROD Test
# Ensure we ONLY use OpenAI/Azure client initialized on the router level
await asyncio.sleep(5)
print("OpenAI Client used= ", proxy_handler_instance.openai_client)
print("all router level clients= ", router_level_clients)
assert proxy_handler_instance.openai_client in router_level_clients
except Exception as e:
traceback.print_exc()
pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio()
async def test_transcription_caching():
import litellm

View file

@ -0,0 +1,5 @@
## Router component unit tests.
Please name all files with the word 'router' in them.
This is used to ensure all functions in the router are tested.

Binary file not shown.

View file

@ -0,0 +1,279 @@
import sys
import os
import traceback
from dotenv import load_dotenv
from fastapi import Request
from datetime import datetime
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from litellm import Router, CustomLogger
# Get the current directory of the file being run
pwd = os.path.dirname(os.path.realpath(__file__))
print(pwd)
file_path = os.path.join(pwd, "gettysburg.wav")
audio_file = open(file_path, "rb")
from pathlib import Path
import litellm
import pytest
import asyncio
@pytest.fixture
def model_list():
return [
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
{
"model_name": "gpt-4o",
"litellm_params": {
"model": "gpt-4o",
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
{
"model_name": "dall-e-3",
"litellm_params": {
"model": "dall-e-3",
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
{
"model_name": "cohere-rerank",
"litellm_params": {
"model": "cohere/rerank-english-v3.0",
"api_key": os.getenv("COHERE_API_KEY"),
},
},
{
"model_name": "claude-3-5-sonnet-20240620",
"litellm_params": {
"model": "gpt-3.5-turbo",
"mock_response": "hi this is macintosh.",
},
},
]
# This file includes the custom callbacks for LiteLLM Proxy
# Once defined, these can be passed in proxy_config.yaml
class MyCustomHandler(CustomLogger):
def __init__(self):
self.openai_client = None
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
# init logging config
print("logging a transcript kwargs: ", kwargs)
print("openai client=", kwargs.get("client"))
self.openai_client = kwargs.get("client")
except Exception:
pass
proxy_handler_instance = MyCustomHandler()
# Set litellm.callbacks = [proxy_handler_instance] on the proxy
# need to set litellm.callbacks = [proxy_handler_instance] # on the proxy
@pytest.mark.asyncio
async def test_transcription_on_router():
litellm.set_verbose = True
litellm.callbacks = [proxy_handler_instance]
print("\n Testing async transcription on router\n")
try:
model_list = [
{
"model_name": "whisper",
"litellm_params": {
"model": "whisper-1",
},
},
{
"model_name": "whisper",
"litellm_params": {
"model": "azure/azure-whisper",
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com/",
"api_key": os.getenv("AZURE_EUROPE_API_KEY"),
"api_version": "2024-02-15-preview",
},
},
]
router = Router(model_list=model_list)
router_level_clients = []
for deployment in router.model_list:
_deployment_openai_client = router._get_client(
deployment=deployment,
kwargs={"model": "whisper-1"},
client_type="async",
)
router_level_clients.append(str(_deployment_openai_client))
## test 1: user facing function
response = await router.atranscription(
model="whisper",
file=audio_file,
)
## test 2: underlying function
response = await router._atranscription(
model="whisper",
file=audio_file,
)
print(response)
# PROD Test
# Ensure we ONLY use OpenAI/Azure client initialized on the router level
await asyncio.sleep(5)
print("OpenAI Client used= ", proxy_handler_instance.openai_client)
print("all router level clients= ", router_level_clients)
assert proxy_handler_instance.openai_client in router_level_clients
except Exception as e:
traceback.print_exc()
pytest.fail(f"Error occurred: {e}")
@pytest.mark.parametrize("mode", ["iterator"]) # "file",
@pytest.mark.asyncio
async def test_audio_speech_router(mode):
from litellm import Router
client = Router(
model_list=[
{
"model_name": "tts",
"litellm_params": {
"model": "openai/tts-1",
},
},
]
)
response = await client.aspeech(
model="tts",
voice="alloy",
input="the quick brown fox jumped over the lazy dogs",
api_base=None,
api_key=None,
organization=None,
project=None,
max_retries=1,
timeout=600,
client=None,
optional_params={},
)
from litellm.llms.OpenAI.openai import HttpxBinaryResponseContent
assert isinstance(response, HttpxBinaryResponseContent)
@pytest.mark.asyncio()
async def test_rerank_endpoint(model_list):
from litellm.types.utils import RerankResponse
router = Router(model_list=model_list)
## Test 1: user facing function
response = await router.arerank(
model="cohere-rerank",
query="hello",
documents=["hello", "world"],
top_n=3,
)
## Test 2: underlying function
response = await router._arerank(
model="cohere-rerank",
query="hello",
documents=["hello", "world"],
top_n=3,
)
print("async re rank response: ", response)
assert response.id is not None
assert response.results is not None
RerankResponse.model_validate(response)
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_text_completion_endpoint(model_list, sync_mode):
router = Router(model_list=model_list)
if sync_mode:
response = router.text_completion(
model="gpt-3.5-turbo",
prompt="Hello, how are you?",
mock_response="I'm fine, thank you!",
)
else:
## Test 1: user facing function
response = await router.atext_completion(
model="gpt-3.5-turbo",
prompt="Hello, how are you?",
mock_response="I'm fine, thank you!",
)
## Test 2: underlying function
response_2 = await router._atext_completion(
model="gpt-3.5-turbo",
prompt="Hello, how are you?",
mock_response="I'm fine, thank you!",
)
assert response_2.choices[0].text == "I'm fine, thank you!"
assert response.choices[0].text == "I'm fine, thank you!"
@pytest.mark.asyncio
async def test_anthropic_router_completion_e2e(model_list):
from litellm.adapters.anthropic_adapter import anthropic_adapter
from litellm.types.llms.anthropic import AnthropicResponse
litellm.set_verbose = True
litellm.adapters = [{"id": "anthropic", "adapter": anthropic_adapter}]
router = Router(model_list=model_list)
messages = [{"role": "user", "content": "Hey, how's it going?"}]
## Test 1: user facing function
response = await router.aadapter_completion(
model="claude-3-5-sonnet-20240620",
messages=messages,
adapter_id="anthropic",
mock_response="This is a fake call",
)
## Test 2: underlying function
await router._aadapter_completion(
model="claude-3-5-sonnet-20240620",
messages=messages,
adapter_id="anthropic",
mock_response="This is a fake call",
)
print("Response: {}".format(response))
assert response is not None
AnthropicResponse.model_validate(response)
assert response.model == "gpt-3.5-turbo"

View file

@ -0,0 +1,252 @@
import sys
import os
import traceback
from dotenv import load_dotenv
from fastapi import Request
from datetime import datetime
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from litellm import Router
import pytest
from unittest.mock import patch, MagicMock, AsyncMock
@pytest.fixture
def model_list():
return [
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
{
"model_name": "gpt-4o",
"litellm_params": {
"model": "gpt-4o",
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
{
"model_name": "dall-e-3",
"litellm_params": {
"model": "dall-e-3",
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
]
def test_validate_fallbacks(model_list):
router = Router(model_list=model_list, fallbacks=[{"gpt-4o": "gpt-3.5-turbo"}])
router.validate_fallbacks(fallback_param=[{"gpt-4o": "gpt-3.5-turbo"}])
def test_routing_strategy_init(model_list):
"""Test if all routing strategies are initialized correctly"""
from litellm.types.router import RoutingStrategy
router = Router(model_list=model_list)
for strategy in RoutingStrategy._member_names_:
router.routing_strategy_init(
routing_strategy=strategy, routing_strategy_args={}
)
def test_print_deployment(model_list):
"""Test if the api key is masked correctly"""
router = Router(model_list=model_list)
deployment = {
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
},
}
printed_deployment = router.print_deployment(deployment)
assert 10 * "*" in printed_deployment["litellm_params"]["api_key"]
def test_completion(model_list):
"""Test if the completion function is working correctly"""
router = Router(model_list=model_list)
response = router._completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello, how are you?"}],
mock_response="I'm fine, thank you!",
)
assert response["choices"][0]["message"]["content"] == "I'm fine, thank you!"
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.flaky(retries=6, delay=1)
@pytest.mark.asyncio
async def test_image_generation(model_list, sync_mode):
"""Test if the underlying '_image_generation' function is working correctly"""
from litellm.types.utils import ImageResponse
router = Router(model_list=model_list)
if sync_mode:
response = router._image_generation(
model="dall-e-3",
prompt="A cute baby sea otter",
)
else:
response = await router._aimage_generation(
model="dall-e-3",
prompt="A cute baby sea otter",
)
ImageResponse.model_validate(response)
@pytest.mark.asyncio
async def test_router_acompletion_util(model_list):
"""Test if the underlying '_acompletion' function is working correctly"""
router = Router(model_list=model_list)
response = await router._acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello, how are you?"}],
mock_response="I'm fine, thank you!",
)
assert response["choices"][0]["message"]["content"] == "I'm fine, thank you!"
@pytest.mark.asyncio
async def test_router_abatch_completion_one_model_multiple_requests_util(model_list):
"""Test if the 'abatch_completion_one_model_multiple_requests' function is working correctly"""
router = Router(model_list=model_list)
response = await router.abatch_completion_one_model_multiple_requests(
model="gpt-3.5-turbo",
messages=[
[{"role": "user", "content": "Hello, how are you?"}],
[{"role": "user", "content": "Hello, how are you?"}],
],
mock_response="I'm fine, thank you!",
)
print(response)
assert response[0]["choices"][0]["message"]["content"] == "I'm fine, thank you!"
assert response[1]["choices"][0]["message"]["content"] == "I'm fine, thank you!"
@pytest.mark.asyncio
async def test_router_schedule_acompletion(model_list):
"""Test if the 'schedule_acompletion' function is working correctly"""
router = Router(model_list=model_list)
response = await router.schedule_acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello, how are you?"}],
mock_response="I'm fine, thank you!",
priority=1,
)
assert response["choices"][0]["message"]["content"] == "I'm fine, thank you!"
@pytest.mark.asyncio
async def test_router_arealtime(model_list):
"""Test if the '_arealtime' function is working correctly"""
import litellm
router = Router(model_list=model_list)
with patch.object(litellm, "_arealtime", AsyncMock()) as mock_arealtime:
mock_arealtime.return_value = "I'm fine, thank you!"
await router._arealtime(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello, how are you?"}],
)
mock_arealtime.assert_awaited_once()
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_router_function_with_fallbacks(model_list, sync_mode):
"""Test if the router 'async_function_with_fallbacks' + 'function_with_fallbacks' are working correctly"""
router = Router(model_list=model_list)
data = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hello, how are you?"}],
"mock_response": "I'm fine, thank you!",
"num_retries": 0,
}
if sync_mode:
response = router.function_with_fallbacks(
original_function=router._completion,
**data,
)
else:
response = await router.async_function_with_fallbacks(
original_function=router._acompletion,
**data,
)
assert response.choices[0].message.content == "I'm fine, thank you!"
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_router_function_with_retries(model_list, sync_mode):
"""Test if the router 'async_function_with_retries' + 'function_with_retries' are working correctly"""
router = Router(model_list=model_list)
data = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hello, how are you?"}],
"mock_response": "I'm fine, thank you!",
"num_retries": 0,
}
if sync_mode:
response = router.function_with_retries(
original_function=router._completion,
**data,
)
else:
response = await router.async_function_with_retries(
original_function=router._acompletion,
**data,
)
assert response.choices[0].message.content == "I'm fine, thank you!"
@pytest.mark.asyncio
async def test_router_make_call(model_list):
"""Test if the router 'make_call' function is working correctly"""
## ACOMPLETION
router = Router(model_list=model_list)
response = await router.make_call(
original_function=router._acompletion,
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello, how are you?"}],
mock_response="I'm fine, thank you!",
)
assert response.choices[0].message.content == "I'm fine, thank you!"
## ATEXT_COMPLETION
response = await router.make_call(
original_function=router._atext_completion,
model="gpt-3.5-turbo",
prompt="Hello, how are you?",
mock_response="I'm fine, thank you!",
)
assert response.choices[0].text == "I'm fine, thank you!"
## AEMBEDDING
response = await router.make_call(
original_function=router._aembedding,
model="gpt-3.5-turbo",
input="Hello, how are you?",
mock_response=[0.1, 0.2, 0.3],
)
assert response.data[0].embedding == [0.1, 0.2, 0.3]
## AIMAGE_GENERATION
response = await router.make_call(
original_function=router._aimage_generation,
model="dall-e-3",
prompt="A cute baby sea otter",
mock_response="https://example.com/image.png",
)
assert response.data[0].url == "https://example.com/image.png"