diff --git a/litellm/router.py b/litellm/router.py index caab343143..f932e5e186 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -11,7 +11,7 @@ class Cache: if cache_config["type"] == "redis": pass elif cache_config["type"] == "local": - self.usage_dict = {} + self.usage_dict: Dict = {} def get(self, key: str): return self.usage_dict.get(key, 0) @@ -86,7 +86,7 @@ class Router: is_async: Optional[bool] = False, **kwargs) -> Union[List[float], None]: # pick the one that is available (lowest TPM/RPM) - deployment = self.get_available_deployment(model=model) + deployment = self.get_available_deployment(model=model, input=input) data = deployment["litellm_params"] data["input"] = input @@ -109,7 +109,8 @@ class Router: def get_available_deployment(self, model: str, - messages: List[Dict[str, str]]): + messages: Optional[List[Dict[str, str]]]=None, + input: Optional[Union[str, List]]=None): """ Returns a deployment with the lowest TPM/RPM usage. """ @@ -131,8 +132,14 @@ class Router: current_tpm, current_rpm = self._get_deployment_usage(deployment_name=deployment["litellm_params"]["model"]) # get encoding - token_count = litellm.token_counter(model=deployment["model_name"], messages=messages) - + if messages: + token_count = litellm.token_counter(model=deployment["model_name"], messages=messages) + elif input: + if isinstance(input, List): + input_text = "".join(text for text in input) + else: + input_text = input + token_count = litellm.token_counter(model=deployment["model_name"], text=input_text) # if at model limit, return lowest used if current_tpm + token_count > tpm or current_rpm + 1 >= rpm: diff --git a/litellm/tests/test_proxy.py b/litellm/tests/test_proxy.py index b8d19b065f..bd20cabb82 100644 --- a/litellm/tests/test_proxy.py +++ b/litellm/tests/test_proxy.py @@ -28,7 +28,7 @@ def test_azure_call(): ## test debug def test_debug(): try: - initialize(model=None, alias=None, api_base=None, debug=True, temperature=None, max_tokens=None, max_budget=None, telemetry=None, drop_params=None, add_function_to_prompt=None, headers=None, save=None) + initialize(model=None, alias=None, api_base=None, debug=True, temperature=None, max_tokens=None, max_budget=None, telemetry=None, drop_params=None, add_function_to_prompt=None, headers=None, save=None, api_version=None) assert litellm.set_verbose == True except Exception as e: pytest.fail(f"An error occurred: {e}")