fix(router.py): periodically re-initialize azure/openai clients to solve max conn issue

This commit is contained in:
Krrish Dholakia 2023-12-30 15:48:34 +05:30
parent d089157925
commit 69935db239
4 changed files with 451 additions and 242 deletions

View file

@ -778,7 +778,8 @@ def test_reading_keys_os_environ():
os.environ["AZURE_MAX_RETRIES"]
), f"{model['litellm_params']['max_retries']} vs {os.environ['AZURE_MAX_RETRIES']}"
print("passed testing of reading keys from os.environ")
async_client: openai.AsyncAzureOpenAI = model["async_client"] # type: ignore
model_id = model["model_info"]["id"]
async_client: openai.AsyncAzureOpenAI = router.cache.get_cache(f"{model_id}_async_client") # type: ignore
assert async_client.api_key == os.environ["AZURE_API_KEY"]
assert async_client.base_url == os.environ["AZURE_API_BASE"]
assert async_client.max_retries == (
@ -791,7 +792,7 @@ def test_reading_keys_os_environ():
print("\n Testing async streaming client")
stream_async_client: openai.AsyncAzureOpenAI = model["stream_async_client"] # type: ignore
stream_async_client: openai.AsyncAzureOpenAI = router.cache.get_cache(f"{model_id}_stream_async_client") # type: ignore
assert stream_async_client.api_key == os.environ["AZURE_API_KEY"]
assert stream_async_client.base_url == os.environ["AZURE_API_BASE"]
assert stream_async_client.max_retries == (
@ -803,7 +804,7 @@ def test_reading_keys_os_environ():
print("async stream client set correctly!")
print("\n Testing sync client")
client: openai.AzureOpenAI = model["client"] # type: ignore
client: openai.AzureOpenAI = router.cache.get_cache(f"{model_id}_client") # type: ignore
assert client.api_key == os.environ["AZURE_API_KEY"]
assert client.base_url == os.environ["AZURE_API_BASE"]
assert client.max_retries == (
@ -815,7 +816,7 @@ def test_reading_keys_os_environ():
print("sync client set correctly!")
print("\n Testing sync stream client")
stream_client: openai.AzureOpenAI = model["stream_client"] # type: ignore
stream_client: openai.AzureOpenAI = router.cache.get_cache(f"{model_id}_stream_client") # type: ignore
assert stream_client.api_key == os.environ["AZURE_API_KEY"]
assert stream_client.base_url == os.environ["AZURE_API_BASE"]
assert stream_client.max_retries == (
@ -877,7 +878,8 @@ def test_reading_openai_keys_os_environ():
os.environ["AZURE_MAX_RETRIES"]
), f"{model['litellm_params']['max_retries']} vs {os.environ['AZURE_MAX_RETRIES']}"
print("passed testing of reading keys from os.environ")
async_client: openai.AsyncOpenAI = model["async_client"] # type: ignore
model_id = model["model_info"]["id"]
async_client: openai.AsyncOpenAI = router.cache.get_cache(key=f"{model_id}_async_client") # type: ignore
assert async_client.api_key == os.environ["OPENAI_API_KEY"]
assert async_client.max_retries == (
os.environ["AZURE_MAX_RETRIES"]
@ -889,7 +891,7 @@ def test_reading_openai_keys_os_environ():
print("\n Testing async streaming client")
stream_async_client: openai.AsyncOpenAI = model["stream_async_client"] # type: ignore
stream_async_client: openai.AsyncOpenAI = router.cache.get_cache(key=f"{model_id}_stream_async_client") # type: ignore
assert stream_async_client.api_key == os.environ["OPENAI_API_KEY"]
assert stream_async_client.max_retries == (
os.environ["AZURE_MAX_RETRIES"]
@ -900,7 +902,7 @@ def test_reading_openai_keys_os_environ():
print("async stream client set correctly!")
print("\n Testing sync client")
client: openai.AzureOpenAI = model["client"] # type: ignore
client: openai.AzureOpenAI = router.cache.get_cache(key=f"{model_id}_client") # type: ignore
assert client.api_key == os.environ["OPENAI_API_KEY"]
assert client.max_retries == (
os.environ["AZURE_MAX_RETRIES"]
@ -911,7 +913,7 @@ def test_reading_openai_keys_os_environ():
print("sync client set correctly!")
print("\n Testing sync stream client")
stream_client: openai.AzureOpenAI = model["stream_client"] # type: ignore
stream_client: openai.AzureOpenAI = router.cache.get_cache(key=f"{model_id}_stream_client") # type: ignore
assert stream_client.api_key == os.environ["OPENAI_API_KEY"]
assert stream_client.max_retries == (
os.environ["AZURE_MAX_RETRIES"]