forked from phoenix/litellm-mirror
fix(router.py): periodically re-initialize azure/openai clients to solve max conn issue
This commit is contained in:
parent
d089157925
commit
69935db239
4 changed files with 451 additions and 242 deletions
|
@ -133,30 +133,31 @@ class DualCache(BaseCache):
|
|||
# If redis_cache is not provided, use the default RedisCache
|
||||
self.redis_cache = redis_cache
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
def set_cache(self, key, value, local_only: bool = False, **kwargs):
|
||||
# Update both Redis and in-memory cache
|
||||
try:
|
||||
print_verbose(f"set cache: key: {key}; value: {value}")
|
||||
if self.in_memory_cache is not None:
|
||||
self.in_memory_cache.set_cache(key, value, **kwargs)
|
||||
|
||||
if self.redis_cache is not None:
|
||||
if self.redis_cache is not None and local_only == False:
|
||||
self.redis_cache.set_cache(key, value, **kwargs)
|
||||
except Exception as e:
|
||||
print_verbose(e)
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
def get_cache(self, key, local_only: bool = False, **kwargs):
|
||||
# Try to fetch from in-memory cache first
|
||||
try:
|
||||
print_verbose(f"get cache: cache key: {key}")
|
||||
print_verbose(f"get cache: cache key: {key}; local_only: {local_only}")
|
||||
result = None
|
||||
if self.in_memory_cache is not None:
|
||||
in_memory_result = self.in_memory_cache.get_cache(key, **kwargs)
|
||||
|
||||
print_verbose(f"in_memory_result: {in_memory_result}")
|
||||
if in_memory_result is not None:
|
||||
result = in_memory_result
|
||||
|
||||
if self.redis_cache is not None:
|
||||
if result is None and self.redis_cache is not None and local_only == False:
|
||||
# If not found in in-memory cache, try fetching from Redis
|
||||
redis_result = self.redis_cache.get_cache(key, **kwargs)
|
||||
|
||||
|
|
|
@ -84,6 +84,7 @@ class Router:
|
|||
caching_groups: Optional[
|
||||
List[tuple]
|
||||
] = None, # if you want to cache across model groups
|
||||
client_ttl: int = 3600, # ttl for cached clients - will re-initialize after this time in seconds
|
||||
## RELIABILITY ##
|
||||
num_retries: int = 0,
|
||||
timeout: Optional[float] = None,
|
||||
|
@ -106,6 +107,43 @@ class Router:
|
|||
[]
|
||||
) # names of models under litellm_params. ex. azure/chatgpt-v-2
|
||||
self.deployment_latency_map = {}
|
||||
### CACHING ###
|
||||
cache_type: Literal["local", "redis"] = "local" # default to an in-memory cache
|
||||
redis_cache = None
|
||||
cache_config = {}
|
||||
self.client_ttl = client_ttl
|
||||
if redis_url is not None or (
|
||||
redis_host is not None
|
||||
and redis_port is not None
|
||||
and redis_password is not None
|
||||
):
|
||||
cache_type = "redis"
|
||||
|
||||
if redis_url is not None:
|
||||
cache_config["url"] = redis_url
|
||||
|
||||
if redis_host is not None:
|
||||
cache_config["host"] = redis_host
|
||||
|
||||
if redis_port is not None:
|
||||
cache_config["port"] = str(redis_port) # type: ignore
|
||||
|
||||
if redis_password is not None:
|
||||
cache_config["password"] = redis_password
|
||||
|
||||
# Add additional key-value pairs from cache_kwargs
|
||||
cache_config.update(cache_kwargs)
|
||||
redis_cache = RedisCache(**cache_config)
|
||||
|
||||
if cache_responses:
|
||||
if litellm.cache is None:
|
||||
# the cache can be initialized on the proxy server. We should not overwrite it
|
||||
litellm.cache = litellm.Cache(type=cache_type, **cache_config) # type: ignore
|
||||
self.cache_responses = cache_responses
|
||||
self.cache = DualCache(
|
||||
redis_cache=redis_cache, in_memory_cache=InMemoryCache()
|
||||
) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
|
||||
|
||||
if model_list:
|
||||
model_list = copy.deepcopy(model_list)
|
||||
self.set_model_list(model_list)
|
||||
|
@ -155,40 +193,6 @@ class Router:
|
|||
{"caching_groups": caching_groups}
|
||||
)
|
||||
|
||||
### CACHING ###
|
||||
cache_type: Literal["local", "redis"] = "local" # default to an in-memory cache
|
||||
redis_cache = None
|
||||
cache_config = {}
|
||||
if redis_url is not None or (
|
||||
redis_host is not None
|
||||
and redis_port is not None
|
||||
and redis_password is not None
|
||||
):
|
||||
cache_type = "redis"
|
||||
|
||||
if redis_url is not None:
|
||||
cache_config["url"] = redis_url
|
||||
|
||||
if redis_host is not None:
|
||||
cache_config["host"] = redis_host
|
||||
|
||||
if redis_port is not None:
|
||||
cache_config["port"] = str(redis_port) # type: ignore
|
||||
|
||||
if redis_password is not None:
|
||||
cache_config["password"] = redis_password
|
||||
|
||||
# Add additional key-value pairs from cache_kwargs
|
||||
cache_config.update(cache_kwargs)
|
||||
redis_cache = RedisCache(**cache_config)
|
||||
if cache_responses:
|
||||
if litellm.cache is None:
|
||||
# the cache can be initialized on the proxy server. We should not overwrite it
|
||||
litellm.cache = litellm.Cache(type=cache_type, **cache_config) # type: ignore
|
||||
self.cache_responses = cache_responses
|
||||
self.cache = DualCache(
|
||||
redis_cache=redis_cache, in_memory_cache=InMemoryCache()
|
||||
) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
|
||||
### ROUTING SETUP ###
|
||||
if routing_strategy == "least-busy":
|
||||
self.leastbusy_logger = LeastBusyLoggingHandler(
|
||||
|
@ -1248,208 +1252,297 @@ class Router:
|
|||
chosen_item = random.choices(sorted_items, weights=weights, k=1)[0][0]
|
||||
return chosen_item
|
||||
|
||||
def set_client(self, model: dict):
|
||||
"""
|
||||
Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
|
||||
"""
|
||||
client_ttl = self.client_ttl
|
||||
litellm_params = model.get("litellm_params", {})
|
||||
model_name = litellm_params.get("model")
|
||||
model_id = model["model_info"]["id"]
|
||||
#### for OpenAI / Azure we need to initalize the Client for High Traffic ########
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or ""
|
||||
default_api_base = None
|
||||
default_api_key = None
|
||||
if custom_llm_provider in litellm.openai_compatible_providers:
|
||||
_, custom_llm_provider, api_key, api_base = litellm.get_llm_provider(
|
||||
model=model_name
|
||||
)
|
||||
default_api_base = api_base
|
||||
default_api_key = api_key
|
||||
if (
|
||||
model_name in litellm.open_ai_chat_completion_models
|
||||
or custom_llm_provider in litellm.openai_compatible_providers
|
||||
or custom_llm_provider == "azure"
|
||||
or custom_llm_provider == "custom_openai"
|
||||
or custom_llm_provider == "openai"
|
||||
or "ft:gpt-3.5-turbo" in model_name
|
||||
or model_name in litellm.open_ai_embedding_models
|
||||
):
|
||||
# glorified / complicated reading of configs
|
||||
# user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env
|
||||
# we do this here because we init clients for Azure, OpenAI and we need to set the right key
|
||||
api_key = litellm_params.get("api_key") or default_api_key
|
||||
if api_key and api_key.startswith("os.environ/"):
|
||||
api_key_env_name = api_key.replace("os.environ/", "")
|
||||
api_key = litellm.get_secret(api_key_env_name)
|
||||
litellm_params["api_key"] = api_key
|
||||
|
||||
api_base = litellm_params.get("api_base")
|
||||
base_url = litellm_params.get("base_url")
|
||||
api_base = (
|
||||
api_base or base_url or default_api_base
|
||||
) # allow users to pass in `api_base` or `base_url` for azure
|
||||
if api_base and api_base.startswith("os.environ/"):
|
||||
api_base_env_name = api_base.replace("os.environ/", "")
|
||||
api_base = litellm.get_secret(api_base_env_name)
|
||||
litellm_params["api_base"] = api_base
|
||||
|
||||
api_version = litellm_params.get("api_version")
|
||||
if api_version and api_version.startswith("os.environ/"):
|
||||
api_version_env_name = api_version.replace("os.environ/", "")
|
||||
api_version = litellm.get_secret(api_version_env_name)
|
||||
litellm_params["api_version"] = api_version
|
||||
|
||||
timeout = litellm_params.pop("timeout", None)
|
||||
if isinstance(timeout, str) and timeout.startswith("os.environ/"):
|
||||
timeout_env_name = timeout.replace("os.environ/", "")
|
||||
timeout = litellm.get_secret(timeout_env_name)
|
||||
litellm_params["timeout"] = timeout
|
||||
|
||||
stream_timeout = litellm_params.pop(
|
||||
"stream_timeout", timeout
|
||||
) # if no stream_timeout is set, default to timeout
|
||||
if isinstance(stream_timeout, str) and stream_timeout.startswith(
|
||||
"os.environ/"
|
||||
):
|
||||
stream_timeout_env_name = stream_timeout.replace("os.environ/", "")
|
||||
stream_timeout = litellm.get_secret(stream_timeout_env_name)
|
||||
litellm_params["stream_timeout"] = stream_timeout
|
||||
|
||||
max_retries = litellm_params.pop("max_retries", 2)
|
||||
if isinstance(max_retries, str) and max_retries.startswith("os.environ/"):
|
||||
max_retries_env_name = max_retries.replace("os.environ/", "")
|
||||
max_retries = litellm.get_secret(max_retries_env_name)
|
||||
litellm_params["max_retries"] = max_retries
|
||||
|
||||
if "azure" in model_name:
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
f"api_base is required for Azure OpenAI. Set it on your config. Model - {model}"
|
||||
)
|
||||
if api_version is None:
|
||||
api_version = "2023-07-01-preview"
|
||||
if "gateway.ai.cloudflare.com" in api_base:
|
||||
if not api_base.endswith("/"):
|
||||
api_base += "/"
|
||||
azure_model = model_name.replace("azure/", "")
|
||||
api_base += f"{azure_model}"
|
||||
cache_key = f"{model_id}_async_client"
|
||||
_client = openai.AsyncAzureOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
self.cache.set_cache(
|
||||
key=cache_key,
|
||||
value=_client,
|
||||
ttl=client_ttl,
|
||||
local_only=True,
|
||||
) # cache for 1 hr
|
||||
|
||||
_client = openai.AzureOpenAI( # type: ignore
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
self.cache.set_cache(
|
||||
key=cache_key,
|
||||
value=_client,
|
||||
ttl=client_ttl,
|
||||
local_only=True,
|
||||
) # cache for 1 hr
|
||||
# streaming clients can have diff timeouts
|
||||
cache_key = f"{model_id}_stream_async_client"
|
||||
_client = openai.AsyncAzureOpenAI( # type: ignore
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=stream_timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
self.cache.set_cache(
|
||||
key=cache_key,
|
||||
value=_client,
|
||||
ttl=client_ttl,
|
||||
local_only=True,
|
||||
) # cache for 1 hr
|
||||
|
||||
cache_key = f"{model_id}_stream_client"
|
||||
_client = openai.AzureOpenAI( # type: ignore
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=stream_timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
self.cache.set_cache(
|
||||
key=cache_key,
|
||||
value=_client,
|
||||
ttl=client_ttl,
|
||||
local_only=True,
|
||||
) # cache for 1 hr
|
||||
else:
|
||||
self.print_verbose(
|
||||
f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{api_key}"
|
||||
)
|
||||
|
||||
cache_key = f"{model_id}_async_client"
|
||||
_client = openai.AsyncAzureOpenAI( # type: ignore
|
||||
api_key=api_key,
|
||||
azure_endpoint=api_base,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
http_client=httpx.AsyncClient(
|
||||
transport=AsyncCustomHTTPTransport(),
|
||||
), # type: ignore
|
||||
)
|
||||
self.cache.set_cache(
|
||||
key=cache_key,
|
||||
value=_client,
|
||||
ttl=client_ttl,
|
||||
local_only=True,
|
||||
) # cache for 1 hr
|
||||
|
||||
cache_key = f"{model_id}_client"
|
||||
_client = openai.AzureOpenAI( # type: ignore
|
||||
api_key=api_key,
|
||||
azure_endpoint=api_base,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
http_client=httpx.Client(
|
||||
transport=CustomHTTPTransport(),
|
||||
), # type: ignore
|
||||
)
|
||||
self.cache.set_cache(
|
||||
key=cache_key,
|
||||
value=_client,
|
||||
ttl=client_ttl,
|
||||
local_only=True,
|
||||
) # cache for 1 hr
|
||||
|
||||
# streaming clients should have diff timeouts
|
||||
cache_key = f"{model_id}_stream_async_client"
|
||||
_client = openai.AsyncAzureOpenAI( # type: ignore
|
||||
api_key=api_key,
|
||||
azure_endpoint=api_base,
|
||||
api_version=api_version,
|
||||
timeout=stream_timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
self.cache.set_cache(
|
||||
key=cache_key,
|
||||
value=_client,
|
||||
ttl=client_ttl,
|
||||
local_only=True,
|
||||
) # cache for 1 hr
|
||||
|
||||
cache_key = f"{model_id}_stream_client"
|
||||
_client = openai.AzureOpenAI( # type: ignore
|
||||
api_key=api_key,
|
||||
azure_endpoint=api_base,
|
||||
api_version=api_version,
|
||||
timeout=stream_timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
self.cache.set_cache(
|
||||
key=cache_key,
|
||||
value=_client,
|
||||
ttl=client_ttl,
|
||||
local_only=True,
|
||||
) # cache for 1 hr
|
||||
|
||||
else:
|
||||
self.print_verbose(
|
||||
f"Initializing OpenAI Client for {model_name}, Api Base:{str(api_base)}, Api Key:{api_key}"
|
||||
)
|
||||
cache_key = f"{model_id}_async_client"
|
||||
_client = openai.AsyncOpenAI( # type: ignore
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
self.cache.set_cache(
|
||||
key=cache_key,
|
||||
value=_client,
|
||||
ttl=client_ttl,
|
||||
local_only=True,
|
||||
) # cache for 1 hr
|
||||
|
||||
cache_key = f"{model_id}_client"
|
||||
_client = openai.OpenAI( # type: ignore
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
self.cache.set_cache(
|
||||
key=cache_key,
|
||||
value=_client,
|
||||
ttl=client_ttl,
|
||||
local_only=True,
|
||||
) # cache for 1 hr
|
||||
|
||||
# streaming clients should have diff timeouts
|
||||
cache_key = f"{model_id}_stream_async_client"
|
||||
_client = openai.AsyncOpenAI( # type: ignore
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
timeout=stream_timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
self.cache.set_cache(
|
||||
key=cache_key,
|
||||
value=_client,
|
||||
ttl=client_ttl,
|
||||
local_only=True,
|
||||
) # cache for 1 hr
|
||||
|
||||
# streaming clients should have diff timeouts
|
||||
cache_key = f"{model_id}_stream_client"
|
||||
_client = openai.OpenAI( # type: ignore
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
timeout=stream_timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
self.cache.set_cache(
|
||||
key=cache_key,
|
||||
value=_client,
|
||||
ttl=client_ttl,
|
||||
local_only=True,
|
||||
) # cache for 1 hr
|
||||
|
||||
def set_model_list(self, model_list: list):
|
||||
self.model_list = copy.deepcopy(model_list)
|
||||
# we add api_base/api_key each model so load balancing between azure/gpt on api_base1 and api_base2 works
|
||||
import os
|
||||
|
||||
for model in self.model_list:
|
||||
litellm_params = model.get("litellm_params", {})
|
||||
model_name = litellm_params.get("model")
|
||||
#### MODEL ID INIT ########
|
||||
model_info = model.get("model_info", {})
|
||||
model_info["id"] = model_info.get("id", str(uuid.uuid4()))
|
||||
model["model_info"] = model_info
|
||||
#### for OpenAI / Azure we need to initalize the Client for High Traffic ########
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
custom_llm_provider = (
|
||||
custom_llm_provider or model_name.split("/", 1)[0] or ""
|
||||
)
|
||||
default_api_base = None
|
||||
default_api_key = None
|
||||
if custom_llm_provider in litellm.openai_compatible_providers:
|
||||
_, custom_llm_provider, api_key, api_base = litellm.get_llm_provider(
|
||||
model=model_name
|
||||
)
|
||||
default_api_base = api_base
|
||||
default_api_key = api_key
|
||||
if (
|
||||
model_name in litellm.open_ai_chat_completion_models
|
||||
or custom_llm_provider in litellm.openai_compatible_providers
|
||||
or custom_llm_provider == "azure"
|
||||
or custom_llm_provider == "custom_openai"
|
||||
or custom_llm_provider == "openai"
|
||||
or "ft:gpt-3.5-turbo" in model_name
|
||||
or model_name in litellm.open_ai_embedding_models
|
||||
):
|
||||
# glorified / complicated reading of configs
|
||||
# user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env
|
||||
# we do this here because we init clients for Azure, OpenAI and we need to set the right key
|
||||
api_key = litellm_params.get("api_key") or default_api_key
|
||||
if api_key and api_key.startswith("os.environ/"):
|
||||
api_key_env_name = api_key.replace("os.environ/", "")
|
||||
api_key = litellm.get_secret(api_key_env_name)
|
||||
litellm_params["api_key"] = api_key
|
||||
|
||||
api_base = litellm_params.get("api_base")
|
||||
base_url = litellm_params.get("base_url")
|
||||
api_base = (
|
||||
api_base or base_url or default_api_base
|
||||
) # allow users to pass in `api_base` or `base_url` for azure
|
||||
if api_base and api_base.startswith("os.environ/"):
|
||||
api_base_env_name = api_base.replace("os.environ/", "")
|
||||
api_base = litellm.get_secret(api_base_env_name)
|
||||
litellm_params["api_base"] = api_base
|
||||
|
||||
api_version = litellm_params.get("api_version")
|
||||
if api_version and api_version.startswith("os.environ/"):
|
||||
api_version_env_name = api_version.replace("os.environ/", "")
|
||||
api_version = litellm.get_secret(api_version_env_name)
|
||||
litellm_params["api_version"] = api_version
|
||||
|
||||
timeout = litellm_params.pop("timeout", None)
|
||||
if isinstance(timeout, str) and timeout.startswith("os.environ/"):
|
||||
timeout_env_name = timeout.replace("os.environ/", "")
|
||||
timeout = litellm.get_secret(timeout_env_name)
|
||||
litellm_params["timeout"] = timeout
|
||||
|
||||
stream_timeout = litellm_params.pop(
|
||||
"stream_timeout", timeout
|
||||
) # if no stream_timeout is set, default to timeout
|
||||
if isinstance(stream_timeout, str) and stream_timeout.startswith(
|
||||
"os.environ/"
|
||||
):
|
||||
stream_timeout_env_name = stream_timeout.replace("os.environ/", "")
|
||||
stream_timeout = litellm.get_secret(stream_timeout_env_name)
|
||||
litellm_params["stream_timeout"] = stream_timeout
|
||||
|
||||
max_retries = litellm_params.pop("max_retries", 2)
|
||||
if isinstance(max_retries, str) and max_retries.startswith(
|
||||
"os.environ/"
|
||||
):
|
||||
max_retries_env_name = max_retries.replace("os.environ/", "")
|
||||
max_retries = litellm.get_secret(max_retries_env_name)
|
||||
litellm_params["max_retries"] = max_retries
|
||||
|
||||
if "azure" in model_name:
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
f"api_base is required for Azure OpenAI. Set it on your config. Model - {model}"
|
||||
)
|
||||
if api_version is None:
|
||||
api_version = "2023-07-01-preview"
|
||||
if "gateway.ai.cloudflare.com" in api_base:
|
||||
if not api_base.endswith("/"):
|
||||
api_base += "/"
|
||||
azure_model = model_name.replace("azure/", "")
|
||||
api_base += f"{azure_model}"
|
||||
model["async_client"] = openai.AsyncAzureOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
model["client"] = openai.AzureOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
|
||||
# streaming clients can have diff timeouts
|
||||
model["stream_async_client"] = openai.AsyncAzureOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=stream_timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
model["stream_client"] = openai.AzureOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=stream_timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
else:
|
||||
self.print_verbose(
|
||||
f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{api_key}"
|
||||
)
|
||||
model["async_client"] = openai.AsyncAzureOpenAI(
|
||||
api_key=api_key,
|
||||
azure_endpoint=api_base,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
http_client=httpx.AsyncClient(
|
||||
transport=AsyncCustomHTTPTransport(),
|
||||
), # type: ignore
|
||||
)
|
||||
model["client"] = openai.AzureOpenAI(
|
||||
api_key=api_key,
|
||||
azure_endpoint=api_base,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
http_client=httpx.Client(
|
||||
transport=CustomHTTPTransport(),
|
||||
), # type: ignore
|
||||
)
|
||||
# streaming clients should have diff timeouts
|
||||
model["stream_async_client"] = openai.AsyncAzureOpenAI(
|
||||
api_key=api_key,
|
||||
azure_endpoint=api_base,
|
||||
api_version=api_version,
|
||||
timeout=stream_timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
|
||||
model["stream_client"] = openai.AzureOpenAI(
|
||||
api_key=api_key,
|
||||
azure_endpoint=api_base,
|
||||
api_version=api_version,
|
||||
timeout=stream_timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
|
||||
else:
|
||||
self.print_verbose(
|
||||
f"Initializing OpenAI Client for {model_name}, Api Base:{str(api_base)}, Api Key:{api_key}"
|
||||
)
|
||||
model["async_client"] = openai.AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
model["client"] = openai.OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
|
||||
# streaming clients should have diff timeouts
|
||||
model["stream_async_client"] = openai.AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
timeout=stream_timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
|
||||
# streaming clients should have diff timeouts
|
||||
model["stream_client"] = openai.OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
timeout=stream_timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
|
||||
############ End of initializing Clients for OpenAI/Azure ###################
|
||||
#### DEPLOYMENT NAMES INIT ########
|
||||
self.deployment_names.append(model["litellm_params"]["model"])
|
||||
|
||||
############ Users can either pass tpm/rpm as a litellm_param or a router param ###########
|
||||
# for get_available_deployment, we use the litellm_param["rpm"]
|
||||
# in this snippet we also set rpm to be a litellm_param
|
||||
|
@ -1464,6 +1557,8 @@ class Router:
|
|||
):
|
||||
model["litellm_params"]["tpm"] = model.get("tpm")
|
||||
|
||||
self.set_client(model=model)
|
||||
|
||||
self.print_verbose(f"\nInitialized Model List {self.model_list}")
|
||||
self.model_names = [m["model_name"] for m in model_list]
|
||||
|
||||
|
@ -1482,16 +1577,49 @@ class Router:
|
|||
Returns:
|
||||
The appropriate client based on the given client_type and kwargs.
|
||||
"""
|
||||
model_id = deployment["model_info"]["id"]
|
||||
if client_type == "async":
|
||||
if kwargs.get("stream") == True:
|
||||
return deployment.get("stream_async_client", None)
|
||||
cache_key = f"{model_id}_stream_async_client"
|
||||
client = self.cache.get_cache(key=cache_key, local_only=True)
|
||||
if client is None:
|
||||
"""
|
||||
Re-initialize the client
|
||||
"""
|
||||
self.set_client(model=deployment)
|
||||
client = self.cache.get_cache(key=cache_key, local_only=True)
|
||||
return client
|
||||
else:
|
||||
return deployment.get("async_client", None)
|
||||
cache_key = f"{model_id}_async_client"
|
||||
client = self.cache.get_cache(key=cache_key, local_only=True)
|
||||
if client is None:
|
||||
"""
|
||||
Re-initialize the client
|
||||
"""
|
||||
self.set_client(model=deployment)
|
||||
client = self.cache.get_cache(key=cache_key, local_only=True)
|
||||
return client
|
||||
else:
|
||||
if kwargs.get("stream") == True:
|
||||
return deployment.get("stream_client", None)
|
||||
cache_key = f"{model_id}_stream_client"
|
||||
client = self.cache.get_cache(key=cache_key)
|
||||
if client is None:
|
||||
"""
|
||||
Re-initialize the client
|
||||
"""
|
||||
self.set_client(model=deployment)
|
||||
client = self.cache.get_cache(key=cache_key)
|
||||
return client
|
||||
else:
|
||||
return deployment.get("client", None)
|
||||
cache_key = f"{model_id}_client"
|
||||
client = self.cache.get_cache(key=cache_key)
|
||||
if client is None:
|
||||
"""
|
||||
Re-initialize the client
|
||||
"""
|
||||
self.set_client(model=deployment)
|
||||
client = self.cache.get_cache(key=cache_key)
|
||||
return client
|
||||
|
||||
def print_verbose(self, print_statement):
|
||||
try:
|
||||
|
|
|
@ -778,7 +778,8 @@ def test_reading_keys_os_environ():
|
|||
os.environ["AZURE_MAX_RETRIES"]
|
||||
), f"{model['litellm_params']['max_retries']} vs {os.environ['AZURE_MAX_RETRIES']}"
|
||||
print("passed testing of reading keys from os.environ")
|
||||
async_client: openai.AsyncAzureOpenAI = model["async_client"] # type: ignore
|
||||
model_id = model["model_info"]["id"]
|
||||
async_client: openai.AsyncAzureOpenAI = router.cache.get_cache(f"{model_id}_async_client") # type: ignore
|
||||
assert async_client.api_key == os.environ["AZURE_API_KEY"]
|
||||
assert async_client.base_url == os.environ["AZURE_API_BASE"]
|
||||
assert async_client.max_retries == (
|
||||
|
@ -791,7 +792,7 @@ def test_reading_keys_os_environ():
|
|||
|
||||
print("\n Testing async streaming client")
|
||||
|
||||
stream_async_client: openai.AsyncAzureOpenAI = model["stream_async_client"] # type: ignore
|
||||
stream_async_client: openai.AsyncAzureOpenAI = router.cache.get_cache(f"{model_id}_stream_async_client") # type: ignore
|
||||
assert stream_async_client.api_key == os.environ["AZURE_API_KEY"]
|
||||
assert stream_async_client.base_url == os.environ["AZURE_API_BASE"]
|
||||
assert stream_async_client.max_retries == (
|
||||
|
@ -803,7 +804,7 @@ def test_reading_keys_os_environ():
|
|||
print("async stream client set correctly!")
|
||||
|
||||
print("\n Testing sync client")
|
||||
client: openai.AzureOpenAI = model["client"] # type: ignore
|
||||
client: openai.AzureOpenAI = router.cache.get_cache(f"{model_id}_client") # type: ignore
|
||||
assert client.api_key == os.environ["AZURE_API_KEY"]
|
||||
assert client.base_url == os.environ["AZURE_API_BASE"]
|
||||
assert client.max_retries == (
|
||||
|
@ -815,7 +816,7 @@ def test_reading_keys_os_environ():
|
|||
print("sync client set correctly!")
|
||||
|
||||
print("\n Testing sync stream client")
|
||||
stream_client: openai.AzureOpenAI = model["stream_client"] # type: ignore
|
||||
stream_client: openai.AzureOpenAI = router.cache.get_cache(f"{model_id}_stream_client") # type: ignore
|
||||
assert stream_client.api_key == os.environ["AZURE_API_KEY"]
|
||||
assert stream_client.base_url == os.environ["AZURE_API_BASE"]
|
||||
assert stream_client.max_retries == (
|
||||
|
@ -877,7 +878,8 @@ def test_reading_openai_keys_os_environ():
|
|||
os.environ["AZURE_MAX_RETRIES"]
|
||||
), f"{model['litellm_params']['max_retries']} vs {os.environ['AZURE_MAX_RETRIES']}"
|
||||
print("passed testing of reading keys from os.environ")
|
||||
async_client: openai.AsyncOpenAI = model["async_client"] # type: ignore
|
||||
model_id = model["model_info"]["id"]
|
||||
async_client: openai.AsyncOpenAI = router.cache.get_cache(key=f"{model_id}_async_client") # type: ignore
|
||||
assert async_client.api_key == os.environ["OPENAI_API_KEY"]
|
||||
assert async_client.max_retries == (
|
||||
os.environ["AZURE_MAX_RETRIES"]
|
||||
|
@ -889,7 +891,7 @@ def test_reading_openai_keys_os_environ():
|
|||
|
||||
print("\n Testing async streaming client")
|
||||
|
||||
stream_async_client: openai.AsyncOpenAI = model["stream_async_client"] # type: ignore
|
||||
stream_async_client: openai.AsyncOpenAI = router.cache.get_cache(key=f"{model_id}_stream_async_client") # type: ignore
|
||||
assert stream_async_client.api_key == os.environ["OPENAI_API_KEY"]
|
||||
assert stream_async_client.max_retries == (
|
||||
os.environ["AZURE_MAX_RETRIES"]
|
||||
|
@ -900,7 +902,7 @@ def test_reading_openai_keys_os_environ():
|
|||
print("async stream client set correctly!")
|
||||
|
||||
print("\n Testing sync client")
|
||||
client: openai.AzureOpenAI = model["client"] # type: ignore
|
||||
client: openai.AzureOpenAI = router.cache.get_cache(key=f"{model_id}_client") # type: ignore
|
||||
assert client.api_key == os.environ["OPENAI_API_KEY"]
|
||||
assert client.max_retries == (
|
||||
os.environ["AZURE_MAX_RETRIES"]
|
||||
|
@ -911,7 +913,7 @@ def test_reading_openai_keys_os_environ():
|
|||
print("sync client set correctly!")
|
||||
|
||||
print("\n Testing sync stream client")
|
||||
stream_client: openai.AzureOpenAI = model["stream_client"] # type: ignore
|
||||
stream_client: openai.AzureOpenAI = router.cache.get_cache(key=f"{model_id}_stream_client") # type: ignore
|
||||
assert stream_client.api_key == os.environ["OPENAI_API_KEY"]
|
||||
assert stream_client.max_retries == (
|
||||
os.environ["AZURE_MAX_RETRIES"]
|
||||
|
|
78
litellm/tests/test_router_client_init.py
Normal file
78
litellm/tests/test_router_client_init.py
Normal file
|
@ -0,0 +1,78 @@
|
|||
#### What this tests ####
|
||||
# This tests client initialization + reinitialization on the router
|
||||
|
||||
#### What this tests ####
|
||||
# This tests caching on the router
|
||||
import sys, os, time
|
||||
import traceback, asyncio
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
from litellm import Router
|
||||
|
||||
|
||||
async def test_router_init():
|
||||
"""
|
||||
1. Initializes clients on the router with 0
|
||||
2. Checks if client is still valid
|
||||
3. Checks if new client was initialized
|
||||
"""
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo-0613",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
"model_info": {"id": "1234"},
|
||||
"tpm": 100000,
|
||||
"rpm": 10000,
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
},
|
||||
"tpm": 100000,
|
||||
"rpm": 10000,
|
||||
},
|
||||
]
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": f"write a one sentence poem {time.time()}?"}
|
||||
]
|
||||
client_ttl_time = 2
|
||||
router = Router(
|
||||
model_list=model_list,
|
||||
redis_host=os.environ["REDIS_HOST"],
|
||||
redis_password=os.environ["REDIS_PASSWORD"],
|
||||
redis_port=os.environ["REDIS_PORT"],
|
||||
cache_responses=True,
|
||||
timeout=30,
|
||||
routing_strategy="simple-shuffle",
|
||||
client_ttl=client_ttl_time,
|
||||
)
|
||||
model = "gpt-3.5-turbo"
|
||||
cache_key = f"1234_async_client"
|
||||
## ASSERT IT EXISTS AT THE START ##
|
||||
assert router.cache.get_cache(key=cache_key) is not None
|
||||
response1 = await router.acompletion(model=model, messages=messages, temperature=1)
|
||||
await asyncio.sleep(client_ttl_time)
|
||||
## ASSERT IT'S CLEARED FROM CACHE ##
|
||||
assert router.cache.get_cache(key=cache_key, local_only=True) is None
|
||||
## ASSERT IT EXISTS AFTER RUNNING __GET_CLIENT() ##
|
||||
assert (
|
||||
router._get_client(
|
||||
deployment=model_list[0], client_type="async", kwargs={"stream": False}
|
||||
)
|
||||
is not None
|
||||
)
|
||||
|
||||
|
||||
# asyncio.run(test_router_init())
|
Loading…
Add table
Add a link
Reference in a new issue