forked from phoenix/litellm-mirror
fix(router.py): periodically re-initialize azure/openai clients to solve max conn issue
This commit is contained in:
parent
d089157925
commit
69935db239
4 changed files with 451 additions and 242 deletions
|
@ -133,30 +133,31 @@ class DualCache(BaseCache):
|
||||||
# If redis_cache is not provided, use the default RedisCache
|
# If redis_cache is not provided, use the default RedisCache
|
||||||
self.redis_cache = redis_cache
|
self.redis_cache = redis_cache
|
||||||
|
|
||||||
def set_cache(self, key, value, **kwargs):
|
def set_cache(self, key, value, local_only: bool = False, **kwargs):
|
||||||
# Update both Redis and in-memory cache
|
# Update both Redis and in-memory cache
|
||||||
try:
|
try:
|
||||||
print_verbose(f"set cache: key: {key}; value: {value}")
|
print_verbose(f"set cache: key: {key}; value: {value}")
|
||||||
if self.in_memory_cache is not None:
|
if self.in_memory_cache is not None:
|
||||||
self.in_memory_cache.set_cache(key, value, **kwargs)
|
self.in_memory_cache.set_cache(key, value, **kwargs)
|
||||||
|
|
||||||
if self.redis_cache is not None:
|
if self.redis_cache is not None and local_only == False:
|
||||||
self.redis_cache.set_cache(key, value, **kwargs)
|
self.redis_cache.set_cache(key, value, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_verbose(e)
|
print_verbose(e)
|
||||||
|
|
||||||
def get_cache(self, key, **kwargs):
|
def get_cache(self, key, local_only: bool = False, **kwargs):
|
||||||
# Try to fetch from in-memory cache first
|
# Try to fetch from in-memory cache first
|
||||||
try:
|
try:
|
||||||
print_verbose(f"get cache: cache key: {key}")
|
print_verbose(f"get cache: cache key: {key}; local_only: {local_only}")
|
||||||
result = None
|
result = None
|
||||||
if self.in_memory_cache is not None:
|
if self.in_memory_cache is not None:
|
||||||
in_memory_result = self.in_memory_cache.get_cache(key, **kwargs)
|
in_memory_result = self.in_memory_cache.get_cache(key, **kwargs)
|
||||||
|
|
||||||
|
print_verbose(f"in_memory_result: {in_memory_result}")
|
||||||
if in_memory_result is not None:
|
if in_memory_result is not None:
|
||||||
result = in_memory_result
|
result = in_memory_result
|
||||||
|
|
||||||
if self.redis_cache is not None:
|
if result is None and self.redis_cache is not None and local_only == False:
|
||||||
# If not found in in-memory cache, try fetching from Redis
|
# If not found in in-memory cache, try fetching from Redis
|
||||||
redis_result = self.redis_cache.get_cache(key, **kwargs)
|
redis_result = self.redis_cache.get_cache(key, **kwargs)
|
||||||
|
|
||||||
|
|
|
@ -84,6 +84,7 @@ class Router:
|
||||||
caching_groups: Optional[
|
caching_groups: Optional[
|
||||||
List[tuple]
|
List[tuple]
|
||||||
] = None, # if you want to cache across model groups
|
] = None, # if you want to cache across model groups
|
||||||
|
client_ttl: int = 3600, # ttl for cached clients - will re-initialize after this time in seconds
|
||||||
## RELIABILITY ##
|
## RELIABILITY ##
|
||||||
num_retries: int = 0,
|
num_retries: int = 0,
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
|
@ -106,6 +107,43 @@ class Router:
|
||||||
[]
|
[]
|
||||||
) # names of models under litellm_params. ex. azure/chatgpt-v-2
|
) # names of models under litellm_params. ex. azure/chatgpt-v-2
|
||||||
self.deployment_latency_map = {}
|
self.deployment_latency_map = {}
|
||||||
|
### CACHING ###
|
||||||
|
cache_type: Literal["local", "redis"] = "local" # default to an in-memory cache
|
||||||
|
redis_cache = None
|
||||||
|
cache_config = {}
|
||||||
|
self.client_ttl = client_ttl
|
||||||
|
if redis_url is not None or (
|
||||||
|
redis_host is not None
|
||||||
|
and redis_port is not None
|
||||||
|
and redis_password is not None
|
||||||
|
):
|
||||||
|
cache_type = "redis"
|
||||||
|
|
||||||
|
if redis_url is not None:
|
||||||
|
cache_config["url"] = redis_url
|
||||||
|
|
||||||
|
if redis_host is not None:
|
||||||
|
cache_config["host"] = redis_host
|
||||||
|
|
||||||
|
if redis_port is not None:
|
||||||
|
cache_config["port"] = str(redis_port) # type: ignore
|
||||||
|
|
||||||
|
if redis_password is not None:
|
||||||
|
cache_config["password"] = redis_password
|
||||||
|
|
||||||
|
# Add additional key-value pairs from cache_kwargs
|
||||||
|
cache_config.update(cache_kwargs)
|
||||||
|
redis_cache = RedisCache(**cache_config)
|
||||||
|
|
||||||
|
if cache_responses:
|
||||||
|
if litellm.cache is None:
|
||||||
|
# the cache can be initialized on the proxy server. We should not overwrite it
|
||||||
|
litellm.cache = litellm.Cache(type=cache_type, **cache_config) # type: ignore
|
||||||
|
self.cache_responses = cache_responses
|
||||||
|
self.cache = DualCache(
|
||||||
|
redis_cache=redis_cache, in_memory_cache=InMemoryCache()
|
||||||
|
) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
|
||||||
|
|
||||||
if model_list:
|
if model_list:
|
||||||
model_list = copy.deepcopy(model_list)
|
model_list = copy.deepcopy(model_list)
|
||||||
self.set_model_list(model_list)
|
self.set_model_list(model_list)
|
||||||
|
@ -155,40 +193,6 @@ class Router:
|
||||||
{"caching_groups": caching_groups}
|
{"caching_groups": caching_groups}
|
||||||
)
|
)
|
||||||
|
|
||||||
### CACHING ###
|
|
||||||
cache_type: Literal["local", "redis"] = "local" # default to an in-memory cache
|
|
||||||
redis_cache = None
|
|
||||||
cache_config = {}
|
|
||||||
if redis_url is not None or (
|
|
||||||
redis_host is not None
|
|
||||||
and redis_port is not None
|
|
||||||
and redis_password is not None
|
|
||||||
):
|
|
||||||
cache_type = "redis"
|
|
||||||
|
|
||||||
if redis_url is not None:
|
|
||||||
cache_config["url"] = redis_url
|
|
||||||
|
|
||||||
if redis_host is not None:
|
|
||||||
cache_config["host"] = redis_host
|
|
||||||
|
|
||||||
if redis_port is not None:
|
|
||||||
cache_config["port"] = str(redis_port) # type: ignore
|
|
||||||
|
|
||||||
if redis_password is not None:
|
|
||||||
cache_config["password"] = redis_password
|
|
||||||
|
|
||||||
# Add additional key-value pairs from cache_kwargs
|
|
||||||
cache_config.update(cache_kwargs)
|
|
||||||
redis_cache = RedisCache(**cache_config)
|
|
||||||
if cache_responses:
|
|
||||||
if litellm.cache is None:
|
|
||||||
# the cache can be initialized on the proxy server. We should not overwrite it
|
|
||||||
litellm.cache = litellm.Cache(type=cache_type, **cache_config) # type: ignore
|
|
||||||
self.cache_responses = cache_responses
|
|
||||||
self.cache = DualCache(
|
|
||||||
redis_cache=redis_cache, in_memory_cache=InMemoryCache()
|
|
||||||
) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
|
|
||||||
### ROUTING SETUP ###
|
### ROUTING SETUP ###
|
||||||
if routing_strategy == "least-busy":
|
if routing_strategy == "least-busy":
|
||||||
self.leastbusy_logger = LeastBusyLoggingHandler(
|
self.leastbusy_logger = LeastBusyLoggingHandler(
|
||||||
|
@ -1248,208 +1252,297 @@ class Router:
|
||||||
chosen_item = random.choices(sorted_items, weights=weights, k=1)[0][0]
|
chosen_item = random.choices(sorted_items, weights=weights, k=1)[0][0]
|
||||||
return chosen_item
|
return chosen_item
|
||||||
|
|
||||||
|
def set_client(self, model: dict):
|
||||||
|
"""
|
||||||
|
Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
|
||||||
|
"""
|
||||||
|
client_ttl = self.client_ttl
|
||||||
|
litellm_params = model.get("litellm_params", {})
|
||||||
|
model_name = litellm_params.get("model")
|
||||||
|
model_id = model["model_info"]["id"]
|
||||||
|
#### for OpenAI / Azure we need to initalize the Client for High Traffic ########
|
||||||
|
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||||
|
custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or ""
|
||||||
|
default_api_base = None
|
||||||
|
default_api_key = None
|
||||||
|
if custom_llm_provider in litellm.openai_compatible_providers:
|
||||||
|
_, custom_llm_provider, api_key, api_base = litellm.get_llm_provider(
|
||||||
|
model=model_name
|
||||||
|
)
|
||||||
|
default_api_base = api_base
|
||||||
|
default_api_key = api_key
|
||||||
|
if (
|
||||||
|
model_name in litellm.open_ai_chat_completion_models
|
||||||
|
or custom_llm_provider in litellm.openai_compatible_providers
|
||||||
|
or custom_llm_provider == "azure"
|
||||||
|
or custom_llm_provider == "custom_openai"
|
||||||
|
or custom_llm_provider == "openai"
|
||||||
|
or "ft:gpt-3.5-turbo" in model_name
|
||||||
|
or model_name in litellm.open_ai_embedding_models
|
||||||
|
):
|
||||||
|
# glorified / complicated reading of configs
|
||||||
|
# user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env
|
||||||
|
# we do this here because we init clients for Azure, OpenAI and we need to set the right key
|
||||||
|
api_key = litellm_params.get("api_key") or default_api_key
|
||||||
|
if api_key and api_key.startswith("os.environ/"):
|
||||||
|
api_key_env_name = api_key.replace("os.environ/", "")
|
||||||
|
api_key = litellm.get_secret(api_key_env_name)
|
||||||
|
litellm_params["api_key"] = api_key
|
||||||
|
|
||||||
|
api_base = litellm_params.get("api_base")
|
||||||
|
base_url = litellm_params.get("base_url")
|
||||||
|
api_base = (
|
||||||
|
api_base or base_url or default_api_base
|
||||||
|
) # allow users to pass in `api_base` or `base_url` for azure
|
||||||
|
if api_base and api_base.startswith("os.environ/"):
|
||||||
|
api_base_env_name = api_base.replace("os.environ/", "")
|
||||||
|
api_base = litellm.get_secret(api_base_env_name)
|
||||||
|
litellm_params["api_base"] = api_base
|
||||||
|
|
||||||
|
api_version = litellm_params.get("api_version")
|
||||||
|
if api_version and api_version.startswith("os.environ/"):
|
||||||
|
api_version_env_name = api_version.replace("os.environ/", "")
|
||||||
|
api_version = litellm.get_secret(api_version_env_name)
|
||||||
|
litellm_params["api_version"] = api_version
|
||||||
|
|
||||||
|
timeout = litellm_params.pop("timeout", None)
|
||||||
|
if isinstance(timeout, str) and timeout.startswith("os.environ/"):
|
||||||
|
timeout_env_name = timeout.replace("os.environ/", "")
|
||||||
|
timeout = litellm.get_secret(timeout_env_name)
|
||||||
|
litellm_params["timeout"] = timeout
|
||||||
|
|
||||||
|
stream_timeout = litellm_params.pop(
|
||||||
|
"stream_timeout", timeout
|
||||||
|
) # if no stream_timeout is set, default to timeout
|
||||||
|
if isinstance(stream_timeout, str) and stream_timeout.startswith(
|
||||||
|
"os.environ/"
|
||||||
|
):
|
||||||
|
stream_timeout_env_name = stream_timeout.replace("os.environ/", "")
|
||||||
|
stream_timeout = litellm.get_secret(stream_timeout_env_name)
|
||||||
|
litellm_params["stream_timeout"] = stream_timeout
|
||||||
|
|
||||||
|
max_retries = litellm_params.pop("max_retries", 2)
|
||||||
|
if isinstance(max_retries, str) and max_retries.startswith("os.environ/"):
|
||||||
|
max_retries_env_name = max_retries.replace("os.environ/", "")
|
||||||
|
max_retries = litellm.get_secret(max_retries_env_name)
|
||||||
|
litellm_params["max_retries"] = max_retries
|
||||||
|
|
||||||
|
if "azure" in model_name:
|
||||||
|
if api_base is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"api_base is required for Azure OpenAI. Set it on your config. Model - {model}"
|
||||||
|
)
|
||||||
|
if api_version is None:
|
||||||
|
api_version = "2023-07-01-preview"
|
||||||
|
if "gateway.ai.cloudflare.com" in api_base:
|
||||||
|
if not api_base.endswith("/"):
|
||||||
|
api_base += "/"
|
||||||
|
azure_model = model_name.replace("azure/", "")
|
||||||
|
api_base += f"{azure_model}"
|
||||||
|
cache_key = f"{model_id}_async_client"
|
||||||
|
_client = openai.AsyncAzureOpenAI(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
)
|
||||||
|
self.cache.set_cache(
|
||||||
|
key=cache_key,
|
||||||
|
value=_client,
|
||||||
|
ttl=client_ttl,
|
||||||
|
local_only=True,
|
||||||
|
) # cache for 1 hr
|
||||||
|
|
||||||
|
_client = openai.AzureOpenAI( # type: ignore
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
)
|
||||||
|
self.cache.set_cache(
|
||||||
|
key=cache_key,
|
||||||
|
value=_client,
|
||||||
|
ttl=client_ttl,
|
||||||
|
local_only=True,
|
||||||
|
) # cache for 1 hr
|
||||||
|
# streaming clients can have diff timeouts
|
||||||
|
cache_key = f"{model_id}_stream_async_client"
|
||||||
|
_client = openai.AsyncAzureOpenAI( # type: ignore
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
timeout=stream_timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
)
|
||||||
|
self.cache.set_cache(
|
||||||
|
key=cache_key,
|
||||||
|
value=_client,
|
||||||
|
ttl=client_ttl,
|
||||||
|
local_only=True,
|
||||||
|
) # cache for 1 hr
|
||||||
|
|
||||||
|
cache_key = f"{model_id}_stream_client"
|
||||||
|
_client = openai.AzureOpenAI( # type: ignore
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
timeout=stream_timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
)
|
||||||
|
self.cache.set_cache(
|
||||||
|
key=cache_key,
|
||||||
|
value=_client,
|
||||||
|
ttl=client_ttl,
|
||||||
|
local_only=True,
|
||||||
|
) # cache for 1 hr
|
||||||
|
else:
|
||||||
|
self.print_verbose(
|
||||||
|
f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{api_key}"
|
||||||
|
)
|
||||||
|
|
||||||
|
cache_key = f"{model_id}_async_client"
|
||||||
|
_client = openai.AsyncAzureOpenAI( # type: ignore
|
||||||
|
api_key=api_key,
|
||||||
|
azure_endpoint=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
http_client=httpx.AsyncClient(
|
||||||
|
transport=AsyncCustomHTTPTransport(),
|
||||||
|
), # type: ignore
|
||||||
|
)
|
||||||
|
self.cache.set_cache(
|
||||||
|
key=cache_key,
|
||||||
|
value=_client,
|
||||||
|
ttl=client_ttl,
|
||||||
|
local_only=True,
|
||||||
|
) # cache for 1 hr
|
||||||
|
|
||||||
|
cache_key = f"{model_id}_client"
|
||||||
|
_client = openai.AzureOpenAI( # type: ignore
|
||||||
|
api_key=api_key,
|
||||||
|
azure_endpoint=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
http_client=httpx.Client(
|
||||||
|
transport=CustomHTTPTransport(),
|
||||||
|
), # type: ignore
|
||||||
|
)
|
||||||
|
self.cache.set_cache(
|
||||||
|
key=cache_key,
|
||||||
|
value=_client,
|
||||||
|
ttl=client_ttl,
|
||||||
|
local_only=True,
|
||||||
|
) # cache for 1 hr
|
||||||
|
|
||||||
|
# streaming clients should have diff timeouts
|
||||||
|
cache_key = f"{model_id}_stream_async_client"
|
||||||
|
_client = openai.AsyncAzureOpenAI( # type: ignore
|
||||||
|
api_key=api_key,
|
||||||
|
azure_endpoint=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
timeout=stream_timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
)
|
||||||
|
self.cache.set_cache(
|
||||||
|
key=cache_key,
|
||||||
|
value=_client,
|
||||||
|
ttl=client_ttl,
|
||||||
|
local_only=True,
|
||||||
|
) # cache for 1 hr
|
||||||
|
|
||||||
|
cache_key = f"{model_id}_stream_client"
|
||||||
|
_client = openai.AzureOpenAI( # type: ignore
|
||||||
|
api_key=api_key,
|
||||||
|
azure_endpoint=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
timeout=stream_timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
)
|
||||||
|
self.cache.set_cache(
|
||||||
|
key=cache_key,
|
||||||
|
value=_client,
|
||||||
|
ttl=client_ttl,
|
||||||
|
local_only=True,
|
||||||
|
) # cache for 1 hr
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.print_verbose(
|
||||||
|
f"Initializing OpenAI Client for {model_name}, Api Base:{str(api_base)}, Api Key:{api_key}"
|
||||||
|
)
|
||||||
|
cache_key = f"{model_id}_async_client"
|
||||||
|
_client = openai.AsyncOpenAI( # type: ignore
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=api_base,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
)
|
||||||
|
self.cache.set_cache(
|
||||||
|
key=cache_key,
|
||||||
|
value=_client,
|
||||||
|
ttl=client_ttl,
|
||||||
|
local_only=True,
|
||||||
|
) # cache for 1 hr
|
||||||
|
|
||||||
|
cache_key = f"{model_id}_client"
|
||||||
|
_client = openai.OpenAI( # type: ignore
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=api_base,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
)
|
||||||
|
self.cache.set_cache(
|
||||||
|
key=cache_key,
|
||||||
|
value=_client,
|
||||||
|
ttl=client_ttl,
|
||||||
|
local_only=True,
|
||||||
|
) # cache for 1 hr
|
||||||
|
|
||||||
|
# streaming clients should have diff timeouts
|
||||||
|
cache_key = f"{model_id}_stream_async_client"
|
||||||
|
_client = openai.AsyncOpenAI( # type: ignore
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=api_base,
|
||||||
|
timeout=stream_timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
)
|
||||||
|
self.cache.set_cache(
|
||||||
|
key=cache_key,
|
||||||
|
value=_client,
|
||||||
|
ttl=client_ttl,
|
||||||
|
local_only=True,
|
||||||
|
) # cache for 1 hr
|
||||||
|
|
||||||
|
# streaming clients should have diff timeouts
|
||||||
|
cache_key = f"{model_id}_stream_client"
|
||||||
|
_client = openai.OpenAI( # type: ignore
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=api_base,
|
||||||
|
timeout=stream_timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
)
|
||||||
|
self.cache.set_cache(
|
||||||
|
key=cache_key,
|
||||||
|
value=_client,
|
||||||
|
ttl=client_ttl,
|
||||||
|
local_only=True,
|
||||||
|
) # cache for 1 hr
|
||||||
|
|
||||||
def set_model_list(self, model_list: list):
|
def set_model_list(self, model_list: list):
|
||||||
self.model_list = copy.deepcopy(model_list)
|
self.model_list = copy.deepcopy(model_list)
|
||||||
# we add api_base/api_key each model so load balancing between azure/gpt on api_base1 and api_base2 works
|
# we add api_base/api_key each model so load balancing between azure/gpt on api_base1 and api_base2 works
|
||||||
import os
|
import os
|
||||||
|
|
||||||
for model in self.model_list:
|
for model in self.model_list:
|
||||||
litellm_params = model.get("litellm_params", {})
|
|
||||||
model_name = litellm_params.get("model")
|
|
||||||
#### MODEL ID INIT ########
|
#### MODEL ID INIT ########
|
||||||
model_info = model.get("model_info", {})
|
model_info = model.get("model_info", {})
|
||||||
model_info["id"] = model_info.get("id", str(uuid.uuid4()))
|
model_info["id"] = model_info.get("id", str(uuid.uuid4()))
|
||||||
model["model_info"] = model_info
|
model["model_info"] = model_info
|
||||||
#### for OpenAI / Azure we need to initalize the Client for High Traffic ########
|
#### DEPLOYMENT NAMES INIT ########
|
||||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
|
||||||
custom_llm_provider = (
|
|
||||||
custom_llm_provider or model_name.split("/", 1)[0] or ""
|
|
||||||
)
|
|
||||||
default_api_base = None
|
|
||||||
default_api_key = None
|
|
||||||
if custom_llm_provider in litellm.openai_compatible_providers:
|
|
||||||
_, custom_llm_provider, api_key, api_base = litellm.get_llm_provider(
|
|
||||||
model=model_name
|
|
||||||
)
|
|
||||||
default_api_base = api_base
|
|
||||||
default_api_key = api_key
|
|
||||||
if (
|
|
||||||
model_name in litellm.open_ai_chat_completion_models
|
|
||||||
or custom_llm_provider in litellm.openai_compatible_providers
|
|
||||||
or custom_llm_provider == "azure"
|
|
||||||
or custom_llm_provider == "custom_openai"
|
|
||||||
or custom_llm_provider == "openai"
|
|
||||||
or "ft:gpt-3.5-turbo" in model_name
|
|
||||||
or model_name in litellm.open_ai_embedding_models
|
|
||||||
):
|
|
||||||
# glorified / complicated reading of configs
|
|
||||||
# user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env
|
|
||||||
# we do this here because we init clients for Azure, OpenAI and we need to set the right key
|
|
||||||
api_key = litellm_params.get("api_key") or default_api_key
|
|
||||||
if api_key and api_key.startswith("os.environ/"):
|
|
||||||
api_key_env_name = api_key.replace("os.environ/", "")
|
|
||||||
api_key = litellm.get_secret(api_key_env_name)
|
|
||||||
litellm_params["api_key"] = api_key
|
|
||||||
|
|
||||||
api_base = litellm_params.get("api_base")
|
|
||||||
base_url = litellm_params.get("base_url")
|
|
||||||
api_base = (
|
|
||||||
api_base or base_url or default_api_base
|
|
||||||
) # allow users to pass in `api_base` or `base_url` for azure
|
|
||||||
if api_base and api_base.startswith("os.environ/"):
|
|
||||||
api_base_env_name = api_base.replace("os.environ/", "")
|
|
||||||
api_base = litellm.get_secret(api_base_env_name)
|
|
||||||
litellm_params["api_base"] = api_base
|
|
||||||
|
|
||||||
api_version = litellm_params.get("api_version")
|
|
||||||
if api_version and api_version.startswith("os.environ/"):
|
|
||||||
api_version_env_name = api_version.replace("os.environ/", "")
|
|
||||||
api_version = litellm.get_secret(api_version_env_name)
|
|
||||||
litellm_params["api_version"] = api_version
|
|
||||||
|
|
||||||
timeout = litellm_params.pop("timeout", None)
|
|
||||||
if isinstance(timeout, str) and timeout.startswith("os.environ/"):
|
|
||||||
timeout_env_name = timeout.replace("os.environ/", "")
|
|
||||||
timeout = litellm.get_secret(timeout_env_name)
|
|
||||||
litellm_params["timeout"] = timeout
|
|
||||||
|
|
||||||
stream_timeout = litellm_params.pop(
|
|
||||||
"stream_timeout", timeout
|
|
||||||
) # if no stream_timeout is set, default to timeout
|
|
||||||
if isinstance(stream_timeout, str) and stream_timeout.startswith(
|
|
||||||
"os.environ/"
|
|
||||||
):
|
|
||||||
stream_timeout_env_name = stream_timeout.replace("os.environ/", "")
|
|
||||||
stream_timeout = litellm.get_secret(stream_timeout_env_name)
|
|
||||||
litellm_params["stream_timeout"] = stream_timeout
|
|
||||||
|
|
||||||
max_retries = litellm_params.pop("max_retries", 2)
|
|
||||||
if isinstance(max_retries, str) and max_retries.startswith(
|
|
||||||
"os.environ/"
|
|
||||||
):
|
|
||||||
max_retries_env_name = max_retries.replace("os.environ/", "")
|
|
||||||
max_retries = litellm.get_secret(max_retries_env_name)
|
|
||||||
litellm_params["max_retries"] = max_retries
|
|
||||||
|
|
||||||
if "azure" in model_name:
|
|
||||||
if api_base is None:
|
|
||||||
raise ValueError(
|
|
||||||
f"api_base is required for Azure OpenAI. Set it on your config. Model - {model}"
|
|
||||||
)
|
|
||||||
if api_version is None:
|
|
||||||
api_version = "2023-07-01-preview"
|
|
||||||
if "gateway.ai.cloudflare.com" in api_base:
|
|
||||||
if not api_base.endswith("/"):
|
|
||||||
api_base += "/"
|
|
||||||
azure_model = model_name.replace("azure/", "")
|
|
||||||
api_base += f"{azure_model}"
|
|
||||||
model["async_client"] = openai.AsyncAzureOpenAI(
|
|
||||||
api_key=api_key,
|
|
||||||
base_url=api_base,
|
|
||||||
api_version=api_version,
|
|
||||||
timeout=timeout,
|
|
||||||
max_retries=max_retries,
|
|
||||||
)
|
|
||||||
model["client"] = openai.AzureOpenAI(
|
|
||||||
api_key=api_key,
|
|
||||||
base_url=api_base,
|
|
||||||
api_version=api_version,
|
|
||||||
timeout=timeout,
|
|
||||||
max_retries=max_retries,
|
|
||||||
)
|
|
||||||
|
|
||||||
# streaming clients can have diff timeouts
|
|
||||||
model["stream_async_client"] = openai.AsyncAzureOpenAI(
|
|
||||||
api_key=api_key,
|
|
||||||
base_url=api_base,
|
|
||||||
api_version=api_version,
|
|
||||||
timeout=stream_timeout,
|
|
||||||
max_retries=max_retries,
|
|
||||||
)
|
|
||||||
model["stream_client"] = openai.AzureOpenAI(
|
|
||||||
api_key=api_key,
|
|
||||||
base_url=api_base,
|
|
||||||
api_version=api_version,
|
|
||||||
timeout=stream_timeout,
|
|
||||||
max_retries=max_retries,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.print_verbose(
|
|
||||||
f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{api_key}"
|
|
||||||
)
|
|
||||||
model["async_client"] = openai.AsyncAzureOpenAI(
|
|
||||||
api_key=api_key,
|
|
||||||
azure_endpoint=api_base,
|
|
||||||
api_version=api_version,
|
|
||||||
timeout=timeout,
|
|
||||||
max_retries=max_retries,
|
|
||||||
http_client=httpx.AsyncClient(
|
|
||||||
transport=AsyncCustomHTTPTransport(),
|
|
||||||
), # type: ignore
|
|
||||||
)
|
|
||||||
model["client"] = openai.AzureOpenAI(
|
|
||||||
api_key=api_key,
|
|
||||||
azure_endpoint=api_base,
|
|
||||||
api_version=api_version,
|
|
||||||
timeout=timeout,
|
|
||||||
max_retries=max_retries,
|
|
||||||
http_client=httpx.Client(
|
|
||||||
transport=CustomHTTPTransport(),
|
|
||||||
), # type: ignore
|
|
||||||
)
|
|
||||||
# streaming clients should have diff timeouts
|
|
||||||
model["stream_async_client"] = openai.AsyncAzureOpenAI(
|
|
||||||
api_key=api_key,
|
|
||||||
azure_endpoint=api_base,
|
|
||||||
api_version=api_version,
|
|
||||||
timeout=stream_timeout,
|
|
||||||
max_retries=max_retries,
|
|
||||||
)
|
|
||||||
|
|
||||||
model["stream_client"] = openai.AzureOpenAI(
|
|
||||||
api_key=api_key,
|
|
||||||
azure_endpoint=api_base,
|
|
||||||
api_version=api_version,
|
|
||||||
timeout=stream_timeout,
|
|
||||||
max_retries=max_retries,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
self.print_verbose(
|
|
||||||
f"Initializing OpenAI Client for {model_name}, Api Base:{str(api_base)}, Api Key:{api_key}"
|
|
||||||
)
|
|
||||||
model["async_client"] = openai.AsyncOpenAI(
|
|
||||||
api_key=api_key,
|
|
||||||
base_url=api_base,
|
|
||||||
timeout=timeout,
|
|
||||||
max_retries=max_retries,
|
|
||||||
)
|
|
||||||
model["client"] = openai.OpenAI(
|
|
||||||
api_key=api_key,
|
|
||||||
base_url=api_base,
|
|
||||||
timeout=timeout,
|
|
||||||
max_retries=max_retries,
|
|
||||||
)
|
|
||||||
|
|
||||||
# streaming clients should have diff timeouts
|
|
||||||
model["stream_async_client"] = openai.AsyncOpenAI(
|
|
||||||
api_key=api_key,
|
|
||||||
base_url=api_base,
|
|
||||||
timeout=stream_timeout,
|
|
||||||
max_retries=max_retries,
|
|
||||||
)
|
|
||||||
|
|
||||||
# streaming clients should have diff timeouts
|
|
||||||
model["stream_client"] = openai.OpenAI(
|
|
||||||
api_key=api_key,
|
|
||||||
base_url=api_base,
|
|
||||||
timeout=stream_timeout,
|
|
||||||
max_retries=max_retries,
|
|
||||||
)
|
|
||||||
|
|
||||||
############ End of initializing Clients for OpenAI/Azure ###################
|
|
||||||
self.deployment_names.append(model["litellm_params"]["model"])
|
self.deployment_names.append(model["litellm_params"]["model"])
|
||||||
|
|
||||||
############ Users can either pass tpm/rpm as a litellm_param or a router param ###########
|
############ Users can either pass tpm/rpm as a litellm_param or a router param ###########
|
||||||
# for get_available_deployment, we use the litellm_param["rpm"]
|
# for get_available_deployment, we use the litellm_param["rpm"]
|
||||||
# in this snippet we also set rpm to be a litellm_param
|
# in this snippet we also set rpm to be a litellm_param
|
||||||
|
@ -1464,6 +1557,8 @@ class Router:
|
||||||
):
|
):
|
||||||
model["litellm_params"]["tpm"] = model.get("tpm")
|
model["litellm_params"]["tpm"] = model.get("tpm")
|
||||||
|
|
||||||
|
self.set_client(model=model)
|
||||||
|
|
||||||
self.print_verbose(f"\nInitialized Model List {self.model_list}")
|
self.print_verbose(f"\nInitialized Model List {self.model_list}")
|
||||||
self.model_names = [m["model_name"] for m in model_list]
|
self.model_names = [m["model_name"] for m in model_list]
|
||||||
|
|
||||||
|
@ -1482,16 +1577,49 @@ class Router:
|
||||||
Returns:
|
Returns:
|
||||||
The appropriate client based on the given client_type and kwargs.
|
The appropriate client based on the given client_type and kwargs.
|
||||||
"""
|
"""
|
||||||
|
model_id = deployment["model_info"]["id"]
|
||||||
if client_type == "async":
|
if client_type == "async":
|
||||||
if kwargs.get("stream") == True:
|
if kwargs.get("stream") == True:
|
||||||
return deployment.get("stream_async_client", None)
|
cache_key = f"{model_id}_stream_async_client"
|
||||||
|
client = self.cache.get_cache(key=cache_key, local_only=True)
|
||||||
|
if client is None:
|
||||||
|
"""
|
||||||
|
Re-initialize the client
|
||||||
|
"""
|
||||||
|
self.set_client(model=deployment)
|
||||||
|
client = self.cache.get_cache(key=cache_key, local_only=True)
|
||||||
|
return client
|
||||||
else:
|
else:
|
||||||
return deployment.get("async_client", None)
|
cache_key = f"{model_id}_async_client"
|
||||||
|
client = self.cache.get_cache(key=cache_key, local_only=True)
|
||||||
|
if client is None:
|
||||||
|
"""
|
||||||
|
Re-initialize the client
|
||||||
|
"""
|
||||||
|
self.set_client(model=deployment)
|
||||||
|
client = self.cache.get_cache(key=cache_key, local_only=True)
|
||||||
|
return client
|
||||||
else:
|
else:
|
||||||
if kwargs.get("stream") == True:
|
if kwargs.get("stream") == True:
|
||||||
return deployment.get("stream_client", None)
|
cache_key = f"{model_id}_stream_client"
|
||||||
|
client = self.cache.get_cache(key=cache_key)
|
||||||
|
if client is None:
|
||||||
|
"""
|
||||||
|
Re-initialize the client
|
||||||
|
"""
|
||||||
|
self.set_client(model=deployment)
|
||||||
|
client = self.cache.get_cache(key=cache_key)
|
||||||
|
return client
|
||||||
else:
|
else:
|
||||||
return deployment.get("client", None)
|
cache_key = f"{model_id}_client"
|
||||||
|
client = self.cache.get_cache(key=cache_key)
|
||||||
|
if client is None:
|
||||||
|
"""
|
||||||
|
Re-initialize the client
|
||||||
|
"""
|
||||||
|
self.set_client(model=deployment)
|
||||||
|
client = self.cache.get_cache(key=cache_key)
|
||||||
|
return client
|
||||||
|
|
||||||
def print_verbose(self, print_statement):
|
def print_verbose(self, print_statement):
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -778,7 +778,8 @@ def test_reading_keys_os_environ():
|
||||||
os.environ["AZURE_MAX_RETRIES"]
|
os.environ["AZURE_MAX_RETRIES"]
|
||||||
), f"{model['litellm_params']['max_retries']} vs {os.environ['AZURE_MAX_RETRIES']}"
|
), f"{model['litellm_params']['max_retries']} vs {os.environ['AZURE_MAX_RETRIES']}"
|
||||||
print("passed testing of reading keys from os.environ")
|
print("passed testing of reading keys from os.environ")
|
||||||
async_client: openai.AsyncAzureOpenAI = model["async_client"] # type: ignore
|
model_id = model["model_info"]["id"]
|
||||||
|
async_client: openai.AsyncAzureOpenAI = router.cache.get_cache(f"{model_id}_async_client") # type: ignore
|
||||||
assert async_client.api_key == os.environ["AZURE_API_KEY"]
|
assert async_client.api_key == os.environ["AZURE_API_KEY"]
|
||||||
assert async_client.base_url == os.environ["AZURE_API_BASE"]
|
assert async_client.base_url == os.environ["AZURE_API_BASE"]
|
||||||
assert async_client.max_retries == (
|
assert async_client.max_retries == (
|
||||||
|
@ -791,7 +792,7 @@ def test_reading_keys_os_environ():
|
||||||
|
|
||||||
print("\n Testing async streaming client")
|
print("\n Testing async streaming client")
|
||||||
|
|
||||||
stream_async_client: openai.AsyncAzureOpenAI = model["stream_async_client"] # type: ignore
|
stream_async_client: openai.AsyncAzureOpenAI = router.cache.get_cache(f"{model_id}_stream_async_client") # type: ignore
|
||||||
assert stream_async_client.api_key == os.environ["AZURE_API_KEY"]
|
assert stream_async_client.api_key == os.environ["AZURE_API_KEY"]
|
||||||
assert stream_async_client.base_url == os.environ["AZURE_API_BASE"]
|
assert stream_async_client.base_url == os.environ["AZURE_API_BASE"]
|
||||||
assert stream_async_client.max_retries == (
|
assert stream_async_client.max_retries == (
|
||||||
|
@ -803,7 +804,7 @@ def test_reading_keys_os_environ():
|
||||||
print("async stream client set correctly!")
|
print("async stream client set correctly!")
|
||||||
|
|
||||||
print("\n Testing sync client")
|
print("\n Testing sync client")
|
||||||
client: openai.AzureOpenAI = model["client"] # type: ignore
|
client: openai.AzureOpenAI = router.cache.get_cache(f"{model_id}_client") # type: ignore
|
||||||
assert client.api_key == os.environ["AZURE_API_KEY"]
|
assert client.api_key == os.environ["AZURE_API_KEY"]
|
||||||
assert client.base_url == os.environ["AZURE_API_BASE"]
|
assert client.base_url == os.environ["AZURE_API_BASE"]
|
||||||
assert client.max_retries == (
|
assert client.max_retries == (
|
||||||
|
@ -815,7 +816,7 @@ def test_reading_keys_os_environ():
|
||||||
print("sync client set correctly!")
|
print("sync client set correctly!")
|
||||||
|
|
||||||
print("\n Testing sync stream client")
|
print("\n Testing sync stream client")
|
||||||
stream_client: openai.AzureOpenAI = model["stream_client"] # type: ignore
|
stream_client: openai.AzureOpenAI = router.cache.get_cache(f"{model_id}_stream_client") # type: ignore
|
||||||
assert stream_client.api_key == os.environ["AZURE_API_KEY"]
|
assert stream_client.api_key == os.environ["AZURE_API_KEY"]
|
||||||
assert stream_client.base_url == os.environ["AZURE_API_BASE"]
|
assert stream_client.base_url == os.environ["AZURE_API_BASE"]
|
||||||
assert stream_client.max_retries == (
|
assert stream_client.max_retries == (
|
||||||
|
@ -877,7 +878,8 @@ def test_reading_openai_keys_os_environ():
|
||||||
os.environ["AZURE_MAX_RETRIES"]
|
os.environ["AZURE_MAX_RETRIES"]
|
||||||
), f"{model['litellm_params']['max_retries']} vs {os.environ['AZURE_MAX_RETRIES']}"
|
), f"{model['litellm_params']['max_retries']} vs {os.environ['AZURE_MAX_RETRIES']}"
|
||||||
print("passed testing of reading keys from os.environ")
|
print("passed testing of reading keys from os.environ")
|
||||||
async_client: openai.AsyncOpenAI = model["async_client"] # type: ignore
|
model_id = model["model_info"]["id"]
|
||||||
|
async_client: openai.AsyncOpenAI = router.cache.get_cache(key=f"{model_id}_async_client") # type: ignore
|
||||||
assert async_client.api_key == os.environ["OPENAI_API_KEY"]
|
assert async_client.api_key == os.environ["OPENAI_API_KEY"]
|
||||||
assert async_client.max_retries == (
|
assert async_client.max_retries == (
|
||||||
os.environ["AZURE_MAX_RETRIES"]
|
os.environ["AZURE_MAX_RETRIES"]
|
||||||
|
@ -889,7 +891,7 @@ def test_reading_openai_keys_os_environ():
|
||||||
|
|
||||||
print("\n Testing async streaming client")
|
print("\n Testing async streaming client")
|
||||||
|
|
||||||
stream_async_client: openai.AsyncOpenAI = model["stream_async_client"] # type: ignore
|
stream_async_client: openai.AsyncOpenAI = router.cache.get_cache(key=f"{model_id}_stream_async_client") # type: ignore
|
||||||
assert stream_async_client.api_key == os.environ["OPENAI_API_KEY"]
|
assert stream_async_client.api_key == os.environ["OPENAI_API_KEY"]
|
||||||
assert stream_async_client.max_retries == (
|
assert stream_async_client.max_retries == (
|
||||||
os.environ["AZURE_MAX_RETRIES"]
|
os.environ["AZURE_MAX_RETRIES"]
|
||||||
|
@ -900,7 +902,7 @@ def test_reading_openai_keys_os_environ():
|
||||||
print("async stream client set correctly!")
|
print("async stream client set correctly!")
|
||||||
|
|
||||||
print("\n Testing sync client")
|
print("\n Testing sync client")
|
||||||
client: openai.AzureOpenAI = model["client"] # type: ignore
|
client: openai.AzureOpenAI = router.cache.get_cache(key=f"{model_id}_client") # type: ignore
|
||||||
assert client.api_key == os.environ["OPENAI_API_KEY"]
|
assert client.api_key == os.environ["OPENAI_API_KEY"]
|
||||||
assert client.max_retries == (
|
assert client.max_retries == (
|
||||||
os.environ["AZURE_MAX_RETRIES"]
|
os.environ["AZURE_MAX_RETRIES"]
|
||||||
|
@ -911,7 +913,7 @@ def test_reading_openai_keys_os_environ():
|
||||||
print("sync client set correctly!")
|
print("sync client set correctly!")
|
||||||
|
|
||||||
print("\n Testing sync stream client")
|
print("\n Testing sync stream client")
|
||||||
stream_client: openai.AzureOpenAI = model["stream_client"] # type: ignore
|
stream_client: openai.AzureOpenAI = router.cache.get_cache(key=f"{model_id}_stream_client") # type: ignore
|
||||||
assert stream_client.api_key == os.environ["OPENAI_API_KEY"]
|
assert stream_client.api_key == os.environ["OPENAI_API_KEY"]
|
||||||
assert stream_client.max_retries == (
|
assert stream_client.max_retries == (
|
||||||
os.environ["AZURE_MAX_RETRIES"]
|
os.environ["AZURE_MAX_RETRIES"]
|
||||||
|
|
78
litellm/tests/test_router_client_init.py
Normal file
78
litellm/tests/test_router_client_init.py
Normal file
|
@ -0,0 +1,78 @@
|
||||||
|
#### What this tests ####
|
||||||
|
# This tests client initialization + reinitialization on the router
|
||||||
|
|
||||||
|
#### What this tests ####
|
||||||
|
# This tests caching on the router
|
||||||
|
import sys, os, time
|
||||||
|
import traceback, asyncio
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import litellm
|
||||||
|
from litellm import Router
|
||||||
|
|
||||||
|
|
||||||
|
async def test_router_init():
|
||||||
|
"""
|
||||||
|
1. Initializes clients on the router with 0
|
||||||
|
2. Checks if client is still valid
|
||||||
|
3. Checks if new client was initialized
|
||||||
|
"""
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo-0613",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
"model_info": {"id": "1234"},
|
||||||
|
"tpm": 100000,
|
||||||
|
"rpm": 10000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
},
|
||||||
|
"tpm": 100000,
|
||||||
|
"rpm": 10000,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": f"write a one sentence poem {time.time()}?"}
|
||||||
|
]
|
||||||
|
client_ttl_time = 2
|
||||||
|
router = Router(
|
||||||
|
model_list=model_list,
|
||||||
|
redis_host=os.environ["REDIS_HOST"],
|
||||||
|
redis_password=os.environ["REDIS_PASSWORD"],
|
||||||
|
redis_port=os.environ["REDIS_PORT"],
|
||||||
|
cache_responses=True,
|
||||||
|
timeout=30,
|
||||||
|
routing_strategy="simple-shuffle",
|
||||||
|
client_ttl=client_ttl_time,
|
||||||
|
)
|
||||||
|
model = "gpt-3.5-turbo"
|
||||||
|
cache_key = f"1234_async_client"
|
||||||
|
## ASSERT IT EXISTS AT THE START ##
|
||||||
|
assert router.cache.get_cache(key=cache_key) is not None
|
||||||
|
response1 = await router.acompletion(model=model, messages=messages, temperature=1)
|
||||||
|
await asyncio.sleep(client_ttl_time)
|
||||||
|
## ASSERT IT'S CLEARED FROM CACHE ##
|
||||||
|
assert router.cache.get_cache(key=cache_key, local_only=True) is None
|
||||||
|
## ASSERT IT EXISTS AFTER RUNNING __GET_CLIENT() ##
|
||||||
|
assert (
|
||||||
|
router._get_client(
|
||||||
|
deployment=model_list[0], client_type="async", kwargs={"stream": False}
|
||||||
|
)
|
||||||
|
is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# asyncio.run(test_router_init())
|
Loading…
Add table
Add a link
Reference in a new issue