(router_strategy/) ensure all async functions use async cache methods (#6489)

* fix router strat

* use async set / get cache in router_strategy

* add coverage for router strategy

* fix imports

* fix batch_get_cache

* use async methods for least busy

* fix least busy use async methods

* fix test_dual_cache_increment

* test async_get_available_deployment when routing_strategy="least-busy"
This commit is contained in:
Ishaan Jaff 2024-10-29 21:07:17 +05:30 committed by GitHub
parent f9ba74ef87
commit 441adad3ae
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 202 additions and 20 deletions

View file

@ -424,6 +424,7 @@ jobs:
- run: ruff check ./litellm - run: ruff check ./litellm
- run: python ./tests/documentation_tests/test_general_setting_keys.py - run: python ./tests/documentation_tests/test_general_setting_keys.py
- run: python ./tests/code_coverage_tests/router_code_coverage.py - run: python ./tests/code_coverage_tests/router_code_coverage.py
- run: python ./tests/code_coverage_tests/test_router_strategy_async.py
- run: python ./tests/documentation_tests/test_env_keys.py - run: python ./tests/documentation_tests/test_env_keys.py
- run: helm lint ./deploy/charts/litellm-helm - run: helm lint ./deploy/charts/litellm-helm

View file

@ -5127,6 +5127,7 @@ class Router:
and self.routing_strategy != "simple-shuffle" and self.routing_strategy != "simple-shuffle"
and self.routing_strategy != "cost-based-routing" and self.routing_strategy != "cost-based-routing"
and self.routing_strategy != "latency-based-routing" and self.routing_strategy != "latency-based-routing"
and self.routing_strategy != "least-busy"
): # prevent regressions for other routing strategies, that don't have async get available deployments implemented. ): # prevent regressions for other routing strategies, that don't have async get available deployments implemented.
return self.get_available_deployment( return self.get_available_deployment(
model=model, model=model,
@ -5240,6 +5241,16 @@ class Router:
healthy_deployments=healthy_deployments, healthy_deployments=healthy_deployments,
model=model, model=model,
) )
elif (
self.routing_strategy == "least-busy"
and self.leastbusy_logger is not None
):
deployment = (
await self.leastbusy_logger.async_get_available_deployments(
model_group=model,
healthy_deployments=healthy_deployments, # type: ignore
)
)
else: else:
deployment = None deployment = None
if deployment is None: if deployment is None:

View file

@ -145,13 +145,14 @@ class LeastBusyLoggingHandler(CustomLogger):
request_count_api_key = f"{model_group}_request_count" request_count_api_key = f"{model_group}_request_count"
# decrement count in cache # decrement count in cache
request_count_dict = ( request_count_dict = (
self.router_cache.get_cache(key=request_count_api_key) or {} await self.router_cache.async_get_cache(key=request_count_api_key)
or {}
) )
request_count_value: Optional[int] = request_count_dict.get(id, 0) request_count_value: Optional[int] = request_count_dict.get(id, 0)
if request_count_value is None: if request_count_value is None:
return return
request_count_dict[id] = request_count_value - 1 request_count_dict[id] = request_count_value - 1
self.router_cache.set_cache( await self.router_cache.async_set_cache(
key=request_count_api_key, value=request_count_dict key=request_count_api_key, value=request_count_dict
) )
@ -178,13 +179,14 @@ class LeastBusyLoggingHandler(CustomLogger):
request_count_api_key = f"{model_group}_request_count" request_count_api_key = f"{model_group}_request_count"
# decrement count in cache # decrement count in cache
request_count_dict = ( request_count_dict = (
self.router_cache.get_cache(key=request_count_api_key) or {} await self.router_cache.async_get_cache(key=request_count_api_key)
or {}
) )
request_count_value: Optional[int] = request_count_dict.get(id, 0) request_count_value: Optional[int] = request_count_dict.get(id, 0)
if request_count_value is None: if request_count_value is None:
return return
request_count_dict[id] = request_count_value - 1 request_count_dict[id] = request_count_value - 1
self.router_cache.set_cache( await self.router_cache.async_set_cache(
key=request_count_api_key, value=request_count_dict key=request_count_api_key, value=request_count_dict
) )
@ -194,10 +196,14 @@ class LeastBusyLoggingHandler(CustomLogger):
except Exception: except Exception:
pass pass
def get_available_deployments(self, model_group: str, healthy_deployments: list): def _get_available_deployments(
request_count_api_key = f"{model_group}_request_count" self,
deployments = self.router_cache.get_cache(key=request_count_api_key) or {} healthy_deployments: list,
all_deployments = deployments all_deployments: dict,
):
"""
Helper to get deployments using least busy strategy
"""
for d in healthy_deployments: for d in healthy_deployments:
## if healthy deployment not yet used ## if healthy deployment not yet used
if d["model_info"]["id"] not in all_deployments: if d["model_info"]["id"] not in all_deployments:
@ -219,3 +225,33 @@ class LeastBusyLoggingHandler(CustomLogger):
else: else:
min_deployment = random.choice(healthy_deployments) min_deployment = random.choice(healthy_deployments)
return min_deployment return min_deployment
def get_available_deployments(
self,
model_group: str,
healthy_deployments: list,
):
"""
Sync helper to get deployments using least busy strategy
"""
request_count_api_key = f"{model_group}_request_count"
all_deployments = self.router_cache.get_cache(key=request_count_api_key) or {}
return self._get_available_deployments(
healthy_deployments=healthy_deployments,
all_deployments=all_deployments,
)
async def async_get_available_deployments(
self, model_group: str, healthy_deployments: list
):
"""
Async helper to get deployments using least busy strategy
"""
request_count_api_key = f"{model_group}_request_count"
all_deployments = (
await self.router_cache.async_get_cache(key=request_count_api_key) or {}
)
return self._get_available_deployments(
healthy_deployments=healthy_deployments,
all_deployments=all_deployments,
)

View file

@ -243,7 +243,7 @@ class LowestLatencyLoggingHandler(CustomLogger):
"latency" "latency"
][: self.routing_args.max_latency_list_size - 1] + [1000.0] ][: self.routing_args.max_latency_list_size - 1] + [1000.0]
self.router_cache.set_cache( await self.router_cache.async_set_cache(
key=latency_key, key=latency_key,
value=request_count_dict, value=request_count_dict,
ttl=self.routing_args.ttl, ttl=self.routing_args.ttl,
@ -384,7 +384,7 @@ class LowestLatencyLoggingHandler(CustomLogger):
request_count_dict[id][precise_minute].get("rpm", 0) + 1 request_count_dict[id][precise_minute].get("rpm", 0) + 1
) )
self.router_cache.set_cache( await self.router_cache.async_set_cache(
key=latency_key, value=request_count_dict, ttl=self.routing_args.ttl key=latency_key, value=request_count_dict, ttl=self.routing_args.ttl
) # reset map within window ) # reset map within window

View file

@ -139,18 +139,22 @@ class LowestTPMLoggingHandler(CustomLogger):
# update cache # update cache
## TPM ## TPM
request_count_dict = self.router_cache.get_cache(key=tpm_key) or {} request_count_dict = (
await self.router_cache.async_get_cache(key=tpm_key) or {}
)
request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens
self.router_cache.set_cache( await self.router_cache.async_set_cache(
key=tpm_key, value=request_count_dict, ttl=self.routing_args.ttl key=tpm_key, value=request_count_dict, ttl=self.routing_args.ttl
) )
## RPM ## RPM
request_count_dict = self.router_cache.get_cache(key=rpm_key) or {} request_count_dict = (
await self.router_cache.async_get_cache(key=rpm_key) or {}
)
request_count_dict[id] = request_count_dict.get(id, 0) + 1 request_count_dict[id] = request_count_dict.get(id, 0) + 1
self.router_cache.set_cache( await self.router_cache.async_set_cache(
key=rpm_key, value=request_count_dict, ttl=self.routing_args.ttl key=rpm_key, value=request_count_dict, ttl=self.routing_args.ttl
) )

View file

@ -0,0 +1,120 @@
"""
Test that all cache calls in async functions in router_strategy/ are async
"""
import os
import sys
from typing import Dict, List, Tuple
import ast
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import os
class AsyncCacheCallVisitor(ast.NodeVisitor):
def __init__(self):
self.async_functions: Dict[str, List[Tuple[str, int]]] = {}
self.current_function = None
def visit_AsyncFunctionDef(self, node):
"""Visit async function definitions and store their cache calls"""
self.current_function = node.name
self.async_functions[node.name] = []
self.generic_visit(node)
self.current_function = None
def visit_Call(self, node):
"""Visit function calls and check for cache operations"""
if self.current_function is not None:
# Check if it's a cache-related call
if isinstance(node.func, ast.Attribute):
method_name = node.func.attr
if any(keyword in method_name.lower() for keyword in ["cache"]):
# Get the full method call path
if isinstance(node.func.value, ast.Name):
full_call = f"{node.func.value.id}.{method_name}"
elif isinstance(node.func.value, ast.Attribute):
# Handle nested attributes like self.router_cache.get
parts = []
current = node.func.value
while isinstance(current, ast.Attribute):
parts.append(current.attr)
current = current.value
if isinstance(current, ast.Name):
parts.append(current.id)
parts.reverse()
parts.append(method_name)
full_call = ".".join(parts)
else:
full_call = method_name
# Store both the call and its line number
self.async_functions[self.current_function].append(
(full_call, node.lineno)
)
self.generic_visit(node)
def get_python_files(directory: str) -> List[str]:
"""Get all Python files in the router_strategy directory"""
python_files = []
for file in os.listdir(directory):
if file.endswith(".py") and not file.startswith("__"):
python_files.append(os.path.join(directory, file))
return python_files
def analyze_file(file_path: str) -> Dict[str, List[Tuple[str, int]]]:
"""Analyze a Python file for async functions and their cache calls"""
with open(file_path, "r") as file:
tree = ast.parse(file.read())
visitor = AsyncCacheCallVisitor()
visitor.visit(tree)
return visitor.async_functions
def test_router_strategy_async_cache_calls():
"""Test that all cache calls in async functions are properly async"""
strategy_dir = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
"litellm",
"router_strategy",
)
# Get all Python files in the router_strategy directory
python_files = get_python_files(strategy_dir)
print("python files:", python_files)
all_async_functions: Dict[str, Dict[str, List[Tuple[str, int]]]] = {}
for file_path in python_files:
file_name = os.path.basename(file_path)
async_functions = analyze_file(file_path)
if async_functions:
all_async_functions[file_name] = async_functions
print(f"\nAnalyzing {file_name}:")
for func_name, cache_calls in async_functions.items():
print(f"\nAsync function: {func_name}")
print(f"Cache calls found: {cache_calls}")
# Assert that cache calls in async functions use async methods
for call, line_number in cache_calls:
if any(keyword in call.lower() for keyword in ["cache"]):
assert (
"async" in call.lower()
), f"VIOLATION: Cache call '{call}' in async function '{func_name}' should be async. file path: {file_path}, line number: {line_number}"
# Assert we found async functions to analyze
assert (
len(all_async_functions) > 0
), "No async functions found in router_strategy directory"
if __name__ == "__main__":
test_router_strategy_async_cache_calls()

View file

@ -158,7 +158,7 @@ async def test_dual_cache_batch_operations(is_async):
if is_async: if is_async:
results = await dual_cache.async_batch_get_cache(test_keys) results = await dual_cache.async_batch_get_cache(test_keys)
else: else:
results = dual_cache.batch_get_cache(test_keys) results = dual_cache.batch_get_cache(test_keys, parent_otel_span=None)
assert results == test_values assert results == test_values
mock_redis_get.assert_not_called() mock_redis_get.assert_not_called()
@ -181,7 +181,10 @@ async def test_dual_cache_increment(is_async):
) as mock_redis_increment: ) as mock_redis_increment:
if is_async: if is_async:
result = await dual_cache.async_increment_cache( result = await dual_cache.async_increment_cache(
test_key, increment_value, local_only=True test_key,
increment_value,
local_only=True,
parent_otel_span=None,
) )
else: else:
result = dual_cache.increment_cache( result = dual_cache.increment_cache(

View file

@ -65,7 +65,9 @@ def test_get_available_deployments():
# test_get_available_deployments() # test_get_available_deployments()
def test_router_get_available_deployments(): @pytest.mark.parametrize("async_test", [True, False])
@pytest.mark.asyncio
async def test_router_get_available_deployments(async_test):
""" """
Tests if 'get_available_deployments' returns the least busy deployment Tests if 'get_available_deployments' returns the least busy deployment
""" """
@ -114,8 +116,13 @@ def test_router_get_available_deployments():
deployment = "azure/chatgpt-v-2" deployment = "azure/chatgpt-v-2"
request_count_dict = {1: 10, 2: 54, 3: 100} request_count_dict = {1: 10, 2: 54, 3: 100}
cache_key = f"{model_group}_request_count" cache_key = f"{model_group}_request_count"
if async_test is True:
await router.cache.async_set_cache(key=cache_key, value=request_count_dict)
deployment = await router.async_get_available_deployment(
model=model_group, messages=None
)
else:
router.cache.set_cache(key=cache_key, value=request_count_dict) router.cache.set_cache(key=cache_key, value=request_count_dict)
deployment = router.get_available_deployment(model=model_group, messages=None) deployment = router.get_available_deployment(model=model_group, messages=None)
print(f"deployment: {deployment}") print(f"deployment: {deployment}")
assert deployment["model_info"]["id"] == "1" assert deployment["model_info"]["id"] == "1"