mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
(router_strategy/) ensure all async functions use async cache methods (#6489)
* fix router strat * use async set / get cache in router_strategy * add coverage for router strategy * fix imports * fix batch_get_cache * use async methods for least busy * fix least busy use async methods * fix test_dual_cache_increment * test async_get_available_deployment when routing_strategy="least-busy"
This commit is contained in:
parent
f9ba74ef87
commit
441adad3ae
8 changed files with 202 additions and 20 deletions
|
@ -424,6 +424,7 @@ jobs:
|
||||||
- run: ruff check ./litellm
|
- run: ruff check ./litellm
|
||||||
- run: python ./tests/documentation_tests/test_general_setting_keys.py
|
- run: python ./tests/documentation_tests/test_general_setting_keys.py
|
||||||
- run: python ./tests/code_coverage_tests/router_code_coverage.py
|
- run: python ./tests/code_coverage_tests/router_code_coverage.py
|
||||||
|
- run: python ./tests/code_coverage_tests/test_router_strategy_async.py
|
||||||
- run: python ./tests/documentation_tests/test_env_keys.py
|
- run: python ./tests/documentation_tests/test_env_keys.py
|
||||||
- run: helm lint ./deploy/charts/litellm-helm
|
- run: helm lint ./deploy/charts/litellm-helm
|
||||||
|
|
||||||
|
|
|
@ -5127,6 +5127,7 @@ class Router:
|
||||||
and self.routing_strategy != "simple-shuffle"
|
and self.routing_strategy != "simple-shuffle"
|
||||||
and self.routing_strategy != "cost-based-routing"
|
and self.routing_strategy != "cost-based-routing"
|
||||||
and self.routing_strategy != "latency-based-routing"
|
and self.routing_strategy != "latency-based-routing"
|
||||||
|
and self.routing_strategy != "least-busy"
|
||||||
): # prevent regressions for other routing strategies, that don't have async get available deployments implemented.
|
): # prevent regressions for other routing strategies, that don't have async get available deployments implemented.
|
||||||
return self.get_available_deployment(
|
return self.get_available_deployment(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -5240,6 +5241,16 @@ class Router:
|
||||||
healthy_deployments=healthy_deployments,
|
healthy_deployments=healthy_deployments,
|
||||||
model=model,
|
model=model,
|
||||||
)
|
)
|
||||||
|
elif (
|
||||||
|
self.routing_strategy == "least-busy"
|
||||||
|
and self.leastbusy_logger is not None
|
||||||
|
):
|
||||||
|
deployment = (
|
||||||
|
await self.leastbusy_logger.async_get_available_deployments(
|
||||||
|
model_group=model,
|
||||||
|
healthy_deployments=healthy_deployments, # type: ignore
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
deployment = None
|
deployment = None
|
||||||
if deployment is None:
|
if deployment is None:
|
||||||
|
|
|
@ -145,13 +145,14 @@ class LeastBusyLoggingHandler(CustomLogger):
|
||||||
request_count_api_key = f"{model_group}_request_count"
|
request_count_api_key = f"{model_group}_request_count"
|
||||||
# decrement count in cache
|
# decrement count in cache
|
||||||
request_count_dict = (
|
request_count_dict = (
|
||||||
self.router_cache.get_cache(key=request_count_api_key) or {}
|
await self.router_cache.async_get_cache(key=request_count_api_key)
|
||||||
|
or {}
|
||||||
)
|
)
|
||||||
request_count_value: Optional[int] = request_count_dict.get(id, 0)
|
request_count_value: Optional[int] = request_count_dict.get(id, 0)
|
||||||
if request_count_value is None:
|
if request_count_value is None:
|
||||||
return
|
return
|
||||||
request_count_dict[id] = request_count_value - 1
|
request_count_dict[id] = request_count_value - 1
|
||||||
self.router_cache.set_cache(
|
await self.router_cache.async_set_cache(
|
||||||
key=request_count_api_key, value=request_count_dict
|
key=request_count_api_key, value=request_count_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -178,13 +179,14 @@ class LeastBusyLoggingHandler(CustomLogger):
|
||||||
request_count_api_key = f"{model_group}_request_count"
|
request_count_api_key = f"{model_group}_request_count"
|
||||||
# decrement count in cache
|
# decrement count in cache
|
||||||
request_count_dict = (
|
request_count_dict = (
|
||||||
self.router_cache.get_cache(key=request_count_api_key) or {}
|
await self.router_cache.async_get_cache(key=request_count_api_key)
|
||||||
|
or {}
|
||||||
)
|
)
|
||||||
request_count_value: Optional[int] = request_count_dict.get(id, 0)
|
request_count_value: Optional[int] = request_count_dict.get(id, 0)
|
||||||
if request_count_value is None:
|
if request_count_value is None:
|
||||||
return
|
return
|
||||||
request_count_dict[id] = request_count_value - 1
|
request_count_dict[id] = request_count_value - 1
|
||||||
self.router_cache.set_cache(
|
await self.router_cache.async_set_cache(
|
||||||
key=request_count_api_key, value=request_count_dict
|
key=request_count_api_key, value=request_count_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -194,10 +196,14 @@ class LeastBusyLoggingHandler(CustomLogger):
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_available_deployments(self, model_group: str, healthy_deployments: list):
|
def _get_available_deployments(
|
||||||
request_count_api_key = f"{model_group}_request_count"
|
self,
|
||||||
deployments = self.router_cache.get_cache(key=request_count_api_key) or {}
|
healthy_deployments: list,
|
||||||
all_deployments = deployments
|
all_deployments: dict,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Helper to get deployments using least busy strategy
|
||||||
|
"""
|
||||||
for d in healthy_deployments:
|
for d in healthy_deployments:
|
||||||
## if healthy deployment not yet used
|
## if healthy deployment not yet used
|
||||||
if d["model_info"]["id"] not in all_deployments:
|
if d["model_info"]["id"] not in all_deployments:
|
||||||
|
@ -219,3 +225,33 @@ class LeastBusyLoggingHandler(CustomLogger):
|
||||||
else:
|
else:
|
||||||
min_deployment = random.choice(healthy_deployments)
|
min_deployment = random.choice(healthy_deployments)
|
||||||
return min_deployment
|
return min_deployment
|
||||||
|
|
||||||
|
def get_available_deployments(
|
||||||
|
self,
|
||||||
|
model_group: str,
|
||||||
|
healthy_deployments: list,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Sync helper to get deployments using least busy strategy
|
||||||
|
"""
|
||||||
|
request_count_api_key = f"{model_group}_request_count"
|
||||||
|
all_deployments = self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||||
|
return self._get_available_deployments(
|
||||||
|
healthy_deployments=healthy_deployments,
|
||||||
|
all_deployments=all_deployments,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_get_available_deployments(
|
||||||
|
self, model_group: str, healthy_deployments: list
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Async helper to get deployments using least busy strategy
|
||||||
|
"""
|
||||||
|
request_count_api_key = f"{model_group}_request_count"
|
||||||
|
all_deployments = (
|
||||||
|
await self.router_cache.async_get_cache(key=request_count_api_key) or {}
|
||||||
|
)
|
||||||
|
return self._get_available_deployments(
|
||||||
|
healthy_deployments=healthy_deployments,
|
||||||
|
all_deployments=all_deployments,
|
||||||
|
)
|
||||||
|
|
|
@ -243,7 +243,7 @@ class LowestLatencyLoggingHandler(CustomLogger):
|
||||||
"latency"
|
"latency"
|
||||||
][: self.routing_args.max_latency_list_size - 1] + [1000.0]
|
][: self.routing_args.max_latency_list_size - 1] + [1000.0]
|
||||||
|
|
||||||
self.router_cache.set_cache(
|
await self.router_cache.async_set_cache(
|
||||||
key=latency_key,
|
key=latency_key,
|
||||||
value=request_count_dict,
|
value=request_count_dict,
|
||||||
ttl=self.routing_args.ttl,
|
ttl=self.routing_args.ttl,
|
||||||
|
@ -384,7 +384,7 @@ class LowestLatencyLoggingHandler(CustomLogger):
|
||||||
request_count_dict[id][precise_minute].get("rpm", 0) + 1
|
request_count_dict[id][precise_minute].get("rpm", 0) + 1
|
||||||
)
|
)
|
||||||
|
|
||||||
self.router_cache.set_cache(
|
await self.router_cache.async_set_cache(
|
||||||
key=latency_key, value=request_count_dict, ttl=self.routing_args.ttl
|
key=latency_key, value=request_count_dict, ttl=self.routing_args.ttl
|
||||||
) # reset map within window
|
) # reset map within window
|
||||||
|
|
||||||
|
|
|
@ -139,18 +139,22 @@ class LowestTPMLoggingHandler(CustomLogger):
|
||||||
# update cache
|
# update cache
|
||||||
|
|
||||||
## TPM
|
## TPM
|
||||||
request_count_dict = self.router_cache.get_cache(key=tpm_key) or {}
|
request_count_dict = (
|
||||||
|
await self.router_cache.async_get_cache(key=tpm_key) or {}
|
||||||
|
)
|
||||||
request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens
|
request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens
|
||||||
|
|
||||||
self.router_cache.set_cache(
|
await self.router_cache.async_set_cache(
|
||||||
key=tpm_key, value=request_count_dict, ttl=self.routing_args.ttl
|
key=tpm_key, value=request_count_dict, ttl=self.routing_args.ttl
|
||||||
)
|
)
|
||||||
|
|
||||||
## RPM
|
## RPM
|
||||||
request_count_dict = self.router_cache.get_cache(key=rpm_key) or {}
|
request_count_dict = (
|
||||||
|
await self.router_cache.async_get_cache(key=rpm_key) or {}
|
||||||
|
)
|
||||||
request_count_dict[id] = request_count_dict.get(id, 0) + 1
|
request_count_dict[id] = request_count_dict.get(id, 0) + 1
|
||||||
|
|
||||||
self.router_cache.set_cache(
|
await self.router_cache.async_set_cache(
|
||||||
key=rpm_key, value=request_count_dict, ttl=self.routing_args.ttl
|
key=rpm_key, value=request_count_dict, ttl=self.routing_args.ttl
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
120
tests/code_coverage_tests/test_router_strategy_async.py
Normal file
120
tests/code_coverage_tests/test_router_strategy_async.py
Normal file
|
@ -0,0 +1,120 @@
|
||||||
|
"""
|
||||||
|
Test that all cache calls in async functions in router_strategy/ are async
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
import ast
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncCacheCallVisitor(ast.NodeVisitor):
|
||||||
|
def __init__(self):
|
||||||
|
self.async_functions: Dict[str, List[Tuple[str, int]]] = {}
|
||||||
|
self.current_function = None
|
||||||
|
|
||||||
|
def visit_AsyncFunctionDef(self, node):
|
||||||
|
"""Visit async function definitions and store their cache calls"""
|
||||||
|
self.current_function = node.name
|
||||||
|
self.async_functions[node.name] = []
|
||||||
|
self.generic_visit(node)
|
||||||
|
self.current_function = None
|
||||||
|
|
||||||
|
def visit_Call(self, node):
|
||||||
|
"""Visit function calls and check for cache operations"""
|
||||||
|
if self.current_function is not None:
|
||||||
|
# Check if it's a cache-related call
|
||||||
|
if isinstance(node.func, ast.Attribute):
|
||||||
|
method_name = node.func.attr
|
||||||
|
if any(keyword in method_name.lower() for keyword in ["cache"]):
|
||||||
|
# Get the full method call path
|
||||||
|
if isinstance(node.func.value, ast.Name):
|
||||||
|
full_call = f"{node.func.value.id}.{method_name}"
|
||||||
|
elif isinstance(node.func.value, ast.Attribute):
|
||||||
|
# Handle nested attributes like self.router_cache.get
|
||||||
|
parts = []
|
||||||
|
current = node.func.value
|
||||||
|
while isinstance(current, ast.Attribute):
|
||||||
|
parts.append(current.attr)
|
||||||
|
current = current.value
|
||||||
|
if isinstance(current, ast.Name):
|
||||||
|
parts.append(current.id)
|
||||||
|
parts.reverse()
|
||||||
|
parts.append(method_name)
|
||||||
|
full_call = ".".join(parts)
|
||||||
|
else:
|
||||||
|
full_call = method_name
|
||||||
|
# Store both the call and its line number
|
||||||
|
self.async_functions[self.current_function].append(
|
||||||
|
(full_call, node.lineno)
|
||||||
|
)
|
||||||
|
self.generic_visit(node)
|
||||||
|
|
||||||
|
|
||||||
|
def get_python_files(directory: str) -> List[str]:
|
||||||
|
"""Get all Python files in the router_strategy directory"""
|
||||||
|
python_files = []
|
||||||
|
for file in os.listdir(directory):
|
||||||
|
if file.endswith(".py") and not file.startswith("__"):
|
||||||
|
python_files.append(os.path.join(directory, file))
|
||||||
|
return python_files
|
||||||
|
|
||||||
|
|
||||||
|
def analyze_file(file_path: str) -> Dict[str, List[Tuple[str, int]]]:
|
||||||
|
"""Analyze a Python file for async functions and their cache calls"""
|
||||||
|
with open(file_path, "r") as file:
|
||||||
|
tree = ast.parse(file.read())
|
||||||
|
|
||||||
|
visitor = AsyncCacheCallVisitor()
|
||||||
|
visitor.visit(tree)
|
||||||
|
return visitor.async_functions
|
||||||
|
|
||||||
|
|
||||||
|
def test_router_strategy_async_cache_calls():
|
||||||
|
"""Test that all cache calls in async functions are properly async"""
|
||||||
|
strategy_dir = os.path.join(
|
||||||
|
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
|
||||||
|
"litellm",
|
||||||
|
"router_strategy",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get all Python files in the router_strategy directory
|
||||||
|
python_files = get_python_files(strategy_dir)
|
||||||
|
|
||||||
|
print("python files:", python_files)
|
||||||
|
|
||||||
|
all_async_functions: Dict[str, Dict[str, List[Tuple[str, int]]]] = {}
|
||||||
|
|
||||||
|
for file_path in python_files:
|
||||||
|
file_name = os.path.basename(file_path)
|
||||||
|
async_functions = analyze_file(file_path)
|
||||||
|
|
||||||
|
if async_functions:
|
||||||
|
all_async_functions[file_name] = async_functions
|
||||||
|
print(f"\nAnalyzing {file_name}:")
|
||||||
|
|
||||||
|
for func_name, cache_calls in async_functions.items():
|
||||||
|
print(f"\nAsync function: {func_name}")
|
||||||
|
print(f"Cache calls found: {cache_calls}")
|
||||||
|
|
||||||
|
# Assert that cache calls in async functions use async methods
|
||||||
|
for call, line_number in cache_calls:
|
||||||
|
if any(keyword in call.lower() for keyword in ["cache"]):
|
||||||
|
assert (
|
||||||
|
"async" in call.lower()
|
||||||
|
), f"VIOLATION: Cache call '{call}' in async function '{func_name}' should be async. file path: {file_path}, line number: {line_number}"
|
||||||
|
|
||||||
|
# Assert we found async functions to analyze
|
||||||
|
assert (
|
||||||
|
len(all_async_functions) > 0
|
||||||
|
), "No async functions found in router_strategy directory"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_router_strategy_async_cache_calls()
|
|
@ -158,7 +158,7 @@ async def test_dual_cache_batch_operations(is_async):
|
||||||
if is_async:
|
if is_async:
|
||||||
results = await dual_cache.async_batch_get_cache(test_keys)
|
results = await dual_cache.async_batch_get_cache(test_keys)
|
||||||
else:
|
else:
|
||||||
results = dual_cache.batch_get_cache(test_keys)
|
results = dual_cache.batch_get_cache(test_keys, parent_otel_span=None)
|
||||||
|
|
||||||
assert results == test_values
|
assert results == test_values
|
||||||
mock_redis_get.assert_not_called()
|
mock_redis_get.assert_not_called()
|
||||||
|
@ -181,7 +181,10 @@ async def test_dual_cache_increment(is_async):
|
||||||
) as mock_redis_increment:
|
) as mock_redis_increment:
|
||||||
if is_async:
|
if is_async:
|
||||||
result = await dual_cache.async_increment_cache(
|
result = await dual_cache.async_increment_cache(
|
||||||
test_key, increment_value, local_only=True
|
test_key,
|
||||||
|
increment_value,
|
||||||
|
local_only=True,
|
||||||
|
parent_otel_span=None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
result = dual_cache.increment_cache(
|
result = dual_cache.increment_cache(
|
||||||
|
|
|
@ -65,7 +65,9 @@ def test_get_available_deployments():
|
||||||
# test_get_available_deployments()
|
# test_get_available_deployments()
|
||||||
|
|
||||||
|
|
||||||
def test_router_get_available_deployments():
|
@pytest.mark.parametrize("async_test", [True, False])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_router_get_available_deployments(async_test):
|
||||||
"""
|
"""
|
||||||
Tests if 'get_available_deployments' returns the least busy deployment
|
Tests if 'get_available_deployments' returns the least busy deployment
|
||||||
"""
|
"""
|
||||||
|
@ -114,8 +116,13 @@ def test_router_get_available_deployments():
|
||||||
deployment = "azure/chatgpt-v-2"
|
deployment = "azure/chatgpt-v-2"
|
||||||
request_count_dict = {1: 10, 2: 54, 3: 100}
|
request_count_dict = {1: 10, 2: 54, 3: 100}
|
||||||
cache_key = f"{model_group}_request_count"
|
cache_key = f"{model_group}_request_count"
|
||||||
|
if async_test is True:
|
||||||
|
await router.cache.async_set_cache(key=cache_key, value=request_count_dict)
|
||||||
|
deployment = await router.async_get_available_deployment(
|
||||||
|
model=model_group, messages=None
|
||||||
|
)
|
||||||
|
else:
|
||||||
router.cache.set_cache(key=cache_key, value=request_count_dict)
|
router.cache.set_cache(key=cache_key, value=request_count_dict)
|
||||||
|
|
||||||
deployment = router.get_available_deployment(model=model_group, messages=None)
|
deployment = router.get_available_deployment(model=model_group, messages=None)
|
||||||
print(f"deployment: {deployment}")
|
print(f"deployment: {deployment}")
|
||||||
assert deployment["model_info"]["id"] == "1"
|
assert deployment["model_info"]["id"] == "1"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue