fix(test_proxy.py): fix tests

This commit is contained in:
Krrish Dholakia 2023-10-17 22:34:12 -07:00
parent ab6f58f252
commit fa488e29e0
2 changed files with 13 additions and 6 deletions

View file

@ -11,7 +11,7 @@ class Cache:
if cache_config["type"] == "redis": if cache_config["type"] == "redis":
pass pass
elif cache_config["type"] == "local": elif cache_config["type"] == "local":
self.usage_dict = {} self.usage_dict: Dict = {}
def get(self, key: str): def get(self, key: str):
return self.usage_dict.get(key, 0) return self.usage_dict.get(key, 0)
@ -86,7 +86,7 @@ class Router:
is_async: Optional[bool] = False, is_async: Optional[bool] = False,
**kwargs) -> Union[List[float], None]: **kwargs) -> Union[List[float], None]:
# pick the one that is available (lowest TPM/RPM) # pick the one that is available (lowest TPM/RPM)
deployment = self.get_available_deployment(model=model) deployment = self.get_available_deployment(model=model, input=input)
data = deployment["litellm_params"] data = deployment["litellm_params"]
data["input"] = input data["input"] = input
@ -109,7 +109,8 @@ class Router:
def get_available_deployment(self, def get_available_deployment(self,
model: str, model: str,
messages: List[Dict[str, str]]): messages: Optional[List[Dict[str, str]]]=None,
input: Optional[Union[str, List]]=None):
""" """
Returns a deployment with the lowest TPM/RPM usage. Returns a deployment with the lowest TPM/RPM usage.
""" """
@ -131,8 +132,14 @@ class Router:
current_tpm, current_rpm = self._get_deployment_usage(deployment_name=deployment["litellm_params"]["model"]) current_tpm, current_rpm = self._get_deployment_usage(deployment_name=deployment["litellm_params"]["model"])
# get encoding # get encoding
token_count = litellm.token_counter(model=deployment["model_name"], messages=messages) if messages:
token_count = litellm.token_counter(model=deployment["model_name"], messages=messages)
elif input:
if isinstance(input, List):
input_text = "".join(text for text in input)
else:
input_text = input
token_count = litellm.token_counter(model=deployment["model_name"], text=input_text)
# if at model limit, return lowest used # if at model limit, return lowest used
if current_tpm + token_count > tpm or current_rpm + 1 >= rpm: if current_tpm + token_count > tpm or current_rpm + 1 >= rpm:

View file

@ -28,7 +28,7 @@ def test_azure_call():
## test debug ## test debug
def test_debug(): def test_debug():
try: try:
initialize(model=None, alias=None, api_base=None, debug=True, temperature=None, max_tokens=None, max_budget=None, telemetry=None, drop_params=None, add_function_to_prompt=None, headers=None, save=None) initialize(model=None, alias=None, api_base=None, debug=True, temperature=None, max_tokens=None, max_budget=None, telemetry=None, drop_params=None, add_function_to_prompt=None, headers=None, save=None, api_version=None)
assert litellm.set_verbose == True assert litellm.set_verbose == True
except Exception as e: except Exception as e:
pytest.fail(f"An error occurred: {e}") pytest.fail(f"An error occurred: {e}")