From 441adad3ae58acee088afa3dcc0a7dcedd7415d1 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Tue, 29 Oct 2024 21:07:17 +0530 Subject: [PATCH] (router_strategy/) ensure all async functions use async cache methods (#6489) * fix router strat * use async set / get cache in router_strategy * add coverage for router strategy * fix imports * fix batch_get_cache * use async methods for least busy * fix least busy use async methods * fix test_dual_cache_increment * test async_get_available_deployment when routing_strategy="least-busy" --- .circleci/config.yml | 1 + litellm/router.py | 11 ++ litellm/router_strategy/least_busy.py | 52 ++++++-- litellm/router_strategy/lowest_latency.py | 4 +- litellm/router_strategy/lowest_tpm_rpm.py | 12 +- .../test_router_strategy_async.py | 120 ++++++++++++++++++ tests/local_testing/test_dual_cache.py | 7 +- .../local_testing/test_least_busy_routing.py | 15 ++- 8 files changed, 202 insertions(+), 20 deletions(-) create mode 100644 tests/code_coverage_tests/test_router_strategy_async.py diff --git a/.circleci/config.yml b/.circleci/config.yml index 8fcf51376..4734ee2a7 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -424,6 +424,7 @@ jobs: - run: ruff check ./litellm - run: python ./tests/documentation_tests/test_general_setting_keys.py - run: python ./tests/code_coverage_tests/router_code_coverage.py + - run: python ./tests/code_coverage_tests/test_router_strategy_async.py - run: python ./tests/documentation_tests/test_env_keys.py - run: helm lint ./deploy/charts/litellm-helm diff --git a/litellm/router.py b/litellm/router.py index 5ccdbcf4a..e2c033c60 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -5127,6 +5127,7 @@ class Router: and self.routing_strategy != "simple-shuffle" and self.routing_strategy != "cost-based-routing" and self.routing_strategy != "latency-based-routing" + and self.routing_strategy != "least-busy" ): # prevent regressions for other routing strategies, that don't have async get available deployments implemented. return self.get_available_deployment( model=model, @@ -5240,6 +5241,16 @@ class Router: healthy_deployments=healthy_deployments, model=model, ) + elif ( + self.routing_strategy == "least-busy" + and self.leastbusy_logger is not None + ): + deployment = ( + await self.leastbusy_logger.async_get_available_deployments( + model_group=model, + healthy_deployments=healthy_deployments, # type: ignore + ) + ) else: deployment = None if deployment is None: diff --git a/litellm/router_strategy/least_busy.py b/litellm/router_strategy/least_busy.py index f1b35bb89..b1a85440f 100644 --- a/litellm/router_strategy/least_busy.py +++ b/litellm/router_strategy/least_busy.py @@ -145,13 +145,14 @@ class LeastBusyLoggingHandler(CustomLogger): request_count_api_key = f"{model_group}_request_count" # decrement count in cache request_count_dict = ( - self.router_cache.get_cache(key=request_count_api_key) or {} + await self.router_cache.async_get_cache(key=request_count_api_key) + or {} ) request_count_value: Optional[int] = request_count_dict.get(id, 0) if request_count_value is None: return request_count_dict[id] = request_count_value - 1 - self.router_cache.set_cache( + await self.router_cache.async_set_cache( key=request_count_api_key, value=request_count_dict ) @@ -178,13 +179,14 @@ class LeastBusyLoggingHandler(CustomLogger): request_count_api_key = f"{model_group}_request_count" # decrement count in cache request_count_dict = ( - self.router_cache.get_cache(key=request_count_api_key) or {} + await self.router_cache.async_get_cache(key=request_count_api_key) + or {} ) request_count_value: Optional[int] = request_count_dict.get(id, 0) if request_count_value is None: return request_count_dict[id] = request_count_value - 1 - self.router_cache.set_cache( + await self.router_cache.async_set_cache( key=request_count_api_key, value=request_count_dict ) @@ -194,10 +196,14 @@ class LeastBusyLoggingHandler(CustomLogger): except Exception: pass - def get_available_deployments(self, model_group: str, healthy_deployments: list): - request_count_api_key = f"{model_group}_request_count" - deployments = self.router_cache.get_cache(key=request_count_api_key) or {} - all_deployments = deployments + def _get_available_deployments( + self, + healthy_deployments: list, + all_deployments: dict, + ): + """ + Helper to get deployments using least busy strategy + """ for d in healthy_deployments: ## if healthy deployment not yet used if d["model_info"]["id"] not in all_deployments: @@ -219,3 +225,33 @@ class LeastBusyLoggingHandler(CustomLogger): else: min_deployment = random.choice(healthy_deployments) return min_deployment + + def get_available_deployments( + self, + model_group: str, + healthy_deployments: list, + ): + """ + Sync helper to get deployments using least busy strategy + """ + request_count_api_key = f"{model_group}_request_count" + all_deployments = self.router_cache.get_cache(key=request_count_api_key) or {} + return self._get_available_deployments( + healthy_deployments=healthy_deployments, + all_deployments=all_deployments, + ) + + async def async_get_available_deployments( + self, model_group: str, healthy_deployments: list + ): + """ + Async helper to get deployments using least busy strategy + """ + request_count_api_key = f"{model_group}_request_count" + all_deployments = ( + await self.router_cache.async_get_cache(key=request_count_api_key) or {} + ) + return self._get_available_deployments( + healthy_deployments=healthy_deployments, + all_deployments=all_deployments, + ) diff --git a/litellm/router_strategy/lowest_latency.py b/litellm/router_strategy/lowest_latency.py index 287e60146..a96a8fa94 100644 --- a/litellm/router_strategy/lowest_latency.py +++ b/litellm/router_strategy/lowest_latency.py @@ -243,7 +243,7 @@ class LowestLatencyLoggingHandler(CustomLogger): "latency" ][: self.routing_args.max_latency_list_size - 1] + [1000.0] - self.router_cache.set_cache( + await self.router_cache.async_set_cache( key=latency_key, value=request_count_dict, ttl=self.routing_args.ttl, @@ -384,7 +384,7 @@ class LowestLatencyLoggingHandler(CustomLogger): request_count_dict[id][precise_minute].get("rpm", 0) + 1 ) - self.router_cache.set_cache( + await self.router_cache.async_set_cache( key=latency_key, value=request_count_dict, ttl=self.routing_args.ttl ) # reset map within window diff --git a/litellm/router_strategy/lowest_tpm_rpm.py b/litellm/router_strategy/lowest_tpm_rpm.py index 45f32fbf0..c79698ecf 100644 --- a/litellm/router_strategy/lowest_tpm_rpm.py +++ b/litellm/router_strategy/lowest_tpm_rpm.py @@ -139,18 +139,22 @@ class LowestTPMLoggingHandler(CustomLogger): # update cache ## TPM - request_count_dict = self.router_cache.get_cache(key=tpm_key) or {} + request_count_dict = ( + await self.router_cache.async_get_cache(key=tpm_key) or {} + ) request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens - self.router_cache.set_cache( + await self.router_cache.async_set_cache( key=tpm_key, value=request_count_dict, ttl=self.routing_args.ttl ) ## RPM - request_count_dict = self.router_cache.get_cache(key=rpm_key) or {} + request_count_dict = ( + await self.router_cache.async_get_cache(key=rpm_key) or {} + ) request_count_dict[id] = request_count_dict.get(id, 0) + 1 - self.router_cache.set_cache( + await self.router_cache.async_set_cache( key=rpm_key, value=request_count_dict, ttl=self.routing_args.ttl ) diff --git a/tests/code_coverage_tests/test_router_strategy_async.py b/tests/code_coverage_tests/test_router_strategy_async.py new file mode 100644 index 000000000..05bdca10f --- /dev/null +++ b/tests/code_coverage_tests/test_router_strategy_async.py @@ -0,0 +1,120 @@ +""" +Test that all cache calls in async functions in router_strategy/ are async + +""" + +import os +import sys +from typing import Dict, List, Tuple +import ast + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import os + + +class AsyncCacheCallVisitor(ast.NodeVisitor): + def __init__(self): + self.async_functions: Dict[str, List[Tuple[str, int]]] = {} + self.current_function = None + + def visit_AsyncFunctionDef(self, node): + """Visit async function definitions and store their cache calls""" + self.current_function = node.name + self.async_functions[node.name] = [] + self.generic_visit(node) + self.current_function = None + + def visit_Call(self, node): + """Visit function calls and check for cache operations""" + if self.current_function is not None: + # Check if it's a cache-related call + if isinstance(node.func, ast.Attribute): + method_name = node.func.attr + if any(keyword in method_name.lower() for keyword in ["cache"]): + # Get the full method call path + if isinstance(node.func.value, ast.Name): + full_call = f"{node.func.value.id}.{method_name}" + elif isinstance(node.func.value, ast.Attribute): + # Handle nested attributes like self.router_cache.get + parts = [] + current = node.func.value + while isinstance(current, ast.Attribute): + parts.append(current.attr) + current = current.value + if isinstance(current, ast.Name): + parts.append(current.id) + parts.reverse() + parts.append(method_name) + full_call = ".".join(parts) + else: + full_call = method_name + # Store both the call and its line number + self.async_functions[self.current_function].append( + (full_call, node.lineno) + ) + self.generic_visit(node) + + +def get_python_files(directory: str) -> List[str]: + """Get all Python files in the router_strategy directory""" + python_files = [] + for file in os.listdir(directory): + if file.endswith(".py") and not file.startswith("__"): + python_files.append(os.path.join(directory, file)) + return python_files + + +def analyze_file(file_path: str) -> Dict[str, List[Tuple[str, int]]]: + """Analyze a Python file for async functions and their cache calls""" + with open(file_path, "r") as file: + tree = ast.parse(file.read()) + + visitor = AsyncCacheCallVisitor() + visitor.visit(tree) + return visitor.async_functions + + +def test_router_strategy_async_cache_calls(): + """Test that all cache calls in async functions are properly async""" + strategy_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(__file__))), + "litellm", + "router_strategy", + ) + + # Get all Python files in the router_strategy directory + python_files = get_python_files(strategy_dir) + + print("python files:", python_files) + + all_async_functions: Dict[str, Dict[str, List[Tuple[str, int]]]] = {} + + for file_path in python_files: + file_name = os.path.basename(file_path) + async_functions = analyze_file(file_path) + + if async_functions: + all_async_functions[file_name] = async_functions + print(f"\nAnalyzing {file_name}:") + + for func_name, cache_calls in async_functions.items(): + print(f"\nAsync function: {func_name}") + print(f"Cache calls found: {cache_calls}") + + # Assert that cache calls in async functions use async methods + for call, line_number in cache_calls: + if any(keyword in call.lower() for keyword in ["cache"]): + assert ( + "async" in call.lower() + ), f"VIOLATION: Cache call '{call}' in async function '{func_name}' should be async. file path: {file_path}, line number: {line_number}" + + # Assert we found async functions to analyze + assert ( + len(all_async_functions) > 0 + ), "No async functions found in router_strategy directory" + + +if __name__ == "__main__": + test_router_strategy_async_cache_calls() diff --git a/tests/local_testing/test_dual_cache.py b/tests/local_testing/test_dual_cache.py index d8c7cf358..c3f3216d5 100644 --- a/tests/local_testing/test_dual_cache.py +++ b/tests/local_testing/test_dual_cache.py @@ -158,7 +158,7 @@ async def test_dual_cache_batch_operations(is_async): if is_async: results = await dual_cache.async_batch_get_cache(test_keys) else: - results = dual_cache.batch_get_cache(test_keys) + results = dual_cache.batch_get_cache(test_keys, parent_otel_span=None) assert results == test_values mock_redis_get.assert_not_called() @@ -181,7 +181,10 @@ async def test_dual_cache_increment(is_async): ) as mock_redis_increment: if is_async: result = await dual_cache.async_increment_cache( - test_key, increment_value, local_only=True + test_key, + increment_value, + local_only=True, + parent_otel_span=None, ) else: result = dual_cache.increment_cache( diff --git a/tests/local_testing/test_least_busy_routing.py b/tests/local_testing/test_least_busy_routing.py index dc7db9560..c9c6eb609 100644 --- a/tests/local_testing/test_least_busy_routing.py +++ b/tests/local_testing/test_least_busy_routing.py @@ -65,7 +65,9 @@ def test_get_available_deployments(): # test_get_available_deployments() -def test_router_get_available_deployments(): +@pytest.mark.parametrize("async_test", [True, False]) +@pytest.mark.asyncio +async def test_router_get_available_deployments(async_test): """ Tests if 'get_available_deployments' returns the least busy deployment """ @@ -114,9 +116,14 @@ def test_router_get_available_deployments(): deployment = "azure/chatgpt-v-2" request_count_dict = {1: 10, 2: 54, 3: 100} cache_key = f"{model_group}_request_count" - router.cache.set_cache(key=cache_key, value=request_count_dict) - - deployment = router.get_available_deployment(model=model_group, messages=None) + if async_test is True: + await router.cache.async_set_cache(key=cache_key, value=request_count_dict) + deployment = await router.async_get_available_deployment( + model=model_group, messages=None + ) + else: + router.cache.set_cache(key=cache_key, value=request_count_dict) + deployment = router.get_available_deployment(model=model_group, messages=None) print(f"deployment: {deployment}") assert deployment["model_info"]["id"] == "1"