fix(router.py): periodically re-initialize azure/openai clients to solve max conn issue

This commit is contained in:
Krrish Dholakia 2023-12-30 15:48:34 +05:30
parent d089157925
commit 69935db239
4 changed files with 451 additions and 242 deletions

View file

@ -133,30 +133,31 @@ class DualCache(BaseCache):
# If redis_cache is not provided, use the default RedisCache # If redis_cache is not provided, use the default RedisCache
self.redis_cache = redis_cache self.redis_cache = redis_cache
def set_cache(self, key, value, **kwargs): def set_cache(self, key, value, local_only: bool = False, **kwargs):
# Update both Redis and in-memory cache # Update both Redis and in-memory cache
try: try:
print_verbose(f"set cache: key: {key}; value: {value}") print_verbose(f"set cache: key: {key}; value: {value}")
if self.in_memory_cache is not None: if self.in_memory_cache is not None:
self.in_memory_cache.set_cache(key, value, **kwargs) self.in_memory_cache.set_cache(key, value, **kwargs)
if self.redis_cache is not None: if self.redis_cache is not None and local_only == False:
self.redis_cache.set_cache(key, value, **kwargs) self.redis_cache.set_cache(key, value, **kwargs)
except Exception as e: except Exception as e:
print_verbose(e) print_verbose(e)
def get_cache(self, key, **kwargs): def get_cache(self, key, local_only: bool = False, **kwargs):
# Try to fetch from in-memory cache first # Try to fetch from in-memory cache first
try: try:
print_verbose(f"get cache: cache key: {key}") print_verbose(f"get cache: cache key: {key}; local_only: {local_only}")
result = None result = None
if self.in_memory_cache is not None: if self.in_memory_cache is not None:
in_memory_result = self.in_memory_cache.get_cache(key, **kwargs) in_memory_result = self.in_memory_cache.get_cache(key, **kwargs)
print_verbose(f"in_memory_result: {in_memory_result}")
if in_memory_result is not None: if in_memory_result is not None:
result = in_memory_result result = in_memory_result
if self.redis_cache is not None: if result is None and self.redis_cache is not None and local_only == False:
# If not found in in-memory cache, try fetching from Redis # If not found in in-memory cache, try fetching from Redis
redis_result = self.redis_cache.get_cache(key, **kwargs) redis_result = self.redis_cache.get_cache(key, **kwargs)

View file

@ -84,6 +84,7 @@ class Router:
caching_groups: Optional[ caching_groups: Optional[
List[tuple] List[tuple]
] = None, # if you want to cache across model groups ] = None, # if you want to cache across model groups
client_ttl: int = 3600, # ttl for cached clients - will re-initialize after this time in seconds
## RELIABILITY ## ## RELIABILITY ##
num_retries: int = 0, num_retries: int = 0,
timeout: Optional[float] = None, timeout: Optional[float] = None,
@ -106,6 +107,43 @@ class Router:
[] []
) # names of models under litellm_params. ex. azure/chatgpt-v-2 ) # names of models under litellm_params. ex. azure/chatgpt-v-2
self.deployment_latency_map = {} self.deployment_latency_map = {}
### CACHING ###
cache_type: Literal["local", "redis"] = "local" # default to an in-memory cache
redis_cache = None
cache_config = {}
self.client_ttl = client_ttl
if redis_url is not None or (
redis_host is not None
and redis_port is not None
and redis_password is not None
):
cache_type = "redis"
if redis_url is not None:
cache_config["url"] = redis_url
if redis_host is not None:
cache_config["host"] = redis_host
if redis_port is not None:
cache_config["port"] = str(redis_port) # type: ignore
if redis_password is not None:
cache_config["password"] = redis_password
# Add additional key-value pairs from cache_kwargs
cache_config.update(cache_kwargs)
redis_cache = RedisCache(**cache_config)
if cache_responses:
if litellm.cache is None:
# the cache can be initialized on the proxy server. We should not overwrite it
litellm.cache = litellm.Cache(type=cache_type, **cache_config) # type: ignore
self.cache_responses = cache_responses
self.cache = DualCache(
redis_cache=redis_cache, in_memory_cache=InMemoryCache()
) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
if model_list: if model_list:
model_list = copy.deepcopy(model_list) model_list = copy.deepcopy(model_list)
self.set_model_list(model_list) self.set_model_list(model_list)
@ -155,40 +193,6 @@ class Router:
{"caching_groups": caching_groups} {"caching_groups": caching_groups}
) )
### CACHING ###
cache_type: Literal["local", "redis"] = "local" # default to an in-memory cache
redis_cache = None
cache_config = {}
if redis_url is not None or (
redis_host is not None
and redis_port is not None
and redis_password is not None
):
cache_type = "redis"
if redis_url is not None:
cache_config["url"] = redis_url
if redis_host is not None:
cache_config["host"] = redis_host
if redis_port is not None:
cache_config["port"] = str(redis_port) # type: ignore
if redis_password is not None:
cache_config["password"] = redis_password
# Add additional key-value pairs from cache_kwargs
cache_config.update(cache_kwargs)
redis_cache = RedisCache(**cache_config)
if cache_responses:
if litellm.cache is None:
# the cache can be initialized on the proxy server. We should not overwrite it
litellm.cache = litellm.Cache(type=cache_type, **cache_config) # type: ignore
self.cache_responses = cache_responses
self.cache = DualCache(
redis_cache=redis_cache, in_memory_cache=InMemoryCache()
) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
### ROUTING SETUP ### ### ROUTING SETUP ###
if routing_strategy == "least-busy": if routing_strategy == "least-busy":
self.leastbusy_logger = LeastBusyLoggingHandler( self.leastbusy_logger = LeastBusyLoggingHandler(
@ -1248,23 +1252,17 @@ class Router:
chosen_item = random.choices(sorted_items, weights=weights, k=1)[0][0] chosen_item = random.choices(sorted_items, weights=weights, k=1)[0][0]
return chosen_item return chosen_item
def set_model_list(self, model_list: list): def set_client(self, model: dict):
self.model_list = copy.deepcopy(model_list) """
# we add api_base/api_key each model so load balancing between azure/gpt on api_base1 and api_base2 works Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
import os """
client_ttl = self.client_ttl
for model in self.model_list:
litellm_params = model.get("litellm_params", {}) litellm_params = model.get("litellm_params", {})
model_name = litellm_params.get("model") model_name = litellm_params.get("model")
#### MODEL ID INIT ######## model_id = model["model_info"]["id"]
model_info = model.get("model_info", {})
model_info["id"] = model_info.get("id", str(uuid.uuid4()))
model["model_info"] = model_info
#### for OpenAI / Azure we need to initalize the Client for High Traffic ######## #### for OpenAI / Azure we need to initalize the Client for High Traffic ########
custom_llm_provider = litellm_params.get("custom_llm_provider") custom_llm_provider = litellm_params.get("custom_llm_provider")
custom_llm_provider = ( custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or ""
custom_llm_provider or model_name.split("/", 1)[0] or ""
)
default_api_base = None default_api_base = None
default_api_key = None default_api_key = None
if custom_llm_provider in litellm.openai_compatible_providers: if custom_llm_provider in litellm.openai_compatible_providers:
@ -1324,9 +1322,7 @@ class Router:
litellm_params["stream_timeout"] = stream_timeout litellm_params["stream_timeout"] = stream_timeout
max_retries = litellm_params.pop("max_retries", 2) max_retries = litellm_params.pop("max_retries", 2)
if isinstance(max_retries, str) and max_retries.startswith( if isinstance(max_retries, str) and max_retries.startswith("os.environ/"):
"os.environ/"
):
max_retries_env_name = max_retries.replace("os.environ/", "") max_retries_env_name = max_retries.replace("os.environ/", "")
max_retries = litellm.get_secret(max_retries_env_name) max_retries = litellm.get_secret(max_retries_env_name)
litellm_params["max_retries"] = max_retries litellm_params["max_retries"] = max_retries
@ -1343,41 +1339,71 @@ class Router:
api_base += "/" api_base += "/"
azure_model = model_name.replace("azure/", "") azure_model = model_name.replace("azure/", "")
api_base += f"{azure_model}" api_base += f"{azure_model}"
model["async_client"] = openai.AsyncAzureOpenAI( cache_key = f"{model_id}_async_client"
api_key=api_key, _client = openai.AsyncAzureOpenAI(
base_url=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
)
model["client"] = openai.AzureOpenAI(
api_key=api_key, api_key=api_key,
base_url=api_base, base_url=api_base,
api_version=api_version, api_version=api_version,
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
) )
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
_client = openai.AzureOpenAI( # type: ignore
api_key=api_key,
base_url=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
)
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
# streaming clients can have diff timeouts # streaming clients can have diff timeouts
model["stream_async_client"] = openai.AsyncAzureOpenAI( cache_key = f"{model_id}_stream_async_client"
_client = openai.AsyncAzureOpenAI( # type: ignore
api_key=api_key, api_key=api_key,
base_url=api_base, base_url=api_base,
api_version=api_version, api_version=api_version,
timeout=stream_timeout, timeout=stream_timeout,
max_retries=max_retries, max_retries=max_retries,
) )
model["stream_client"] = openai.AzureOpenAI( self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
cache_key = f"{model_id}_stream_client"
_client = openai.AzureOpenAI( # type: ignore
api_key=api_key, api_key=api_key,
base_url=api_base, base_url=api_base,
api_version=api_version, api_version=api_version,
timeout=stream_timeout, timeout=stream_timeout,
max_retries=max_retries, max_retries=max_retries,
) )
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
else: else:
self.print_verbose( self.print_verbose(
f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{api_key}" f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{api_key}"
) )
model["async_client"] = openai.AsyncAzureOpenAI(
cache_key = f"{model_id}_async_client"
_client = openai.AsyncAzureOpenAI( # type: ignore
api_key=api_key, api_key=api_key,
azure_endpoint=api_base, azure_endpoint=api_base,
api_version=api_version, api_version=api_version,
@ -1387,7 +1413,15 @@ class Router:
transport=AsyncCustomHTTPTransport(), transport=AsyncCustomHTTPTransport(),
), # type: ignore ), # type: ignore
) )
model["client"] = openai.AzureOpenAI( self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
cache_key = f"{model_id}_client"
_client = openai.AzureOpenAI( # type: ignore
api_key=api_key, api_key=api_key,
azure_endpoint=api_base, azure_endpoint=api_base,
api_version=api_version, api_version=api_version,
@ -1397,59 +1431,118 @@ class Router:
transport=CustomHTTPTransport(), transport=CustomHTTPTransport(),
), # type: ignore ), # type: ignore
) )
# streaming clients should have diff timeouts self.cache.set_cache(
model["stream_async_client"] = openai.AsyncAzureOpenAI( key=cache_key,
api_key=api_key, value=_client,
azure_endpoint=api_base, ttl=client_ttl,
api_version=api_version, local_only=True,
timeout=stream_timeout, ) # cache for 1 hr
max_retries=max_retries,
)
model["stream_client"] = openai.AzureOpenAI( # streaming clients should have diff timeouts
cache_key = f"{model_id}_stream_async_client"
_client = openai.AsyncAzureOpenAI( # type: ignore
api_key=api_key, api_key=api_key,
azure_endpoint=api_base, azure_endpoint=api_base,
api_version=api_version, api_version=api_version,
timeout=stream_timeout, timeout=stream_timeout,
max_retries=max_retries, max_retries=max_retries,
) )
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
cache_key = f"{model_id}_stream_client"
_client = openai.AzureOpenAI( # type: ignore
api_key=api_key,
azure_endpoint=api_base,
api_version=api_version,
timeout=stream_timeout,
max_retries=max_retries,
)
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
else: else:
self.print_verbose( self.print_verbose(
f"Initializing OpenAI Client for {model_name}, Api Base:{str(api_base)}, Api Key:{api_key}" f"Initializing OpenAI Client for {model_name}, Api Base:{str(api_base)}, Api Key:{api_key}"
) )
model["async_client"] = openai.AsyncOpenAI( cache_key = f"{model_id}_async_client"
_client = openai.AsyncOpenAI( # type: ignore
api_key=api_key, api_key=api_key,
base_url=api_base, base_url=api_base,
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
) )
model["client"] = openai.OpenAI( self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
cache_key = f"{model_id}_client"
_client = openai.OpenAI( # type: ignore
api_key=api_key, api_key=api_key,
base_url=api_base, base_url=api_base,
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
) )
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
# streaming clients should have diff timeouts # streaming clients should have diff timeouts
model["stream_async_client"] = openai.AsyncOpenAI( cache_key = f"{model_id}_stream_async_client"
_client = openai.AsyncOpenAI( # type: ignore
api_key=api_key, api_key=api_key,
base_url=api_base, base_url=api_base,
timeout=stream_timeout, timeout=stream_timeout,
max_retries=max_retries, max_retries=max_retries,
) )
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
# streaming clients should have diff timeouts # streaming clients should have diff timeouts
model["stream_client"] = openai.OpenAI( cache_key = f"{model_id}_stream_client"
_client = openai.OpenAI( # type: ignore
api_key=api_key, api_key=api_key,
base_url=api_base, base_url=api_base,
timeout=stream_timeout, timeout=stream_timeout,
max_retries=max_retries, max_retries=max_retries,
) )
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
############ End of initializing Clients for OpenAI/Azure ################### def set_model_list(self, model_list: list):
self.model_list = copy.deepcopy(model_list)
# we add api_base/api_key each model so load balancing between azure/gpt on api_base1 and api_base2 works
import os
for model in self.model_list:
#### MODEL ID INIT ########
model_info = model.get("model_info", {})
model_info["id"] = model_info.get("id", str(uuid.uuid4()))
model["model_info"] = model_info
#### DEPLOYMENT NAMES INIT ########
self.deployment_names.append(model["litellm_params"]["model"]) self.deployment_names.append(model["litellm_params"]["model"])
############ Users can either pass tpm/rpm as a litellm_param or a router param ########### ############ Users can either pass tpm/rpm as a litellm_param or a router param ###########
# for get_available_deployment, we use the litellm_param["rpm"] # for get_available_deployment, we use the litellm_param["rpm"]
# in this snippet we also set rpm to be a litellm_param # in this snippet we also set rpm to be a litellm_param
@ -1464,6 +1557,8 @@ class Router:
): ):
model["litellm_params"]["tpm"] = model.get("tpm") model["litellm_params"]["tpm"] = model.get("tpm")
self.set_client(model=model)
self.print_verbose(f"\nInitialized Model List {self.model_list}") self.print_verbose(f"\nInitialized Model List {self.model_list}")
self.model_names = [m["model_name"] for m in model_list] self.model_names = [m["model_name"] for m in model_list]
@ -1482,16 +1577,49 @@ class Router:
Returns: Returns:
The appropriate client based on the given client_type and kwargs. The appropriate client based on the given client_type and kwargs.
""" """
model_id = deployment["model_info"]["id"]
if client_type == "async": if client_type == "async":
if kwargs.get("stream") == True: if kwargs.get("stream") == True:
return deployment.get("stream_async_client", None) cache_key = f"{model_id}_stream_async_client"
client = self.cache.get_cache(key=cache_key, local_only=True)
if client is None:
"""
Re-initialize the client
"""
self.set_client(model=deployment)
client = self.cache.get_cache(key=cache_key, local_only=True)
return client
else: else:
return deployment.get("async_client", None) cache_key = f"{model_id}_async_client"
client = self.cache.get_cache(key=cache_key, local_only=True)
if client is None:
"""
Re-initialize the client
"""
self.set_client(model=deployment)
client = self.cache.get_cache(key=cache_key, local_only=True)
return client
else: else:
if kwargs.get("stream") == True: if kwargs.get("stream") == True:
return deployment.get("stream_client", None) cache_key = f"{model_id}_stream_client"
client = self.cache.get_cache(key=cache_key)
if client is None:
"""
Re-initialize the client
"""
self.set_client(model=deployment)
client = self.cache.get_cache(key=cache_key)
return client
else: else:
return deployment.get("client", None) cache_key = f"{model_id}_client"
client = self.cache.get_cache(key=cache_key)
if client is None:
"""
Re-initialize the client
"""
self.set_client(model=deployment)
client = self.cache.get_cache(key=cache_key)
return client
def print_verbose(self, print_statement): def print_verbose(self, print_statement):
try: try:

View file

@ -778,7 +778,8 @@ def test_reading_keys_os_environ():
os.environ["AZURE_MAX_RETRIES"] os.environ["AZURE_MAX_RETRIES"]
), f"{model['litellm_params']['max_retries']} vs {os.environ['AZURE_MAX_RETRIES']}" ), f"{model['litellm_params']['max_retries']} vs {os.environ['AZURE_MAX_RETRIES']}"
print("passed testing of reading keys from os.environ") print("passed testing of reading keys from os.environ")
async_client: openai.AsyncAzureOpenAI = model["async_client"] # type: ignore model_id = model["model_info"]["id"]
async_client: openai.AsyncAzureOpenAI = router.cache.get_cache(f"{model_id}_async_client") # type: ignore
assert async_client.api_key == os.environ["AZURE_API_KEY"] assert async_client.api_key == os.environ["AZURE_API_KEY"]
assert async_client.base_url == os.environ["AZURE_API_BASE"] assert async_client.base_url == os.environ["AZURE_API_BASE"]
assert async_client.max_retries == ( assert async_client.max_retries == (
@ -791,7 +792,7 @@ def test_reading_keys_os_environ():
print("\n Testing async streaming client") print("\n Testing async streaming client")
stream_async_client: openai.AsyncAzureOpenAI = model["stream_async_client"] # type: ignore stream_async_client: openai.AsyncAzureOpenAI = router.cache.get_cache(f"{model_id}_stream_async_client") # type: ignore
assert stream_async_client.api_key == os.environ["AZURE_API_KEY"] assert stream_async_client.api_key == os.environ["AZURE_API_KEY"]
assert stream_async_client.base_url == os.environ["AZURE_API_BASE"] assert stream_async_client.base_url == os.environ["AZURE_API_BASE"]
assert stream_async_client.max_retries == ( assert stream_async_client.max_retries == (
@ -803,7 +804,7 @@ def test_reading_keys_os_environ():
print("async stream client set correctly!") print("async stream client set correctly!")
print("\n Testing sync client") print("\n Testing sync client")
client: openai.AzureOpenAI = model["client"] # type: ignore client: openai.AzureOpenAI = router.cache.get_cache(f"{model_id}_client") # type: ignore
assert client.api_key == os.environ["AZURE_API_KEY"] assert client.api_key == os.environ["AZURE_API_KEY"]
assert client.base_url == os.environ["AZURE_API_BASE"] assert client.base_url == os.environ["AZURE_API_BASE"]
assert client.max_retries == ( assert client.max_retries == (
@ -815,7 +816,7 @@ def test_reading_keys_os_environ():
print("sync client set correctly!") print("sync client set correctly!")
print("\n Testing sync stream client") print("\n Testing sync stream client")
stream_client: openai.AzureOpenAI = model["stream_client"] # type: ignore stream_client: openai.AzureOpenAI = router.cache.get_cache(f"{model_id}_stream_client") # type: ignore
assert stream_client.api_key == os.environ["AZURE_API_KEY"] assert stream_client.api_key == os.environ["AZURE_API_KEY"]
assert stream_client.base_url == os.environ["AZURE_API_BASE"] assert stream_client.base_url == os.environ["AZURE_API_BASE"]
assert stream_client.max_retries == ( assert stream_client.max_retries == (
@ -877,7 +878,8 @@ def test_reading_openai_keys_os_environ():
os.environ["AZURE_MAX_RETRIES"] os.environ["AZURE_MAX_RETRIES"]
), f"{model['litellm_params']['max_retries']} vs {os.environ['AZURE_MAX_RETRIES']}" ), f"{model['litellm_params']['max_retries']} vs {os.environ['AZURE_MAX_RETRIES']}"
print("passed testing of reading keys from os.environ") print("passed testing of reading keys from os.environ")
async_client: openai.AsyncOpenAI = model["async_client"] # type: ignore model_id = model["model_info"]["id"]
async_client: openai.AsyncOpenAI = router.cache.get_cache(key=f"{model_id}_async_client") # type: ignore
assert async_client.api_key == os.environ["OPENAI_API_KEY"] assert async_client.api_key == os.environ["OPENAI_API_KEY"]
assert async_client.max_retries == ( assert async_client.max_retries == (
os.environ["AZURE_MAX_RETRIES"] os.environ["AZURE_MAX_RETRIES"]
@ -889,7 +891,7 @@ def test_reading_openai_keys_os_environ():
print("\n Testing async streaming client") print("\n Testing async streaming client")
stream_async_client: openai.AsyncOpenAI = model["stream_async_client"] # type: ignore stream_async_client: openai.AsyncOpenAI = router.cache.get_cache(key=f"{model_id}_stream_async_client") # type: ignore
assert stream_async_client.api_key == os.environ["OPENAI_API_KEY"] assert stream_async_client.api_key == os.environ["OPENAI_API_KEY"]
assert stream_async_client.max_retries == ( assert stream_async_client.max_retries == (
os.environ["AZURE_MAX_RETRIES"] os.environ["AZURE_MAX_RETRIES"]
@ -900,7 +902,7 @@ def test_reading_openai_keys_os_environ():
print("async stream client set correctly!") print("async stream client set correctly!")
print("\n Testing sync client") print("\n Testing sync client")
client: openai.AzureOpenAI = model["client"] # type: ignore client: openai.AzureOpenAI = router.cache.get_cache(key=f"{model_id}_client") # type: ignore
assert client.api_key == os.environ["OPENAI_API_KEY"] assert client.api_key == os.environ["OPENAI_API_KEY"]
assert client.max_retries == ( assert client.max_retries == (
os.environ["AZURE_MAX_RETRIES"] os.environ["AZURE_MAX_RETRIES"]
@ -911,7 +913,7 @@ def test_reading_openai_keys_os_environ():
print("sync client set correctly!") print("sync client set correctly!")
print("\n Testing sync stream client") print("\n Testing sync stream client")
stream_client: openai.AzureOpenAI = model["stream_client"] # type: ignore stream_client: openai.AzureOpenAI = router.cache.get_cache(key=f"{model_id}_stream_client") # type: ignore
assert stream_client.api_key == os.environ["OPENAI_API_KEY"] assert stream_client.api_key == os.environ["OPENAI_API_KEY"]
assert stream_client.max_retries == ( assert stream_client.max_retries == (
os.environ["AZURE_MAX_RETRIES"] os.environ["AZURE_MAX_RETRIES"]

View file

@ -0,0 +1,78 @@
#### What this tests ####
# This tests client initialization + reinitialization on the router
#### What this tests ####
# This tests caching on the router
import sys, os, time
import traceback, asyncio
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm import Router
async def test_router_init():
"""
1. Initializes clients on the router with 0
2. Checks if client is still valid
3. Checks if new client was initialized
"""
model_list = [
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo-0613",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"model_info": {"id": "1234"},
"tpm": 100000,
"rpm": 10000,
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_base": os.getenv("AZURE_API_BASE"),
"api_version": os.getenv("AZURE_API_VERSION"),
},
"tpm": 100000,
"rpm": 10000,
},
]
messages = [
{"role": "user", "content": f"write a one sentence poem {time.time()}?"}
]
client_ttl_time = 2
router = Router(
model_list=model_list,
redis_host=os.environ["REDIS_HOST"],
redis_password=os.environ["REDIS_PASSWORD"],
redis_port=os.environ["REDIS_PORT"],
cache_responses=True,
timeout=30,
routing_strategy="simple-shuffle",
client_ttl=client_ttl_time,
)
model = "gpt-3.5-turbo"
cache_key = f"1234_async_client"
## ASSERT IT EXISTS AT THE START ##
assert router.cache.get_cache(key=cache_key) is not None
response1 = await router.acompletion(model=model, messages=messages, temperature=1)
await asyncio.sleep(client_ttl_time)
## ASSERT IT'S CLEARED FROM CACHE ##
assert router.cache.get_cache(key=cache_key, local_only=True) is None
## ASSERT IT EXISTS AFTER RUNNING __GET_CLIENT() ##
assert (
router._get_client(
deployment=model_list[0], client_type="async", kwargs={"stream": False}
)
is not None
)
# asyncio.run(test_router_init())