mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
(router_strategy/) ensure all async functions use async cache methods (#6489)
* fix router strat * use async set / get cache in router_strategy * add coverage for router strategy * fix imports * fix batch_get_cache * use async methods for least busy * fix least busy use async methods * fix test_dual_cache_increment * test async_get_available_deployment when routing_strategy="least-busy"
This commit is contained in:
parent
f9ba74ef87
commit
441adad3ae
8 changed files with 202 additions and 20 deletions
|
@ -424,6 +424,7 @@ jobs:
|
|||
- run: ruff check ./litellm
|
||||
- run: python ./tests/documentation_tests/test_general_setting_keys.py
|
||||
- run: python ./tests/code_coverage_tests/router_code_coverage.py
|
||||
- run: python ./tests/code_coverage_tests/test_router_strategy_async.py
|
||||
- run: python ./tests/documentation_tests/test_env_keys.py
|
||||
- run: helm lint ./deploy/charts/litellm-helm
|
||||
|
||||
|
|
|
@ -5127,6 +5127,7 @@ class Router:
|
|||
and self.routing_strategy != "simple-shuffle"
|
||||
and self.routing_strategy != "cost-based-routing"
|
||||
and self.routing_strategy != "latency-based-routing"
|
||||
and self.routing_strategy != "least-busy"
|
||||
): # prevent regressions for other routing strategies, that don't have async get available deployments implemented.
|
||||
return self.get_available_deployment(
|
||||
model=model,
|
||||
|
@ -5240,6 +5241,16 @@ class Router:
|
|||
healthy_deployments=healthy_deployments,
|
||||
model=model,
|
||||
)
|
||||
elif (
|
||||
self.routing_strategy == "least-busy"
|
||||
and self.leastbusy_logger is not None
|
||||
):
|
||||
deployment = (
|
||||
await self.leastbusy_logger.async_get_available_deployments(
|
||||
model_group=model,
|
||||
healthy_deployments=healthy_deployments, # type: ignore
|
||||
)
|
||||
)
|
||||
else:
|
||||
deployment = None
|
||||
if deployment is None:
|
||||
|
|
|
@ -145,13 +145,14 @@ class LeastBusyLoggingHandler(CustomLogger):
|
|||
request_count_api_key = f"{model_group}_request_count"
|
||||
# decrement count in cache
|
||||
request_count_dict = (
|
||||
self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||
await self.router_cache.async_get_cache(key=request_count_api_key)
|
||||
or {}
|
||||
)
|
||||
request_count_value: Optional[int] = request_count_dict.get(id, 0)
|
||||
if request_count_value is None:
|
||||
return
|
||||
request_count_dict[id] = request_count_value - 1
|
||||
self.router_cache.set_cache(
|
||||
await self.router_cache.async_set_cache(
|
||||
key=request_count_api_key, value=request_count_dict
|
||||
)
|
||||
|
||||
|
@ -178,13 +179,14 @@ class LeastBusyLoggingHandler(CustomLogger):
|
|||
request_count_api_key = f"{model_group}_request_count"
|
||||
# decrement count in cache
|
||||
request_count_dict = (
|
||||
self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||
await self.router_cache.async_get_cache(key=request_count_api_key)
|
||||
or {}
|
||||
)
|
||||
request_count_value: Optional[int] = request_count_dict.get(id, 0)
|
||||
if request_count_value is None:
|
||||
return
|
||||
request_count_dict[id] = request_count_value - 1
|
||||
self.router_cache.set_cache(
|
||||
await self.router_cache.async_set_cache(
|
||||
key=request_count_api_key, value=request_count_dict
|
||||
)
|
||||
|
||||
|
@ -194,10 +196,14 @@ class LeastBusyLoggingHandler(CustomLogger):
|
|||
except Exception:
|
||||
pass
|
||||
|
||||
def get_available_deployments(self, model_group: str, healthy_deployments: list):
|
||||
request_count_api_key = f"{model_group}_request_count"
|
||||
deployments = self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||
all_deployments = deployments
|
||||
def _get_available_deployments(
|
||||
self,
|
||||
healthy_deployments: list,
|
||||
all_deployments: dict,
|
||||
):
|
||||
"""
|
||||
Helper to get deployments using least busy strategy
|
||||
"""
|
||||
for d in healthy_deployments:
|
||||
## if healthy deployment not yet used
|
||||
if d["model_info"]["id"] not in all_deployments:
|
||||
|
@ -219,3 +225,33 @@ class LeastBusyLoggingHandler(CustomLogger):
|
|||
else:
|
||||
min_deployment = random.choice(healthy_deployments)
|
||||
return min_deployment
|
||||
|
||||
def get_available_deployments(
|
||||
self,
|
||||
model_group: str,
|
||||
healthy_deployments: list,
|
||||
):
|
||||
"""
|
||||
Sync helper to get deployments using least busy strategy
|
||||
"""
|
||||
request_count_api_key = f"{model_group}_request_count"
|
||||
all_deployments = self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||
return self._get_available_deployments(
|
||||
healthy_deployments=healthy_deployments,
|
||||
all_deployments=all_deployments,
|
||||
)
|
||||
|
||||
async def async_get_available_deployments(
|
||||
self, model_group: str, healthy_deployments: list
|
||||
):
|
||||
"""
|
||||
Async helper to get deployments using least busy strategy
|
||||
"""
|
||||
request_count_api_key = f"{model_group}_request_count"
|
||||
all_deployments = (
|
||||
await self.router_cache.async_get_cache(key=request_count_api_key) or {}
|
||||
)
|
||||
return self._get_available_deployments(
|
||||
healthy_deployments=healthy_deployments,
|
||||
all_deployments=all_deployments,
|
||||
)
|
||||
|
|
|
@ -243,7 +243,7 @@ class LowestLatencyLoggingHandler(CustomLogger):
|
|||
"latency"
|
||||
][: self.routing_args.max_latency_list_size - 1] + [1000.0]
|
||||
|
||||
self.router_cache.set_cache(
|
||||
await self.router_cache.async_set_cache(
|
||||
key=latency_key,
|
||||
value=request_count_dict,
|
||||
ttl=self.routing_args.ttl,
|
||||
|
@ -384,7 +384,7 @@ class LowestLatencyLoggingHandler(CustomLogger):
|
|||
request_count_dict[id][precise_minute].get("rpm", 0) + 1
|
||||
)
|
||||
|
||||
self.router_cache.set_cache(
|
||||
await self.router_cache.async_set_cache(
|
||||
key=latency_key, value=request_count_dict, ttl=self.routing_args.ttl
|
||||
) # reset map within window
|
||||
|
||||
|
|
|
@ -139,18 +139,22 @@ class LowestTPMLoggingHandler(CustomLogger):
|
|||
# update cache
|
||||
|
||||
## TPM
|
||||
request_count_dict = self.router_cache.get_cache(key=tpm_key) or {}
|
||||
request_count_dict = (
|
||||
await self.router_cache.async_get_cache(key=tpm_key) or {}
|
||||
)
|
||||
request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens
|
||||
|
||||
self.router_cache.set_cache(
|
||||
await self.router_cache.async_set_cache(
|
||||
key=tpm_key, value=request_count_dict, ttl=self.routing_args.ttl
|
||||
)
|
||||
|
||||
## RPM
|
||||
request_count_dict = self.router_cache.get_cache(key=rpm_key) or {}
|
||||
request_count_dict = (
|
||||
await self.router_cache.async_get_cache(key=rpm_key) or {}
|
||||
)
|
||||
request_count_dict[id] = request_count_dict.get(id, 0) + 1
|
||||
|
||||
self.router_cache.set_cache(
|
||||
await self.router_cache.async_set_cache(
|
||||
key=rpm_key, value=request_count_dict, ttl=self.routing_args.ttl
|
||||
)
|
||||
|
||||
|
|
120
tests/code_coverage_tests/test_router_strategy_async.py
Normal file
120
tests/code_coverage_tests/test_router_strategy_async.py
Normal file
|
@ -0,0 +1,120 @@
|
|||
"""
|
||||
Test that all cache calls in async functions in router_strategy/ are async
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import Dict, List, Tuple
|
||||
import ast
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import os
|
||||
|
||||
|
||||
class AsyncCacheCallVisitor(ast.NodeVisitor):
|
||||
def __init__(self):
|
||||
self.async_functions: Dict[str, List[Tuple[str, int]]] = {}
|
||||
self.current_function = None
|
||||
|
||||
def visit_AsyncFunctionDef(self, node):
|
||||
"""Visit async function definitions and store their cache calls"""
|
||||
self.current_function = node.name
|
||||
self.async_functions[node.name] = []
|
||||
self.generic_visit(node)
|
||||
self.current_function = None
|
||||
|
||||
def visit_Call(self, node):
|
||||
"""Visit function calls and check for cache operations"""
|
||||
if self.current_function is not None:
|
||||
# Check if it's a cache-related call
|
||||
if isinstance(node.func, ast.Attribute):
|
||||
method_name = node.func.attr
|
||||
if any(keyword in method_name.lower() for keyword in ["cache"]):
|
||||
# Get the full method call path
|
||||
if isinstance(node.func.value, ast.Name):
|
||||
full_call = f"{node.func.value.id}.{method_name}"
|
||||
elif isinstance(node.func.value, ast.Attribute):
|
||||
# Handle nested attributes like self.router_cache.get
|
||||
parts = []
|
||||
current = node.func.value
|
||||
while isinstance(current, ast.Attribute):
|
||||
parts.append(current.attr)
|
||||
current = current.value
|
||||
if isinstance(current, ast.Name):
|
||||
parts.append(current.id)
|
||||
parts.reverse()
|
||||
parts.append(method_name)
|
||||
full_call = ".".join(parts)
|
||||
else:
|
||||
full_call = method_name
|
||||
# Store both the call and its line number
|
||||
self.async_functions[self.current_function].append(
|
||||
(full_call, node.lineno)
|
||||
)
|
||||
self.generic_visit(node)
|
||||
|
||||
|
||||
def get_python_files(directory: str) -> List[str]:
|
||||
"""Get all Python files in the router_strategy directory"""
|
||||
python_files = []
|
||||
for file in os.listdir(directory):
|
||||
if file.endswith(".py") and not file.startswith("__"):
|
||||
python_files.append(os.path.join(directory, file))
|
||||
return python_files
|
||||
|
||||
|
||||
def analyze_file(file_path: str) -> Dict[str, List[Tuple[str, int]]]:
|
||||
"""Analyze a Python file for async functions and their cache calls"""
|
||||
with open(file_path, "r") as file:
|
||||
tree = ast.parse(file.read())
|
||||
|
||||
visitor = AsyncCacheCallVisitor()
|
||||
visitor.visit(tree)
|
||||
return visitor.async_functions
|
||||
|
||||
|
||||
def test_router_strategy_async_cache_calls():
|
||||
"""Test that all cache calls in async functions are properly async"""
|
||||
strategy_dir = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
|
||||
"litellm",
|
||||
"router_strategy",
|
||||
)
|
||||
|
||||
# Get all Python files in the router_strategy directory
|
||||
python_files = get_python_files(strategy_dir)
|
||||
|
||||
print("python files:", python_files)
|
||||
|
||||
all_async_functions: Dict[str, Dict[str, List[Tuple[str, int]]]] = {}
|
||||
|
||||
for file_path in python_files:
|
||||
file_name = os.path.basename(file_path)
|
||||
async_functions = analyze_file(file_path)
|
||||
|
||||
if async_functions:
|
||||
all_async_functions[file_name] = async_functions
|
||||
print(f"\nAnalyzing {file_name}:")
|
||||
|
||||
for func_name, cache_calls in async_functions.items():
|
||||
print(f"\nAsync function: {func_name}")
|
||||
print(f"Cache calls found: {cache_calls}")
|
||||
|
||||
# Assert that cache calls in async functions use async methods
|
||||
for call, line_number in cache_calls:
|
||||
if any(keyword in call.lower() for keyword in ["cache"]):
|
||||
assert (
|
||||
"async" in call.lower()
|
||||
), f"VIOLATION: Cache call '{call}' in async function '{func_name}' should be async. file path: {file_path}, line number: {line_number}"
|
||||
|
||||
# Assert we found async functions to analyze
|
||||
assert (
|
||||
len(all_async_functions) > 0
|
||||
), "No async functions found in router_strategy directory"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_router_strategy_async_cache_calls()
|
|
@ -158,7 +158,7 @@ async def test_dual_cache_batch_operations(is_async):
|
|||
if is_async:
|
||||
results = await dual_cache.async_batch_get_cache(test_keys)
|
||||
else:
|
||||
results = dual_cache.batch_get_cache(test_keys)
|
||||
results = dual_cache.batch_get_cache(test_keys, parent_otel_span=None)
|
||||
|
||||
assert results == test_values
|
||||
mock_redis_get.assert_not_called()
|
||||
|
@ -181,7 +181,10 @@ async def test_dual_cache_increment(is_async):
|
|||
) as mock_redis_increment:
|
||||
if is_async:
|
||||
result = await dual_cache.async_increment_cache(
|
||||
test_key, increment_value, local_only=True
|
||||
test_key,
|
||||
increment_value,
|
||||
local_only=True,
|
||||
parent_otel_span=None,
|
||||
)
|
||||
else:
|
||||
result = dual_cache.increment_cache(
|
||||
|
|
|
@ -65,7 +65,9 @@ def test_get_available_deployments():
|
|||
# test_get_available_deployments()
|
||||
|
||||
|
||||
def test_router_get_available_deployments():
|
||||
@pytest.mark.parametrize("async_test", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_get_available_deployments(async_test):
|
||||
"""
|
||||
Tests if 'get_available_deployments' returns the least busy deployment
|
||||
"""
|
||||
|
@ -114,8 +116,13 @@ def test_router_get_available_deployments():
|
|||
deployment = "azure/chatgpt-v-2"
|
||||
request_count_dict = {1: 10, 2: 54, 3: 100}
|
||||
cache_key = f"{model_group}_request_count"
|
||||
if async_test is True:
|
||||
await router.cache.async_set_cache(key=cache_key, value=request_count_dict)
|
||||
deployment = await router.async_get_available_deployment(
|
||||
model=model_group, messages=None
|
||||
)
|
||||
else:
|
||||
router.cache.set_cache(key=cache_key, value=request_count_dict)
|
||||
|
||||
deployment = router.get_available_deployment(model=model_group, messages=None)
|
||||
print(f"deployment: {deployment}")
|
||||
assert deployment["model_info"]["id"] == "1"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue