test(test_router_init.py): fix test router init

This commit is contained in:
Krrish Dholakia 2023-12-30 16:51:39 +05:30
parent 3cb7acceaa
commit c41b1418d4

View file

@ -41,14 +41,17 @@ def test_init_clients():
] ]
router = Router(model_list=model_list) router = Router(model_list=model_list)
for elem in router.model_list: for elem in router.model_list:
assert elem["client"] is not None model_id = elem["model_info"]["id"]
assert elem["async_client"] is not None assert router.cache.get_cache(f"{model_id}_client") is not None
assert elem["stream_client"] is not None assert router.cache.get_cache(f"{model_id}_async_client") is not None
assert elem["stream_async_client"] is not None assert router.cache.get_cache(f"{model_id}_stream_client") is not None
assert router.cache.get_cache(f"{model_id}_stream_async_client") is not None
# check if timeout for stream/non stream clients is set correctly # check if timeout for stream/non stream clients is set correctly
async_client = elem["async_client"] async_client = router.cache.get_cache(f"{model_id}_async_client")
stream_async_client = elem["stream_async_client"] stream_async_client = router.cache.get_cache(
f"{model_id}_stream_async_client"
)
assert async_client.timeout == 0.01 assert async_client.timeout == 0.01
assert stream_async_client.timeout == 0.000_001 assert stream_async_client.timeout == 0.000_001
@ -79,10 +82,11 @@ def test_init_clients_basic():
] ]
router = Router(model_list=model_list) router = Router(model_list=model_list)
for elem in router.model_list: for elem in router.model_list:
assert elem["client"] is not None model_id = elem["model_info"]["id"]
assert elem["async_client"] is not None assert router.cache.get_cache(f"{model_id}_client") is not None
assert elem["stream_client"] is not None assert router.cache.get_cache(f"{model_id}_async_client") is not None
assert elem["stream_async_client"] is not None assert router.cache.get_cache(f"{model_id}_stream_client") is not None
assert router.cache.get_cache(f"{model_id}_stream_async_client") is not None
print("PASSED !") print("PASSED !")
# see if we can init clients without timeout or max retries set # see if we can init clients without timeout or max retries set