From e22e8d24ef5be30f3dd83df87b200ccabd581d6b Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Wed, 16 Oct 2024 21:30:25 -0700 Subject: [PATCH] Litellm router code coverage 3 (#6274) * refactor(router.py): move assistants api endpoints to using 1 pass-through factory function Reduces code, increases testing coverage * refactor(router.py): reduce _common_check_available_deployment function size make code more maintainable - reduce possible errors * test(router_code_coverage.py): include batch_utils + pattern matching in enforced 100% code coverage Improves reliability * fix(router.py): fix model id match model dump --- litellm/proxy/_new_secret_config.yaml | 6 + litellm/router.py | 319 +++++------------- .../router_utils/pattern_match_deployments.py | 51 +++ .../router_code_coverage.py | 27 +- .../router_enforce_line_length.py | 66 ++++ tests/local_testing/test_router.py | 18 + .../test_router_batch_utils.py | 84 +++++ .../test_router_helper_utils.py | 80 +++++ 8 files changed, 407 insertions(+), 244 deletions(-) create mode 100644 tests/code_coverage_tests/router_enforce_line_length.py create mode 100644 tests/router_unit_tests/test_router_batch_utils.py diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 883a02b6e..5f847c04c 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -25,3 +25,9 @@ model_list: # guard_name: "gibberish_guard" # mode: "post_call" # api_base: os.environ/GUARDRAILS_AI_API_BASE + +assistant_settings: + custom_llm_provider: azure + litellm_params: + api_key: os.environ/AZURE_API_KEY + api_base: os.environ/AZURE_API_BASE \ No newline at end of file diff --git a/litellm/router.py b/litellm/router.py index cec462846..aa02cfe63 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -24,7 +24,18 @@ import traceback import uuid from collections import defaultdict from datetime import datetime -from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, TypedDict, Union +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Literal, + Optional, + Tuple, + TypedDict, + Union, +) import httpx import openai @@ -520,6 +531,19 @@ class Router: if self.alerting_config is not None: self._initialize_alerting() + self.initialize_assistants_endpoint() + + def initialize_assistants_endpoint(self): + ## INITIALIZE PASS THROUGH ASSISTANTS ENDPOINT ## + self.acreate_assistants = self.factory_function(litellm.acreate_assistants) + self.adelete_assistant = self.factory_function(litellm.adelete_assistant) + self.aget_assistants = self.factory_function(litellm.aget_assistants) + self.acreate_thread = self.factory_function(litellm.acreate_thread) + self.aget_thread = self.factory_function(litellm.aget_thread) + self.a_add_message = self.factory_function(litellm.a_add_message) + self.aget_messages = self.factory_function(litellm.aget_messages) + self.arun_thread = self.factory_function(litellm.arun_thread) + def validate_fallbacks(self, fallback_param: Optional[List]): """ Validate the fallbacks parameter. @@ -2167,7 +2191,6 @@ class Router: raise e #### FILES API #### - async def acreate_file( self, model: str, @@ -2504,114 +2527,29 @@ class Router: #### ASSISTANTS API #### - async def acreate_assistants( - self, - custom_llm_provider: Optional[Literal["openai", "azure"]] = None, - client: Optional[AsyncOpenAI] = None, - **kwargs, - ) -> Assistant: - if custom_llm_provider is None: - if self.assistants_config is not None: - custom_llm_provider = self.assistants_config["custom_llm_provider"] - kwargs.update(self.assistants_config["litellm_params"]) - else: - raise Exception( - "'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`" - ) - - return await litellm.acreate_assistants( - custom_llm_provider=custom_llm_provider, client=client, **kwargs - ) - - async def adelete_assistant( - self, - custom_llm_provider: Optional[Literal["openai", "azure"]] = None, - client: Optional[AsyncOpenAI] = None, - **kwargs, - ) -> AssistantDeleted: - if custom_llm_provider is None: - if self.assistants_config is not None: - custom_llm_provider = self.assistants_config["custom_llm_provider"] - kwargs.update(self.assistants_config["litellm_params"]) - else: - raise Exception( - "'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`" - ) - - return await litellm.adelete_assistant( - custom_llm_provider=custom_llm_provider, client=client, **kwargs - ) - - async def aget_assistants( - self, - custom_llm_provider: Optional[Literal["openai", "azure"]] = None, - client: Optional[AsyncOpenAI] = None, - **kwargs, - ) -> AsyncCursorPage[Assistant]: - if custom_llm_provider is None: - if self.assistants_config is not None: - custom_llm_provider = self.assistants_config["custom_llm_provider"] - kwargs.update(self.assistants_config["litellm_params"]) - else: - raise Exception( - "'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`" - ) - - return await litellm.aget_assistants( - custom_llm_provider=custom_llm_provider, client=client, **kwargs - ) - - async def acreate_thread( - self, - custom_llm_provider: Optional[Literal["openai", "azure"]] = None, - client: Optional[AsyncOpenAI] = None, - **kwargs, - ) -> Thread: - if custom_llm_provider is None: - if self.assistants_config is not None: - custom_llm_provider = self.assistants_config["custom_llm_provider"] - kwargs.update(self.assistants_config["litellm_params"]) - else: - raise Exception( - "'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`" - ) - return await litellm.acreate_thread( - custom_llm_provider=custom_llm_provider, client=client, **kwargs - ) - - async def aget_thread( - self, - thread_id: str, - custom_llm_provider: Optional[Literal["openai", "azure"]] = None, - client: Optional[AsyncOpenAI] = None, - **kwargs, - ) -> Thread: - if custom_llm_provider is None: - if self.assistants_config is not None: - custom_llm_provider = self.assistants_config["custom_llm_provider"] - kwargs.update(self.assistants_config["litellm_params"]) - else: - raise Exception( - "'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`" - ) - return await litellm.aget_thread( - custom_llm_provider=custom_llm_provider, - thread_id=thread_id, - client=client, + def factory_function(self, original_function: Callable): + async def new_function( + custom_llm_provider: Optional[Literal["openai", "azure"]] = None, + client: Optional["AsyncOpenAI"] = None, **kwargs, - ) + ): + return await self._pass_through_assistants_endpoint_factory( + original_function=original_function, + custom_llm_provider=custom_llm_provider, + client=client, + **kwargs, + ) - async def a_add_message( + return new_function + + async def _pass_through_assistants_endpoint_factory( self, - thread_id: str, - role: Literal["user", "assistant"], - content: str, - attachments: Optional[List[Attachment]] = None, - metadata: Optional[dict] = None, + original_function: Callable, custom_llm_provider: Optional[Literal["openai", "azure"]] = None, client: Optional[AsyncOpenAI] = None, **kwargs, - ) -> OpenAIMessage: + ): + """Internal helper function to pass through the assistants endpoint""" if custom_llm_provider is None: if self.assistants_config is not None: custom_llm_provider = self.assistants_config["custom_llm_provider"] @@ -2620,76 +2558,8 @@ class Router: raise Exception( "'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`" ) - - return await litellm.a_add_message( - custom_llm_provider=custom_llm_provider, - thread_id=thread_id, - role=role, - content=content, - attachments=attachments, - metadata=metadata, - client=client, - **kwargs, - ) - - async def aget_messages( - self, - thread_id: str, - custom_llm_provider: Optional[Literal["openai", "azure"]] = None, - client: Optional[AsyncOpenAI] = None, - **kwargs, - ) -> AsyncCursorPage[OpenAIMessage]: - if custom_llm_provider is None: - if self.assistants_config is not None: - custom_llm_provider = self.assistants_config["custom_llm_provider"] - kwargs.update(self.assistants_config["litellm_params"]) - else: - raise Exception( - "'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`" - ) - return await litellm.aget_messages( - custom_llm_provider=custom_llm_provider, - thread_id=thread_id, - client=client, - **kwargs, - ) - - async def arun_thread( - self, - thread_id: str, - assistant_id: str, - custom_llm_provider: Optional[Literal["openai", "azure"]] = None, - additional_instructions: Optional[str] = None, - instructions: Optional[str] = None, - metadata: Optional[dict] = None, - model: Optional[str] = None, - stream: Optional[bool] = None, - tools: Optional[Iterable[AssistantToolParam]] = None, - client: Optional[Any] = None, - **kwargs, - ) -> Run: - - if custom_llm_provider is None: - if self.assistants_config is not None: - custom_llm_provider = self.assistants_config["custom_llm_provider"] - kwargs.update(self.assistants_config["litellm_params"]) - else: - raise Exception( - "'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`" - ) - - return await litellm.arun_thread( - custom_llm_provider=custom_llm_provider, - thread_id=thread_id, - assistant_id=assistant_id, - additional_instructions=additional_instructions, - instructions=instructions, - metadata=metadata, - model=model, - stream=stream, - tools=tools, - client=client, - **kwargs, + return await original_function( # type: ignore + custom_llm_provider=custom_llm_provider, client=client, **kwargs ) #### [END] ASSISTANTS API #### @@ -4609,7 +4479,7 @@ class Router: """ returned_models: List[DeploymentTypedDict] = [] for model in self.model_list: - if model["model_name"] == model_name: + if model_name is not None and model["model_name"] == model_name: if model_alias is not None: alias_model = copy.deepcopy(model) alias_model["model_name"] = model_alias @@ -5007,82 +4877,73 @@ class Router: return _returned_deployments + def _get_model_from_alias(self, model: str) -> Optional[str]: + """ + Get the model from the alias. + + Returns: + - str, the litellm model name + - None, if model is not in model group alias + """ + if model not in self.model_group_alias: + return None + + _item = self.model_group_alias[model] + if isinstance(_item, str): + model = _item + else: + model = _item["model"] + + return model + + def _get_deployment_by_litellm_model(self, model: str) -> List: + """ + Get the deployment by litellm model. + """ + return [m for m in self.model_list if m["litellm_params"]["model"] == model] + def _common_checks_available_deployment( self, model: str, messages: Optional[List[Dict[str, str]]] = None, input: Optional[Union[str, List]] = None, specific_deployment: Optional[bool] = False, - ) -> Tuple[str, Union[list, dict]]: + ) -> Tuple[str, Union[List, Dict]]: """ Common checks for 'get_available_deployment' across sync + async call. If 'healthy_deployments' returned is None, this means the user chose a specific deployment Returns - - Dict, if specific model chosen + - str, the litellm model name - List, if multiple models chosen + - Dict, if specific model chosen """ + # check if aliases set on litellm model alias map if specific_deployment is True: - # users can also specify a specific deployment name. At this point we should check if they are just trying to call a specific deployment - for deployment in self.model_list: - deployment_model = deployment.get("litellm_params").get("model") - if deployment_model == model: - # User Passed a specific deployment name on their config.yaml, example azure/chat-gpt-v-2 - # return the first deployment where the `model` matches the specificed deployment name - return deployment_model, deployment - raise ValueError( - f"LiteLLM Router: Trying to call specific deployment, but Model:{model} does not exist in Model List: {self.get_model_names()}" - ) + return model, self._get_deployment_by_litellm_model(model=model) elif model in self.get_model_ids(): - deployment = self.get_model_info(id=model) + deployment = self.get_deployment(model_id=model) if deployment is not None: - deployment_model = deployment.get("litellm_params", {}).get("model") - return deployment_model, deployment + deployment_model = deployment.litellm_params.model + return deployment_model, deployment.model_dump(exclude_none=True) raise ValueError( f"LiteLLM Router: Trying to call specific deployment, but Model ID :{model} does not exist in \ Model ID List: {self.get_model_ids}" ) - if model in self.model_group_alias: - _item = self.model_group_alias[model] - if isinstance(_item, str): - model = _item - else: - model = _item["model"] + _model_from_alias = self._get_model_from_alias(model=model) + if _model_from_alias is not None: + model = _model_from_alias if model not in self.model_names: # check if provider/ specific wildcard routing use pattern matching - custom_llm_provider: Optional[str] = None - try: - ( - _, - custom_llm_provider, - _, - _, - ) = litellm.get_llm_provider(model=model) - except Exception: - # get_llm_provider raises exception when provider is unknown - pass - - """ - self.pattern_router.route(model): - does exact pattern matching. Example openai/gpt-3.5-turbo gets routed to pattern openai/* - - self.pattern_router.route(f"{custom_llm_provider}/{model}"): - does pattern matching using litellm.get_llm_provider(), example claude-3-5-sonnet-20240620 gets routed to anthropic/* since 'claude-3-5-sonnet-20240620' is an Anthropic Model - """ - _pattern_router_response = self.pattern_router.route( - model - ) or self.pattern_router.route(f"{custom_llm_provider}/{model}") - if _pattern_router_response is not None: - provider_deployments = [] - for deployment in _pattern_router_response: - dep = copy.deepcopy(deployment) - dep["litellm_params"]["model"] = model - provider_deployments.append(dep) - return model, provider_deployments + pattern_deployments = self.pattern_router.get_deployments_by_pattern( + model=model, + ) + if pattern_deployments: + return model, pattern_deployments # check if default deployment is set if self.default_deployment is not None: @@ -5094,12 +4955,11 @@ class Router: ## get healthy deployments ### get all deployments - healthy_deployments = [m for m in self.model_list if m["model_name"] == model] + healthy_deployments = self._get_all_deployments(model_name=model) + if len(healthy_deployments) == 0: # check if the user sent in a deployment name instead - healthy_deployments = [ - m for m in self.model_list if m["litellm_params"]["model"] == model - ] + healthy_deployments = self._get_deployment_by_litellm_model(model=model) verbose_router_logger.debug( f"initial list of deployments: {healthy_deployments}" @@ -5151,7 +5011,6 @@ class Router: input=input, specific_deployment=specific_deployment, ) # type: ignore - if isinstance(healthy_deployments, dict): return healthy_deployments diff --git a/litellm/router_utils/pattern_match_deployments.py b/litellm/router_utils/pattern_match_deployments.py index 814481f5e..e92049fac 100644 --- a/litellm/router_utils/pattern_match_deployments.py +++ b/litellm/router_utils/pattern_match_deployments.py @@ -2,9 +2,11 @@ Class to handle llm wildcard routing and regex pattern matching """ +import copy import re from typing import Dict, List, Optional +from litellm import get_llm_provider from litellm._logging import verbose_router_logger @@ -82,6 +84,55 @@ class PatternMatchRouter: return None # No matching pattern found + def get_pattern( + self, model: str, custom_llm_provider: Optional[str] = None + ) -> Optional[List[Dict]]: + """ + Check if a pattern exists for the given model and custom llm provider + + Args: + model: str + custom_llm_provider: Optional[str] + + Returns: + bool: True if pattern exists, False otherwise + """ + if custom_llm_provider is None: + try: + ( + _, + custom_llm_provider, + _, + _, + ) = get_llm_provider(model=model) + except Exception: + # get_llm_provider raises exception when provider is unknown + pass + return self.route(model) or self.route(f"{custom_llm_provider}/{model}") + + def get_deployments_by_pattern( + self, model: str, custom_llm_provider: Optional[str] = None + ) -> List[Dict]: + """ + Get the deployments by pattern + + Args: + model: str + custom_llm_provider: Optional[str] + + Returns: + List[Dict]: llm deployments matching the pattern + """ + pattern_match = self.get_pattern(model, custom_llm_provider) + if pattern_match: + provider_deployments = [] + for deployment in pattern_match: + dep = copy.deepcopy(deployment) + dep["litellm_params"]["model"] = model + provider_deployments.append(dep) + return provider_deployments + return [] + # Example usage: # router = PatternRouter() diff --git a/tests/code_coverage_tests/router_code_coverage.py b/tests/code_coverage_tests/router_code_coverage.py index 946c30220..5ed00203c 100644 --- a/tests/code_coverage_tests/router_code_coverage.py +++ b/tests/code_coverage_tests/router_code_coverage.py @@ -75,29 +75,28 @@ def get_functions_from_router(file_path): ignored_function_names = [ "__init__", - "_acreate_file", - "_acreate_batch", - "acreate_assistants", - "adelete_assistant", - "aget_assistants", - "acreate_thread", - "aget_thread", - "a_add_message", - "aget_messages", - "arun_thread", - "try_retrieve_batch", ] def main(): - router_file = "./litellm/router.py" # Update this path if it's located elsewhere - # router_file = "../../litellm/router.py" ## LOCAL TESTING + router_file = [ + "./litellm/router.py", + "./litellm/router_utils/batch_utils.py", + "./litellm/router_utils/pattern_match_deployments.py", + ] + # router_file = [ + # "../../litellm/router.py", + # "../../litellm/router_utils/pattern_match_deployments.py", + # "../../litellm/router_utils/batch_utils.py", + # ] ## LOCAL TESTING tests_dir = ( "./tests/" # Update this path if your tests directory is located elsewhere ) # tests_dir = "../../tests/" # LOCAL TESTING - router_functions = get_functions_from_router(router_file) + router_functions = [] + for file in router_file: + router_functions.extend(get_functions_from_router(file)) print("router_functions: ", router_functions) called_functions_in_tests = get_all_functions_called_in_tests(tests_dir) untested_functions = [ diff --git a/tests/code_coverage_tests/router_enforce_line_length.py b/tests/code_coverage_tests/router_enforce_line_length.py new file mode 100644 index 000000000..ed822e916 --- /dev/null +++ b/tests/code_coverage_tests/router_enforce_line_length.py @@ -0,0 +1,66 @@ +import ast +import os + +MAX_FUNCTION_LINES = 100 + + +def get_function_line_counts(file_path): + """ + Extracts all function names and their line counts from a given Python file. + """ + with open(file_path, "r") as file: + tree = ast.parse(file.read()) + + function_line_counts = [] + + for node in tree.body: + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + # Top-level functions + line_count = node.end_lineno - node.lineno + 1 + function_line_counts.append((node.name, line_count)) + elif isinstance(node, ast.ClassDef): + # Functions inside classes + for class_node in node.body: + if isinstance(class_node, (ast.FunctionDef, ast.AsyncFunctionDef)): + line_count = class_node.end_lineno - class_node.lineno + 1 + function_line_counts.append((class_node.name, line_count)) + + return function_line_counts + + +ignored_functions = [ + "__init__", +] + + +def check_function_lengths(router_file): + """ + Checks if any function in the specified file exceeds the maximum allowed length. + """ + function_line_counts = get_function_line_counts(router_file) + long_functions = [ + (name, count) + for name, count in function_line_counts + if count > MAX_FUNCTION_LINES and name not in ignored_functions + ] + + if long_functions: + print("The following functions exceed the allowed line count:") + for name, count in long_functions: + print(f"- {name}: {count} lines") + raise Exception( + f"{len(long_functions)} functions in {router_file} exceed {MAX_FUNCTION_LINES} lines" + ) + else: + print("All functions in the router file are within the allowed line limit.") + + +def main(): + # Update this path to point to the correct location of router.py + router_file = "../../litellm/router.py" # LOCAL TESTING + + check_function_lengths(router_file) + + +if __name__ == "__main__": + main() diff --git a/tests/local_testing/test_router.py b/tests/local_testing/test_router.py index b97a4d191..a6316233a 100644 --- a/tests/local_testing/test_router.py +++ b/tests/local_testing/test_router.py @@ -2569,6 +2569,15 @@ async def test_router_batch_endpoints(provider): ) print("Response from creating file=", file_obj) + ## TEST 2 - test underlying create_file function + file_obj = await router._acreate_file( + model="my-custom-name", + file=open(file_path, "rb"), + purpose="batch", + custom_llm_provider=provider, + ) + print("Response from creating file=", file_obj) + await asyncio.sleep(10) batch_input_file_id = file_obj.id assert ( @@ -2583,6 +2592,15 @@ async def test_router_batch_endpoints(provider): custom_llm_provider=provider, metadata={"key1": "value1", "key2": "value2"}, ) + ## TEST 2 - test underlying create_batch function + create_batch_response = await router._acreate_batch( + model="my-custom-name", + completion_window="24h", + endpoint="/v1/chat/completions", + input_file_id=batch_input_file_id, + custom_llm_provider=provider, + metadata={"key1": "value1", "key2": "value2"}, + ) print("response from router.create_batch=", create_batch_response) diff --git a/tests/router_unit_tests/test_router_batch_utils.py b/tests/router_unit_tests/test_router_batch_utils.py new file mode 100644 index 000000000..3d1bc9210 --- /dev/null +++ b/tests/router_unit_tests/test_router_batch_utils.py @@ -0,0 +1,84 @@ +import sys +import os +import traceback +from dotenv import load_dotenv +from fastapi import Request +from datetime import datetime + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +from litellm import Router +import pytest +import litellm +from unittest.mock import patch, MagicMock, AsyncMock + +import json +from io import BytesIO +from typing import Dict, List +from litellm.router_utils.batch_utils import ( + replace_model_in_jsonl, + _get_router_metadata_variable_name, +) + + +# Fixtures +@pytest.fixture +def sample_jsonl_data() -> List[Dict]: + """Fixture providing sample JSONL data""" + return [ + { + "body": { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Hello"}], + } + }, + {"body": {"model": "gpt-4", "messages": [{"role": "user", "content": "Hi"}]}}, + ] + + +@pytest.fixture +def sample_jsonl_bytes(sample_jsonl_data) -> bytes: + """Fixture providing sample JSONL as bytes""" + jsonl_str = "\n".join(json.dumps(line) for line in sample_jsonl_data) + return jsonl_str.encode("utf-8") + + +@pytest.fixture +def sample_file_like(sample_jsonl_bytes): + """Fixture providing a file-like object""" + return BytesIO(sample_jsonl_bytes) + + +# Test cases +def test_bytes_input(sample_jsonl_bytes): + """Test with bytes input""" + new_model = "claude-3" + result = replace_model_in_jsonl(sample_jsonl_bytes, new_model) + + assert result is not None + + +def test_tuple_input(sample_jsonl_bytes): + """Test with tuple input""" + new_model = "claude-3" + test_tuple = ("test.jsonl", sample_jsonl_bytes, "application/json") + result = replace_model_in_jsonl(test_tuple, new_model) + + assert result is not None + + +def test_file_like_object(sample_file_like): + """Test with file-like object input""" + new_model = "claude-3" + result = replace_model_in_jsonl(sample_file_like, new_model) + + assert result is not None + + +def test_router_metadata_variable_name(): + """Test that the variable name is correct""" + assert _get_router_metadata_variable_name(function_name="completion") == "metadata" + assert ( + _get_router_metadata_variable_name(function_name="batch") == "litellm_metadata" + ) diff --git a/tests/router_unit_tests/test_router_helper_utils.py b/tests/router_unit_tests/test_router_helper_utils.py index f34eb428f..a97bf3197 100644 --- a/tests/router_unit_tests/test_router_helper_utils.py +++ b/tests/router_unit_tests/test_router_helper_utils.py @@ -41,6 +41,20 @@ def model_list(): "api_key": os.getenv("OPENAI_API_KEY"), }, }, + { + "model_name": "*", + "litellm_params": { + "model": "openai/*", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + }, + { + "model_name": "claude-*", + "litellm_params": { + "model": "anthropic/*", + "api_key": os.getenv("ANTHROPIC_API_KEY"), + }, + }, ] @@ -834,3 +848,69 @@ def test_flush_cache(model_list): assert router.cache.get_cache("test") == "test" router.flush_cache() assert router.cache.get_cache("test") is None + + +def test_initialize_assistants_endpoint(model_list): + """Test if the 'initialize_assistants_endpoint' function is working correctly""" + router = Router(model_list=model_list) + router.initialize_assistants_endpoint() + assert router.acreate_assistants is not None + assert router.adelete_assistant is not None + assert router.aget_assistants is not None + assert router.acreate_thread is not None + assert router.aget_thread is not None + assert router.arun_thread is not None + assert router.aget_messages is not None + assert router.a_add_message is not None + + +def test_pass_through_assistants_endpoint_factory(model_list): + """Test if the 'pass_through_assistants_endpoint_factory' function is working correctly""" + router = Router(model_list=model_list) + router._pass_through_assistants_endpoint_factory( + original_function=litellm.acreate_assistants, + custom_llm_provider="openai", + client=None, + **{}, + ) + + +def test_factory_function(model_list): + """Test if the 'factory_function' function is working correctly""" + router = Router(model_list=model_list) + router.factory_function(litellm.acreate_assistants) + + +def test_get_model_from_alias(model_list): + """Test if the 'get_model_from_alias' function is working correctly""" + router = Router( + model_list=model_list, + model_group_alias={"gpt-4o": "gpt-3.5-turbo"}, + ) + model = router._get_model_from_alias(model="gpt-4o") + assert model == "gpt-3.5-turbo" + + +def test_get_deployment_by_litellm_model(model_list): + """Test if the 'get_deployment_by_litellm_model' function is working correctly""" + router = Router(model_list=model_list) + deployment = router._get_deployment_by_litellm_model(model="gpt-3.5-turbo") + assert deployment is not None + + +def test_get_pattern(model_list): + router = Router(model_list=model_list) + pattern = router.pattern_router.get_pattern(model="claude-3") + assert pattern is not None + + +def test_deployments_by_pattern(model_list): + router = Router(model_list=model_list) + deployments = router.pattern_router.get_deployments_by_pattern(model="claude-3") + assert deployments is not None + + +def test_replace_model_in_jsonl(model_list): + router = Router(model_list=model_list) + deployments = router.pattern_router.get_deployments_by_pattern(model="claude-3") + assert deployments is not None