fix(router.py): periodically re-initialize azure/openai clients to solve max conn issue

This commit is contained in:
Krrish Dholakia 2023-12-30 15:48:34 +05:30
parent d089157925
commit 69935db239
4 changed files with 451 additions and 242 deletions

View file

@ -133,30 +133,31 @@ class DualCache(BaseCache):
# If redis_cache is not provided, use the default RedisCache
self.redis_cache = redis_cache
def set_cache(self, key, value, **kwargs):
def set_cache(self, key, value, local_only: bool = False, **kwargs):
# Update both Redis and in-memory cache
try:
print_verbose(f"set cache: key: {key}; value: {value}")
if self.in_memory_cache is not None:
self.in_memory_cache.set_cache(key, value, **kwargs)
if self.redis_cache is not None:
if self.redis_cache is not None and local_only == False:
self.redis_cache.set_cache(key, value, **kwargs)
except Exception as e:
print_verbose(e)
def get_cache(self, key, **kwargs):
def get_cache(self, key, local_only: bool = False, **kwargs):
# Try to fetch from in-memory cache first
try:
print_verbose(f"get cache: cache key: {key}")
print_verbose(f"get cache: cache key: {key}; local_only: {local_only}")
result = None
if self.in_memory_cache is not None:
in_memory_result = self.in_memory_cache.get_cache(key, **kwargs)
print_verbose(f"in_memory_result: {in_memory_result}")
if in_memory_result is not None:
result = in_memory_result
if self.redis_cache is not None:
if result is None and self.redis_cache is not None and local_only == False:
# If not found in in-memory cache, try fetching from Redis
redis_result = self.redis_cache.get_cache(key, **kwargs)

View file

@ -84,6 +84,7 @@ class Router:
caching_groups: Optional[
List[tuple]
] = None, # if you want to cache across model groups
client_ttl: int = 3600, # ttl for cached clients - will re-initialize after this time in seconds
## RELIABILITY ##
num_retries: int = 0,
timeout: Optional[float] = None,
@ -106,6 +107,43 @@ class Router:
[]
) # names of models under litellm_params. ex. azure/chatgpt-v-2
self.deployment_latency_map = {}
### CACHING ###
cache_type: Literal["local", "redis"] = "local" # default to an in-memory cache
redis_cache = None
cache_config = {}
self.client_ttl = client_ttl
if redis_url is not None or (
redis_host is not None
and redis_port is not None
and redis_password is not None
):
cache_type = "redis"
if redis_url is not None:
cache_config["url"] = redis_url
if redis_host is not None:
cache_config["host"] = redis_host
if redis_port is not None:
cache_config["port"] = str(redis_port) # type: ignore
if redis_password is not None:
cache_config["password"] = redis_password
# Add additional key-value pairs from cache_kwargs
cache_config.update(cache_kwargs)
redis_cache = RedisCache(**cache_config)
if cache_responses:
if litellm.cache is None:
# the cache can be initialized on the proxy server. We should not overwrite it
litellm.cache = litellm.Cache(type=cache_type, **cache_config) # type: ignore
self.cache_responses = cache_responses
self.cache = DualCache(
redis_cache=redis_cache, in_memory_cache=InMemoryCache()
) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
if model_list:
model_list = copy.deepcopy(model_list)
self.set_model_list(model_list)
@ -155,40 +193,6 @@ class Router:
{"caching_groups": caching_groups}
)
### CACHING ###
cache_type: Literal["local", "redis"] = "local" # default to an in-memory cache
redis_cache = None
cache_config = {}
if redis_url is not None or (
redis_host is not None
and redis_port is not None
and redis_password is not None
):
cache_type = "redis"
if redis_url is not None:
cache_config["url"] = redis_url
if redis_host is not None:
cache_config["host"] = redis_host
if redis_port is not None:
cache_config["port"] = str(redis_port) # type: ignore
if redis_password is not None:
cache_config["password"] = redis_password
# Add additional key-value pairs from cache_kwargs
cache_config.update(cache_kwargs)
redis_cache = RedisCache(**cache_config)
if cache_responses:
if litellm.cache is None:
# the cache can be initialized on the proxy server. We should not overwrite it
litellm.cache = litellm.Cache(type=cache_type, **cache_config) # type: ignore
self.cache_responses = cache_responses
self.cache = DualCache(
redis_cache=redis_cache, in_memory_cache=InMemoryCache()
) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
### ROUTING SETUP ###
if routing_strategy == "least-busy":
self.leastbusy_logger = LeastBusyLoggingHandler(
@ -1248,208 +1252,297 @@ class Router:
chosen_item = random.choices(sorted_items, weights=weights, k=1)[0][0]
return chosen_item
def set_client(self, model: dict):
"""
Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
"""
client_ttl = self.client_ttl
litellm_params = model.get("litellm_params", {})
model_name = litellm_params.get("model")
model_id = model["model_info"]["id"]
#### for OpenAI / Azure we need to initalize the Client for High Traffic ########
custom_llm_provider = litellm_params.get("custom_llm_provider")
custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or ""
default_api_base = None
default_api_key = None
if custom_llm_provider in litellm.openai_compatible_providers:
_, custom_llm_provider, api_key, api_base = litellm.get_llm_provider(
model=model_name
)
default_api_base = api_base
default_api_key = api_key
if (
model_name in litellm.open_ai_chat_completion_models
or custom_llm_provider in litellm.openai_compatible_providers
or custom_llm_provider == "azure"
or custom_llm_provider == "custom_openai"
or custom_llm_provider == "openai"
or "ft:gpt-3.5-turbo" in model_name
or model_name in litellm.open_ai_embedding_models
):
# glorified / complicated reading of configs
# user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env
# we do this here because we init clients for Azure, OpenAI and we need to set the right key
api_key = litellm_params.get("api_key") or default_api_key
if api_key and api_key.startswith("os.environ/"):
api_key_env_name = api_key.replace("os.environ/", "")
api_key = litellm.get_secret(api_key_env_name)
litellm_params["api_key"] = api_key
api_base = litellm_params.get("api_base")
base_url = litellm_params.get("base_url")
api_base = (
api_base or base_url or default_api_base
) # allow users to pass in `api_base` or `base_url` for azure
if api_base and api_base.startswith("os.environ/"):
api_base_env_name = api_base.replace("os.environ/", "")
api_base = litellm.get_secret(api_base_env_name)
litellm_params["api_base"] = api_base
api_version = litellm_params.get("api_version")
if api_version and api_version.startswith("os.environ/"):
api_version_env_name = api_version.replace("os.environ/", "")
api_version = litellm.get_secret(api_version_env_name)
litellm_params["api_version"] = api_version
timeout = litellm_params.pop("timeout", None)
if isinstance(timeout, str) and timeout.startswith("os.environ/"):
timeout_env_name = timeout.replace("os.environ/", "")
timeout = litellm.get_secret(timeout_env_name)
litellm_params["timeout"] = timeout
stream_timeout = litellm_params.pop(
"stream_timeout", timeout
) # if no stream_timeout is set, default to timeout
if isinstance(stream_timeout, str) and stream_timeout.startswith(
"os.environ/"
):
stream_timeout_env_name = stream_timeout.replace("os.environ/", "")
stream_timeout = litellm.get_secret(stream_timeout_env_name)
litellm_params["stream_timeout"] = stream_timeout
max_retries = litellm_params.pop("max_retries", 2)
if isinstance(max_retries, str) and max_retries.startswith("os.environ/"):
max_retries_env_name = max_retries.replace("os.environ/", "")
max_retries = litellm.get_secret(max_retries_env_name)
litellm_params["max_retries"] = max_retries
if "azure" in model_name:
if api_base is None:
raise ValueError(
f"api_base is required for Azure OpenAI. Set it on your config. Model - {model}"
)
if api_version is None:
api_version = "2023-07-01-preview"
if "gateway.ai.cloudflare.com" in api_base:
if not api_base.endswith("/"):
api_base += "/"
azure_model = model_name.replace("azure/", "")
api_base += f"{azure_model}"
cache_key = f"{model_id}_async_client"
_client = openai.AsyncAzureOpenAI(
api_key=api_key,
base_url=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
)
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
_client = openai.AzureOpenAI( # type: ignore
api_key=api_key,
base_url=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
)
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
# streaming clients can have diff timeouts
cache_key = f"{model_id}_stream_async_client"
_client = openai.AsyncAzureOpenAI( # type: ignore
api_key=api_key,
base_url=api_base,
api_version=api_version,
timeout=stream_timeout,
max_retries=max_retries,
)
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
cache_key = f"{model_id}_stream_client"
_client = openai.AzureOpenAI( # type: ignore
api_key=api_key,
base_url=api_base,
api_version=api_version,
timeout=stream_timeout,
max_retries=max_retries,
)
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
else:
self.print_verbose(
f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{api_key}"
)
cache_key = f"{model_id}_async_client"
_client = openai.AsyncAzureOpenAI( # type: ignore
api_key=api_key,
azure_endpoint=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(),
), # type: ignore
)
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
cache_key = f"{model_id}_client"
_client = openai.AzureOpenAI( # type: ignore
api_key=api_key,
azure_endpoint=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
http_client=httpx.Client(
transport=CustomHTTPTransport(),
), # type: ignore
)
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
# streaming clients should have diff timeouts
cache_key = f"{model_id}_stream_async_client"
_client = openai.AsyncAzureOpenAI( # type: ignore
api_key=api_key,
azure_endpoint=api_base,
api_version=api_version,
timeout=stream_timeout,
max_retries=max_retries,
)
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
cache_key = f"{model_id}_stream_client"
_client = openai.AzureOpenAI( # type: ignore
api_key=api_key,
azure_endpoint=api_base,
api_version=api_version,
timeout=stream_timeout,
max_retries=max_retries,
)
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
else:
self.print_verbose(
f"Initializing OpenAI Client for {model_name}, Api Base:{str(api_base)}, Api Key:{api_key}"
)
cache_key = f"{model_id}_async_client"
_client = openai.AsyncOpenAI( # type: ignore
api_key=api_key,
base_url=api_base,
timeout=timeout,
max_retries=max_retries,
)
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
cache_key = f"{model_id}_client"
_client = openai.OpenAI( # type: ignore
api_key=api_key,
base_url=api_base,
timeout=timeout,
max_retries=max_retries,
)
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
# streaming clients should have diff timeouts
cache_key = f"{model_id}_stream_async_client"
_client = openai.AsyncOpenAI( # type: ignore
api_key=api_key,
base_url=api_base,
timeout=stream_timeout,
max_retries=max_retries,
)
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
# streaming clients should have diff timeouts
cache_key = f"{model_id}_stream_client"
_client = openai.OpenAI( # type: ignore
api_key=api_key,
base_url=api_base,
timeout=stream_timeout,
max_retries=max_retries,
)
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
def set_model_list(self, model_list: list):
self.model_list = copy.deepcopy(model_list)
# we add api_base/api_key each model so load balancing between azure/gpt on api_base1 and api_base2 works
import os
for model in self.model_list:
litellm_params = model.get("litellm_params", {})
model_name = litellm_params.get("model")
#### MODEL ID INIT ########
model_info = model.get("model_info", {})
model_info["id"] = model_info.get("id", str(uuid.uuid4()))
model["model_info"] = model_info
#### for OpenAI / Azure we need to initalize the Client for High Traffic ########
custom_llm_provider = litellm_params.get("custom_llm_provider")
custom_llm_provider = (
custom_llm_provider or model_name.split("/", 1)[0] or ""
)
default_api_base = None
default_api_key = None
if custom_llm_provider in litellm.openai_compatible_providers:
_, custom_llm_provider, api_key, api_base = litellm.get_llm_provider(
model=model_name
)
default_api_base = api_base
default_api_key = api_key
if (
model_name in litellm.open_ai_chat_completion_models
or custom_llm_provider in litellm.openai_compatible_providers
or custom_llm_provider == "azure"
or custom_llm_provider == "custom_openai"
or custom_llm_provider == "openai"
or "ft:gpt-3.5-turbo" in model_name
or model_name in litellm.open_ai_embedding_models
):
# glorified / complicated reading of configs
# user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env
# we do this here because we init clients for Azure, OpenAI and we need to set the right key
api_key = litellm_params.get("api_key") or default_api_key
if api_key and api_key.startswith("os.environ/"):
api_key_env_name = api_key.replace("os.environ/", "")
api_key = litellm.get_secret(api_key_env_name)
litellm_params["api_key"] = api_key
api_base = litellm_params.get("api_base")
base_url = litellm_params.get("base_url")
api_base = (
api_base or base_url or default_api_base
) # allow users to pass in `api_base` or `base_url` for azure
if api_base and api_base.startswith("os.environ/"):
api_base_env_name = api_base.replace("os.environ/", "")
api_base = litellm.get_secret(api_base_env_name)
litellm_params["api_base"] = api_base
api_version = litellm_params.get("api_version")
if api_version and api_version.startswith("os.environ/"):
api_version_env_name = api_version.replace("os.environ/", "")
api_version = litellm.get_secret(api_version_env_name)
litellm_params["api_version"] = api_version
timeout = litellm_params.pop("timeout", None)
if isinstance(timeout, str) and timeout.startswith("os.environ/"):
timeout_env_name = timeout.replace("os.environ/", "")
timeout = litellm.get_secret(timeout_env_name)
litellm_params["timeout"] = timeout
stream_timeout = litellm_params.pop(
"stream_timeout", timeout
) # if no stream_timeout is set, default to timeout
if isinstance(stream_timeout, str) and stream_timeout.startswith(
"os.environ/"
):
stream_timeout_env_name = stream_timeout.replace("os.environ/", "")
stream_timeout = litellm.get_secret(stream_timeout_env_name)
litellm_params["stream_timeout"] = stream_timeout
max_retries = litellm_params.pop("max_retries", 2)
if isinstance(max_retries, str) and max_retries.startswith(
"os.environ/"
):
max_retries_env_name = max_retries.replace("os.environ/", "")
max_retries = litellm.get_secret(max_retries_env_name)
litellm_params["max_retries"] = max_retries
if "azure" in model_name:
if api_base is None:
raise ValueError(
f"api_base is required for Azure OpenAI. Set it on your config. Model - {model}"
)
if api_version is None:
api_version = "2023-07-01-preview"
if "gateway.ai.cloudflare.com" in api_base:
if not api_base.endswith("/"):
api_base += "/"
azure_model = model_name.replace("azure/", "")
api_base += f"{azure_model}"
model["async_client"] = openai.AsyncAzureOpenAI(
api_key=api_key,
base_url=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
)
model["client"] = openai.AzureOpenAI(
api_key=api_key,
base_url=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
)
# streaming clients can have diff timeouts
model["stream_async_client"] = openai.AsyncAzureOpenAI(
api_key=api_key,
base_url=api_base,
api_version=api_version,
timeout=stream_timeout,
max_retries=max_retries,
)
model["stream_client"] = openai.AzureOpenAI(
api_key=api_key,
base_url=api_base,
api_version=api_version,
timeout=stream_timeout,
max_retries=max_retries,
)
else:
self.print_verbose(
f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{api_key}"
)
model["async_client"] = openai.AsyncAzureOpenAI(
api_key=api_key,
azure_endpoint=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(),
), # type: ignore
)
model["client"] = openai.AzureOpenAI(
api_key=api_key,
azure_endpoint=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
http_client=httpx.Client(
transport=CustomHTTPTransport(),
), # type: ignore
)
# streaming clients should have diff timeouts
model["stream_async_client"] = openai.AsyncAzureOpenAI(
api_key=api_key,
azure_endpoint=api_base,
api_version=api_version,
timeout=stream_timeout,
max_retries=max_retries,
)
model["stream_client"] = openai.AzureOpenAI(
api_key=api_key,
azure_endpoint=api_base,
api_version=api_version,
timeout=stream_timeout,
max_retries=max_retries,
)
else:
self.print_verbose(
f"Initializing OpenAI Client for {model_name}, Api Base:{str(api_base)}, Api Key:{api_key}"
)
model["async_client"] = openai.AsyncOpenAI(
api_key=api_key,
base_url=api_base,
timeout=timeout,
max_retries=max_retries,
)
model["client"] = openai.OpenAI(
api_key=api_key,
base_url=api_base,
timeout=timeout,
max_retries=max_retries,
)
# streaming clients should have diff timeouts
model["stream_async_client"] = openai.AsyncOpenAI(
api_key=api_key,
base_url=api_base,
timeout=stream_timeout,
max_retries=max_retries,
)
# streaming clients should have diff timeouts
model["stream_client"] = openai.OpenAI(
api_key=api_key,
base_url=api_base,
timeout=stream_timeout,
max_retries=max_retries,
)
############ End of initializing Clients for OpenAI/Azure ###################
#### DEPLOYMENT NAMES INIT ########
self.deployment_names.append(model["litellm_params"]["model"])
############ Users can either pass tpm/rpm as a litellm_param or a router param ###########
# for get_available_deployment, we use the litellm_param["rpm"]
# in this snippet we also set rpm to be a litellm_param
@ -1464,6 +1557,8 @@ class Router:
):
model["litellm_params"]["tpm"] = model.get("tpm")
self.set_client(model=model)
self.print_verbose(f"\nInitialized Model List {self.model_list}")
self.model_names = [m["model_name"] for m in model_list]
@ -1482,16 +1577,49 @@ class Router:
Returns:
The appropriate client based on the given client_type and kwargs.
"""
model_id = deployment["model_info"]["id"]
if client_type == "async":
if kwargs.get("stream") == True:
return deployment.get("stream_async_client", None)
cache_key = f"{model_id}_stream_async_client"
client = self.cache.get_cache(key=cache_key, local_only=True)
if client is None:
"""
Re-initialize the client
"""
self.set_client(model=deployment)
client = self.cache.get_cache(key=cache_key, local_only=True)
return client
else:
return deployment.get("async_client", None)
cache_key = f"{model_id}_async_client"
client = self.cache.get_cache(key=cache_key, local_only=True)
if client is None:
"""
Re-initialize the client
"""
self.set_client(model=deployment)
client = self.cache.get_cache(key=cache_key, local_only=True)
return client
else:
if kwargs.get("stream") == True:
return deployment.get("stream_client", None)
cache_key = f"{model_id}_stream_client"
client = self.cache.get_cache(key=cache_key)
if client is None:
"""
Re-initialize the client
"""
self.set_client(model=deployment)
client = self.cache.get_cache(key=cache_key)
return client
else:
return deployment.get("client", None)
cache_key = f"{model_id}_client"
client = self.cache.get_cache(key=cache_key)
if client is None:
"""
Re-initialize the client
"""
self.set_client(model=deployment)
client = self.cache.get_cache(key=cache_key)
return client
def print_verbose(self, print_statement):
try:

View file

@ -778,7 +778,8 @@ def test_reading_keys_os_environ():
os.environ["AZURE_MAX_RETRIES"]
), f"{model['litellm_params']['max_retries']} vs {os.environ['AZURE_MAX_RETRIES']}"
print("passed testing of reading keys from os.environ")
async_client: openai.AsyncAzureOpenAI = model["async_client"] # type: ignore
model_id = model["model_info"]["id"]
async_client: openai.AsyncAzureOpenAI = router.cache.get_cache(f"{model_id}_async_client") # type: ignore
assert async_client.api_key == os.environ["AZURE_API_KEY"]
assert async_client.base_url == os.environ["AZURE_API_BASE"]
assert async_client.max_retries == (
@ -791,7 +792,7 @@ def test_reading_keys_os_environ():
print("\n Testing async streaming client")
stream_async_client: openai.AsyncAzureOpenAI = model["stream_async_client"] # type: ignore
stream_async_client: openai.AsyncAzureOpenAI = router.cache.get_cache(f"{model_id}_stream_async_client") # type: ignore
assert stream_async_client.api_key == os.environ["AZURE_API_KEY"]
assert stream_async_client.base_url == os.environ["AZURE_API_BASE"]
assert stream_async_client.max_retries == (
@ -803,7 +804,7 @@ def test_reading_keys_os_environ():
print("async stream client set correctly!")
print("\n Testing sync client")
client: openai.AzureOpenAI = model["client"] # type: ignore
client: openai.AzureOpenAI = router.cache.get_cache(f"{model_id}_client") # type: ignore
assert client.api_key == os.environ["AZURE_API_KEY"]
assert client.base_url == os.environ["AZURE_API_BASE"]
assert client.max_retries == (
@ -815,7 +816,7 @@ def test_reading_keys_os_environ():
print("sync client set correctly!")
print("\n Testing sync stream client")
stream_client: openai.AzureOpenAI = model["stream_client"] # type: ignore
stream_client: openai.AzureOpenAI = router.cache.get_cache(f"{model_id}_stream_client") # type: ignore
assert stream_client.api_key == os.environ["AZURE_API_KEY"]
assert stream_client.base_url == os.environ["AZURE_API_BASE"]
assert stream_client.max_retries == (
@ -877,7 +878,8 @@ def test_reading_openai_keys_os_environ():
os.environ["AZURE_MAX_RETRIES"]
), f"{model['litellm_params']['max_retries']} vs {os.environ['AZURE_MAX_RETRIES']}"
print("passed testing of reading keys from os.environ")
async_client: openai.AsyncOpenAI = model["async_client"] # type: ignore
model_id = model["model_info"]["id"]
async_client: openai.AsyncOpenAI = router.cache.get_cache(key=f"{model_id}_async_client") # type: ignore
assert async_client.api_key == os.environ["OPENAI_API_KEY"]
assert async_client.max_retries == (
os.environ["AZURE_MAX_RETRIES"]
@ -889,7 +891,7 @@ def test_reading_openai_keys_os_environ():
print("\n Testing async streaming client")
stream_async_client: openai.AsyncOpenAI = model["stream_async_client"] # type: ignore
stream_async_client: openai.AsyncOpenAI = router.cache.get_cache(key=f"{model_id}_stream_async_client") # type: ignore
assert stream_async_client.api_key == os.environ["OPENAI_API_KEY"]
assert stream_async_client.max_retries == (
os.environ["AZURE_MAX_RETRIES"]
@ -900,7 +902,7 @@ def test_reading_openai_keys_os_environ():
print("async stream client set correctly!")
print("\n Testing sync client")
client: openai.AzureOpenAI = model["client"] # type: ignore
client: openai.AzureOpenAI = router.cache.get_cache(key=f"{model_id}_client") # type: ignore
assert client.api_key == os.environ["OPENAI_API_KEY"]
assert client.max_retries == (
os.environ["AZURE_MAX_RETRIES"]
@ -911,7 +913,7 @@ def test_reading_openai_keys_os_environ():
print("sync client set correctly!")
print("\n Testing sync stream client")
stream_client: openai.AzureOpenAI = model["stream_client"] # type: ignore
stream_client: openai.AzureOpenAI = router.cache.get_cache(key=f"{model_id}_stream_client") # type: ignore
assert stream_client.api_key == os.environ["OPENAI_API_KEY"]
assert stream_client.max_retries == (
os.environ["AZURE_MAX_RETRIES"]

View file

@ -0,0 +1,78 @@
#### What this tests ####
# This tests client initialization + reinitialization on the router
#### What this tests ####
# This tests caching on the router
import sys, os, time
import traceback, asyncio
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm import Router
async def test_router_init():
"""
1. Initializes clients on the router with 0
2. Checks if client is still valid
3. Checks if new client was initialized
"""
model_list = [
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo-0613",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"model_info": {"id": "1234"},
"tpm": 100000,
"rpm": 10000,
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_base": os.getenv("AZURE_API_BASE"),
"api_version": os.getenv("AZURE_API_VERSION"),
},
"tpm": 100000,
"rpm": 10000,
},
]
messages = [
{"role": "user", "content": f"write a one sentence poem {time.time()}?"}
]
client_ttl_time = 2
router = Router(
model_list=model_list,
redis_host=os.environ["REDIS_HOST"],
redis_password=os.environ["REDIS_PASSWORD"],
redis_port=os.environ["REDIS_PORT"],
cache_responses=True,
timeout=30,
routing_strategy="simple-shuffle",
client_ttl=client_ttl_time,
)
model = "gpt-3.5-turbo"
cache_key = f"1234_async_client"
## ASSERT IT EXISTS AT THE START ##
assert router.cache.get_cache(key=cache_key) is not None
response1 = await router.acompletion(model=model, messages=messages, temperature=1)
await asyncio.sleep(client_ttl_time)
## ASSERT IT'S CLEARED FROM CACHE ##
assert router.cache.get_cache(key=cache_key, local_only=True) is None
## ASSERT IT EXISTS AFTER RUNNING __GET_CLIENT() ##
assert (
router._get_client(
deployment=model_list[0], client_type="async", kwargs={"stream": False}
)
is not None
)
# asyncio.run(test_router_init())