mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
fix(test_proxy.py): fix tests
This commit is contained in:
parent
ab6f58f252
commit
fa488e29e0
2 changed files with 13 additions and 6 deletions
|
@ -11,7 +11,7 @@ class Cache:
|
||||||
if cache_config["type"] == "redis":
|
if cache_config["type"] == "redis":
|
||||||
pass
|
pass
|
||||||
elif cache_config["type"] == "local":
|
elif cache_config["type"] == "local":
|
||||||
self.usage_dict = {}
|
self.usage_dict: Dict = {}
|
||||||
def get(self, key: str):
|
def get(self, key: str):
|
||||||
return self.usage_dict.get(key, 0)
|
return self.usage_dict.get(key, 0)
|
||||||
|
|
||||||
|
@ -86,7 +86,7 @@ class Router:
|
||||||
is_async: Optional[bool] = False,
|
is_async: Optional[bool] = False,
|
||||||
**kwargs) -> Union[List[float], None]:
|
**kwargs) -> Union[List[float], None]:
|
||||||
# pick the one that is available (lowest TPM/RPM)
|
# pick the one that is available (lowest TPM/RPM)
|
||||||
deployment = self.get_available_deployment(model=model)
|
deployment = self.get_available_deployment(model=model, input=input)
|
||||||
|
|
||||||
data = deployment["litellm_params"]
|
data = deployment["litellm_params"]
|
||||||
data["input"] = input
|
data["input"] = input
|
||||||
|
@ -109,7 +109,8 @@ class Router:
|
||||||
|
|
||||||
def get_available_deployment(self,
|
def get_available_deployment(self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Dict[str, str]]):
|
messages: Optional[List[Dict[str, str]]]=None,
|
||||||
|
input: Optional[Union[str, List]]=None):
|
||||||
"""
|
"""
|
||||||
Returns a deployment with the lowest TPM/RPM usage.
|
Returns a deployment with the lowest TPM/RPM usage.
|
||||||
"""
|
"""
|
||||||
|
@ -131,8 +132,14 @@ class Router:
|
||||||
current_tpm, current_rpm = self._get_deployment_usage(deployment_name=deployment["litellm_params"]["model"])
|
current_tpm, current_rpm = self._get_deployment_usage(deployment_name=deployment["litellm_params"]["model"])
|
||||||
|
|
||||||
# get encoding
|
# get encoding
|
||||||
token_count = litellm.token_counter(model=deployment["model_name"], messages=messages)
|
if messages:
|
||||||
|
token_count = litellm.token_counter(model=deployment["model_name"], messages=messages)
|
||||||
|
elif input:
|
||||||
|
if isinstance(input, List):
|
||||||
|
input_text = "".join(text for text in input)
|
||||||
|
else:
|
||||||
|
input_text = input
|
||||||
|
token_count = litellm.token_counter(model=deployment["model_name"], text=input_text)
|
||||||
|
|
||||||
# if at model limit, return lowest used
|
# if at model limit, return lowest used
|
||||||
if current_tpm + token_count > tpm or current_rpm + 1 >= rpm:
|
if current_tpm + token_count > tpm or current_rpm + 1 >= rpm:
|
||||||
|
|
|
@ -28,7 +28,7 @@ def test_azure_call():
|
||||||
## test debug
|
## test debug
|
||||||
def test_debug():
|
def test_debug():
|
||||||
try:
|
try:
|
||||||
initialize(model=None, alias=None, api_base=None, debug=True, temperature=None, max_tokens=None, max_budget=None, telemetry=None, drop_params=None, add_function_to_prompt=None, headers=None, save=None)
|
initialize(model=None, alias=None, api_base=None, debug=True, temperature=None, max_tokens=None, max_budget=None, telemetry=None, drop_params=None, add_function_to_prompt=None, headers=None, save=None, api_version=None)
|
||||||
assert litellm.set_verbose == True
|
assert litellm.set_verbose == True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"An error occurred: {e}")
|
pytest.fail(f"An error occurred: {e}")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue