refactor(test_router_caching.py): move tpm/rpm routing tests to separate file

This commit is contained in:
Krrish Dholakia 2024-01-02 11:10:11 +05:30
parent 18ef244230
commit dff4c172d0
3 changed files with 18 additions and 45 deletions

View file

@ -1564,6 +1564,23 @@ class Router:
):
model["litellm_params"]["tpm"] = model.get("tpm")
#### VALIDATE MODEL ########
# check if model provider in supported providers
(
model,
custom_llm_provider,
dynamic_api_key,
api_base,
) = litellm.get_llm_provider(
model=model["litellm_params"]["model"],
custom_llm_provider=model["litellm_params"].get(
"custom_llm_provider", None
),
)
if custom_llm_provider not in litellm.provider_list:
raise Exception(f"Unsupported provider - {custom_llm_provider}")
self.set_client(model=model)
self.print_verbose(f"\nInitialized Model List {self.model_list}")

View file

@ -182,7 +182,7 @@ class LowestTPMLoggingHandler(CustomLogger):
break
elif (
item_tpm + input_tokens > _deployment_tpm
or rpm_dict[item] + 1 >= _deployment_rpm
or rpm_dict[item] + 1 > _deployment_rpm
): # if user passed in tpm / rpm in the model_list
continue
elif item_tpm < lowest_tpm:

View file

@ -209,47 +209,3 @@ async def test_acompletion_caching_on_router_caching_groups():
except Exception as e:
traceback.print_exc()
pytest.fail(f"Error occurred: {e}")
def test_usage_based_routing_completion():
litellm.set_verbose = True
model_list = [
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo-0301",
"api_key": os.getenv("OPENAI_API_KEY"),
"custom_llm_provider": "Custom-LLM",
},
"tpm": 10000,
"rpm": 5,
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo-0301",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 10000,
"rpm": 5,
},
]
router = Router(
model_list=model_list, routing_strategy="usage-based-routing", set_verbose=False
)
max_requests = 5
while max_requests > 0:
try:
router.completion(
model="gpt-3.5-turbo",
messages=[{"content": "write a one sentence poem.", "role": "user"}],
)
except ValueError as e:
traceback.print_exc()
pytest.fail(f"Error occurred: {e}")
finally:
max_requests -= 1
router.reset()
test_usage_based_routing_completion()