diff --git a/litellm/caching.py b/litellm/caching.py index 2c394ea64..6c2ec0356 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -133,30 +133,31 @@ class DualCache(BaseCache): # If redis_cache is not provided, use the default RedisCache self.redis_cache = redis_cache - def set_cache(self, key, value, **kwargs): + def set_cache(self, key, value, local_only: bool = False, **kwargs): # Update both Redis and in-memory cache try: print_verbose(f"set cache: key: {key}; value: {value}") if self.in_memory_cache is not None: self.in_memory_cache.set_cache(key, value, **kwargs) - if self.redis_cache is not None: + if self.redis_cache is not None and local_only == False: self.redis_cache.set_cache(key, value, **kwargs) except Exception as e: print_verbose(e) - def get_cache(self, key, **kwargs): + def get_cache(self, key, local_only: bool = False, **kwargs): # Try to fetch from in-memory cache first try: - print_verbose(f"get cache: cache key: {key}") + print_verbose(f"get cache: cache key: {key}; local_only: {local_only}") result = None if self.in_memory_cache is not None: in_memory_result = self.in_memory_cache.get_cache(key, **kwargs) + print_verbose(f"in_memory_result: {in_memory_result}") if in_memory_result is not None: result = in_memory_result - if self.redis_cache is not None: + if result is None and self.redis_cache is not None and local_only == False: # If not found in in-memory cache, try fetching from Redis redis_result = self.redis_cache.get_cache(key, **kwargs) diff --git a/litellm/router.py b/litellm/router.py index a58ceeaec..4e953c740 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -84,6 +84,7 @@ class Router: caching_groups: Optional[ List[tuple] ] = None, # if you want to cache across model groups + client_ttl: int = 3600, # ttl for cached clients - will re-initialize after this time in seconds ## RELIABILITY ## num_retries: int = 0, timeout: Optional[float] = None, @@ -106,6 +107,43 @@ class Router: [] ) # names of models under litellm_params. ex. azure/chatgpt-v-2 self.deployment_latency_map = {} + ### CACHING ### + cache_type: Literal["local", "redis"] = "local" # default to an in-memory cache + redis_cache = None + cache_config = {} + self.client_ttl = client_ttl + if redis_url is not None or ( + redis_host is not None + and redis_port is not None + and redis_password is not None + ): + cache_type = "redis" + + if redis_url is not None: + cache_config["url"] = redis_url + + if redis_host is not None: + cache_config["host"] = redis_host + + if redis_port is not None: + cache_config["port"] = str(redis_port) # type: ignore + + if redis_password is not None: + cache_config["password"] = redis_password + + # Add additional key-value pairs from cache_kwargs + cache_config.update(cache_kwargs) + redis_cache = RedisCache(**cache_config) + + if cache_responses: + if litellm.cache is None: + # the cache can be initialized on the proxy server. We should not overwrite it + litellm.cache = litellm.Cache(type=cache_type, **cache_config) # type: ignore + self.cache_responses = cache_responses + self.cache = DualCache( + redis_cache=redis_cache, in_memory_cache=InMemoryCache() + ) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc. + if model_list: model_list = copy.deepcopy(model_list) self.set_model_list(model_list) @@ -155,40 +193,6 @@ class Router: {"caching_groups": caching_groups} ) - ### CACHING ### - cache_type: Literal["local", "redis"] = "local" # default to an in-memory cache - redis_cache = None - cache_config = {} - if redis_url is not None or ( - redis_host is not None - and redis_port is not None - and redis_password is not None - ): - cache_type = "redis" - - if redis_url is not None: - cache_config["url"] = redis_url - - if redis_host is not None: - cache_config["host"] = redis_host - - if redis_port is not None: - cache_config["port"] = str(redis_port) # type: ignore - - if redis_password is not None: - cache_config["password"] = redis_password - - # Add additional key-value pairs from cache_kwargs - cache_config.update(cache_kwargs) - redis_cache = RedisCache(**cache_config) - if cache_responses: - if litellm.cache is None: - # the cache can be initialized on the proxy server. We should not overwrite it - litellm.cache = litellm.Cache(type=cache_type, **cache_config) # type: ignore - self.cache_responses = cache_responses - self.cache = DualCache( - redis_cache=redis_cache, in_memory_cache=InMemoryCache() - ) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc. ### ROUTING SETUP ### if routing_strategy == "least-busy": self.leastbusy_logger = LeastBusyLoggingHandler( @@ -1248,208 +1252,297 @@ class Router: chosen_item = random.choices(sorted_items, weights=weights, k=1)[0][0] return chosen_item + def set_client(self, model: dict): + """ + Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278 + """ + client_ttl = self.client_ttl + litellm_params = model.get("litellm_params", {}) + model_name = litellm_params.get("model") + model_id = model["model_info"]["id"] + #### for OpenAI / Azure we need to initalize the Client for High Traffic ######## + custom_llm_provider = litellm_params.get("custom_llm_provider") + custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or "" + default_api_base = None + default_api_key = None + if custom_llm_provider in litellm.openai_compatible_providers: + _, custom_llm_provider, api_key, api_base = litellm.get_llm_provider( + model=model_name + ) + default_api_base = api_base + default_api_key = api_key + if ( + model_name in litellm.open_ai_chat_completion_models + or custom_llm_provider in litellm.openai_compatible_providers + or custom_llm_provider == "azure" + or custom_llm_provider == "custom_openai" + or custom_llm_provider == "openai" + or "ft:gpt-3.5-turbo" in model_name + or model_name in litellm.open_ai_embedding_models + ): + # glorified / complicated reading of configs + # user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env + # we do this here because we init clients for Azure, OpenAI and we need to set the right key + api_key = litellm_params.get("api_key") or default_api_key + if api_key and api_key.startswith("os.environ/"): + api_key_env_name = api_key.replace("os.environ/", "") + api_key = litellm.get_secret(api_key_env_name) + litellm_params["api_key"] = api_key + + api_base = litellm_params.get("api_base") + base_url = litellm_params.get("base_url") + api_base = ( + api_base or base_url or default_api_base + ) # allow users to pass in `api_base` or `base_url` for azure + if api_base and api_base.startswith("os.environ/"): + api_base_env_name = api_base.replace("os.environ/", "") + api_base = litellm.get_secret(api_base_env_name) + litellm_params["api_base"] = api_base + + api_version = litellm_params.get("api_version") + if api_version and api_version.startswith("os.environ/"): + api_version_env_name = api_version.replace("os.environ/", "") + api_version = litellm.get_secret(api_version_env_name) + litellm_params["api_version"] = api_version + + timeout = litellm_params.pop("timeout", None) + if isinstance(timeout, str) and timeout.startswith("os.environ/"): + timeout_env_name = timeout.replace("os.environ/", "") + timeout = litellm.get_secret(timeout_env_name) + litellm_params["timeout"] = timeout + + stream_timeout = litellm_params.pop( + "stream_timeout", timeout + ) # if no stream_timeout is set, default to timeout + if isinstance(stream_timeout, str) and stream_timeout.startswith( + "os.environ/" + ): + stream_timeout_env_name = stream_timeout.replace("os.environ/", "") + stream_timeout = litellm.get_secret(stream_timeout_env_name) + litellm_params["stream_timeout"] = stream_timeout + + max_retries = litellm_params.pop("max_retries", 2) + if isinstance(max_retries, str) and max_retries.startswith("os.environ/"): + max_retries_env_name = max_retries.replace("os.environ/", "") + max_retries = litellm.get_secret(max_retries_env_name) + litellm_params["max_retries"] = max_retries + + if "azure" in model_name: + if api_base is None: + raise ValueError( + f"api_base is required for Azure OpenAI. Set it on your config. Model - {model}" + ) + if api_version is None: + api_version = "2023-07-01-preview" + if "gateway.ai.cloudflare.com" in api_base: + if not api_base.endswith("/"): + api_base += "/" + azure_model = model_name.replace("azure/", "") + api_base += f"{azure_model}" + cache_key = f"{model_id}_async_client" + _client = openai.AsyncAzureOpenAI( + api_key=api_key, + base_url=api_base, + api_version=api_version, + timeout=timeout, + max_retries=max_retries, + ) + self.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + + _client = openai.AzureOpenAI( # type: ignore + api_key=api_key, + base_url=api_base, + api_version=api_version, + timeout=timeout, + max_retries=max_retries, + ) + self.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + # streaming clients can have diff timeouts + cache_key = f"{model_id}_stream_async_client" + _client = openai.AsyncAzureOpenAI( # type: ignore + api_key=api_key, + base_url=api_base, + api_version=api_version, + timeout=stream_timeout, + max_retries=max_retries, + ) + self.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + + cache_key = f"{model_id}_stream_client" + _client = openai.AzureOpenAI( # type: ignore + api_key=api_key, + base_url=api_base, + api_version=api_version, + timeout=stream_timeout, + max_retries=max_retries, + ) + self.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + else: + self.print_verbose( + f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{api_key}" + ) + + cache_key = f"{model_id}_async_client" + _client = openai.AsyncAzureOpenAI( # type: ignore + api_key=api_key, + azure_endpoint=api_base, + api_version=api_version, + timeout=timeout, + max_retries=max_retries, + http_client=httpx.AsyncClient( + transport=AsyncCustomHTTPTransport(), + ), # type: ignore + ) + self.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + + cache_key = f"{model_id}_client" + _client = openai.AzureOpenAI( # type: ignore + api_key=api_key, + azure_endpoint=api_base, + api_version=api_version, + timeout=timeout, + max_retries=max_retries, + http_client=httpx.Client( + transport=CustomHTTPTransport(), + ), # type: ignore + ) + self.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + + # streaming clients should have diff timeouts + cache_key = f"{model_id}_stream_async_client" + _client = openai.AsyncAzureOpenAI( # type: ignore + api_key=api_key, + azure_endpoint=api_base, + api_version=api_version, + timeout=stream_timeout, + max_retries=max_retries, + ) + self.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + + cache_key = f"{model_id}_stream_client" + _client = openai.AzureOpenAI( # type: ignore + api_key=api_key, + azure_endpoint=api_base, + api_version=api_version, + timeout=stream_timeout, + max_retries=max_retries, + ) + self.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + + else: + self.print_verbose( + f"Initializing OpenAI Client for {model_name}, Api Base:{str(api_base)}, Api Key:{api_key}" + ) + cache_key = f"{model_id}_async_client" + _client = openai.AsyncOpenAI( # type: ignore + api_key=api_key, + base_url=api_base, + timeout=timeout, + max_retries=max_retries, + ) + self.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + + cache_key = f"{model_id}_client" + _client = openai.OpenAI( # type: ignore + api_key=api_key, + base_url=api_base, + timeout=timeout, + max_retries=max_retries, + ) + self.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + + # streaming clients should have diff timeouts + cache_key = f"{model_id}_stream_async_client" + _client = openai.AsyncOpenAI( # type: ignore + api_key=api_key, + base_url=api_base, + timeout=stream_timeout, + max_retries=max_retries, + ) + self.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + + # streaming clients should have diff timeouts + cache_key = f"{model_id}_stream_client" + _client = openai.OpenAI( # type: ignore + api_key=api_key, + base_url=api_base, + timeout=stream_timeout, + max_retries=max_retries, + ) + self.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + def set_model_list(self, model_list: list): self.model_list = copy.deepcopy(model_list) # we add api_base/api_key each model so load balancing between azure/gpt on api_base1 and api_base2 works import os for model in self.model_list: - litellm_params = model.get("litellm_params", {}) - model_name = litellm_params.get("model") #### MODEL ID INIT ######## model_info = model.get("model_info", {}) model_info["id"] = model_info.get("id", str(uuid.uuid4())) model["model_info"] = model_info - #### for OpenAI / Azure we need to initalize the Client for High Traffic ######## - custom_llm_provider = litellm_params.get("custom_llm_provider") - custom_llm_provider = ( - custom_llm_provider or model_name.split("/", 1)[0] or "" - ) - default_api_base = None - default_api_key = None - if custom_llm_provider in litellm.openai_compatible_providers: - _, custom_llm_provider, api_key, api_base = litellm.get_llm_provider( - model=model_name - ) - default_api_base = api_base - default_api_key = api_key - if ( - model_name in litellm.open_ai_chat_completion_models - or custom_llm_provider in litellm.openai_compatible_providers - or custom_llm_provider == "azure" - or custom_llm_provider == "custom_openai" - or custom_llm_provider == "openai" - or "ft:gpt-3.5-turbo" in model_name - or model_name in litellm.open_ai_embedding_models - ): - # glorified / complicated reading of configs - # user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env - # we do this here because we init clients for Azure, OpenAI and we need to set the right key - api_key = litellm_params.get("api_key") or default_api_key - if api_key and api_key.startswith("os.environ/"): - api_key_env_name = api_key.replace("os.environ/", "") - api_key = litellm.get_secret(api_key_env_name) - litellm_params["api_key"] = api_key - - api_base = litellm_params.get("api_base") - base_url = litellm_params.get("base_url") - api_base = ( - api_base or base_url or default_api_base - ) # allow users to pass in `api_base` or `base_url` for azure - if api_base and api_base.startswith("os.environ/"): - api_base_env_name = api_base.replace("os.environ/", "") - api_base = litellm.get_secret(api_base_env_name) - litellm_params["api_base"] = api_base - - api_version = litellm_params.get("api_version") - if api_version and api_version.startswith("os.environ/"): - api_version_env_name = api_version.replace("os.environ/", "") - api_version = litellm.get_secret(api_version_env_name) - litellm_params["api_version"] = api_version - - timeout = litellm_params.pop("timeout", None) - if isinstance(timeout, str) and timeout.startswith("os.environ/"): - timeout_env_name = timeout.replace("os.environ/", "") - timeout = litellm.get_secret(timeout_env_name) - litellm_params["timeout"] = timeout - - stream_timeout = litellm_params.pop( - "stream_timeout", timeout - ) # if no stream_timeout is set, default to timeout - if isinstance(stream_timeout, str) and stream_timeout.startswith( - "os.environ/" - ): - stream_timeout_env_name = stream_timeout.replace("os.environ/", "") - stream_timeout = litellm.get_secret(stream_timeout_env_name) - litellm_params["stream_timeout"] = stream_timeout - - max_retries = litellm_params.pop("max_retries", 2) - if isinstance(max_retries, str) and max_retries.startswith( - "os.environ/" - ): - max_retries_env_name = max_retries.replace("os.environ/", "") - max_retries = litellm.get_secret(max_retries_env_name) - litellm_params["max_retries"] = max_retries - - if "azure" in model_name: - if api_base is None: - raise ValueError( - f"api_base is required for Azure OpenAI. Set it on your config. Model - {model}" - ) - if api_version is None: - api_version = "2023-07-01-preview" - if "gateway.ai.cloudflare.com" in api_base: - if not api_base.endswith("/"): - api_base += "/" - azure_model = model_name.replace("azure/", "") - api_base += f"{azure_model}" - model["async_client"] = openai.AsyncAzureOpenAI( - api_key=api_key, - base_url=api_base, - api_version=api_version, - timeout=timeout, - max_retries=max_retries, - ) - model["client"] = openai.AzureOpenAI( - api_key=api_key, - base_url=api_base, - api_version=api_version, - timeout=timeout, - max_retries=max_retries, - ) - - # streaming clients can have diff timeouts - model["stream_async_client"] = openai.AsyncAzureOpenAI( - api_key=api_key, - base_url=api_base, - api_version=api_version, - timeout=stream_timeout, - max_retries=max_retries, - ) - model["stream_client"] = openai.AzureOpenAI( - api_key=api_key, - base_url=api_base, - api_version=api_version, - timeout=stream_timeout, - max_retries=max_retries, - ) - else: - self.print_verbose( - f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{api_key}" - ) - model["async_client"] = openai.AsyncAzureOpenAI( - api_key=api_key, - azure_endpoint=api_base, - api_version=api_version, - timeout=timeout, - max_retries=max_retries, - http_client=httpx.AsyncClient( - transport=AsyncCustomHTTPTransport(), - ), # type: ignore - ) - model["client"] = openai.AzureOpenAI( - api_key=api_key, - azure_endpoint=api_base, - api_version=api_version, - timeout=timeout, - max_retries=max_retries, - http_client=httpx.Client( - transport=CustomHTTPTransport(), - ), # type: ignore - ) - # streaming clients should have diff timeouts - model["stream_async_client"] = openai.AsyncAzureOpenAI( - api_key=api_key, - azure_endpoint=api_base, - api_version=api_version, - timeout=stream_timeout, - max_retries=max_retries, - ) - - model["stream_client"] = openai.AzureOpenAI( - api_key=api_key, - azure_endpoint=api_base, - api_version=api_version, - timeout=stream_timeout, - max_retries=max_retries, - ) - - else: - self.print_verbose( - f"Initializing OpenAI Client for {model_name}, Api Base:{str(api_base)}, Api Key:{api_key}" - ) - model["async_client"] = openai.AsyncOpenAI( - api_key=api_key, - base_url=api_base, - timeout=timeout, - max_retries=max_retries, - ) - model["client"] = openai.OpenAI( - api_key=api_key, - base_url=api_base, - timeout=timeout, - max_retries=max_retries, - ) - - # streaming clients should have diff timeouts - model["stream_async_client"] = openai.AsyncOpenAI( - api_key=api_key, - base_url=api_base, - timeout=stream_timeout, - max_retries=max_retries, - ) - - # streaming clients should have diff timeouts - model["stream_client"] = openai.OpenAI( - api_key=api_key, - base_url=api_base, - timeout=stream_timeout, - max_retries=max_retries, - ) - - ############ End of initializing Clients for OpenAI/Azure ################### + #### DEPLOYMENT NAMES INIT ######## self.deployment_names.append(model["litellm_params"]["model"]) - ############ Users can either pass tpm/rpm as a litellm_param or a router param ########### # for get_available_deployment, we use the litellm_param["rpm"] # in this snippet we also set rpm to be a litellm_param @@ -1464,6 +1557,8 @@ class Router: ): model["litellm_params"]["tpm"] = model.get("tpm") + self.set_client(model=model) + self.print_verbose(f"\nInitialized Model List {self.model_list}") self.model_names = [m["model_name"] for m in model_list] @@ -1482,16 +1577,49 @@ class Router: Returns: The appropriate client based on the given client_type and kwargs. """ + model_id = deployment["model_info"]["id"] if client_type == "async": if kwargs.get("stream") == True: - return deployment.get("stream_async_client", None) + cache_key = f"{model_id}_stream_async_client" + client = self.cache.get_cache(key=cache_key, local_only=True) + if client is None: + """ + Re-initialize the client + """ + self.set_client(model=deployment) + client = self.cache.get_cache(key=cache_key, local_only=True) + return client else: - return deployment.get("async_client", None) + cache_key = f"{model_id}_async_client" + client = self.cache.get_cache(key=cache_key, local_only=True) + if client is None: + """ + Re-initialize the client + """ + self.set_client(model=deployment) + client = self.cache.get_cache(key=cache_key, local_only=True) + return client else: if kwargs.get("stream") == True: - return deployment.get("stream_client", None) + cache_key = f"{model_id}_stream_client" + client = self.cache.get_cache(key=cache_key) + if client is None: + """ + Re-initialize the client + """ + self.set_client(model=deployment) + client = self.cache.get_cache(key=cache_key) + return client else: - return deployment.get("client", None) + cache_key = f"{model_id}_client" + client = self.cache.get_cache(key=cache_key) + if client is None: + """ + Re-initialize the client + """ + self.set_client(model=deployment) + client = self.cache.get_cache(key=cache_key) + return client def print_verbose(self, print_statement): try: diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index 3b8ea7ed4..987e6670f 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -778,7 +778,8 @@ def test_reading_keys_os_environ(): os.environ["AZURE_MAX_RETRIES"] ), f"{model['litellm_params']['max_retries']} vs {os.environ['AZURE_MAX_RETRIES']}" print("passed testing of reading keys from os.environ") - async_client: openai.AsyncAzureOpenAI = model["async_client"] # type: ignore + model_id = model["model_info"]["id"] + async_client: openai.AsyncAzureOpenAI = router.cache.get_cache(f"{model_id}_async_client") # type: ignore assert async_client.api_key == os.environ["AZURE_API_KEY"] assert async_client.base_url == os.environ["AZURE_API_BASE"] assert async_client.max_retries == ( @@ -791,7 +792,7 @@ def test_reading_keys_os_environ(): print("\n Testing async streaming client") - stream_async_client: openai.AsyncAzureOpenAI = model["stream_async_client"] # type: ignore + stream_async_client: openai.AsyncAzureOpenAI = router.cache.get_cache(f"{model_id}_stream_async_client") # type: ignore assert stream_async_client.api_key == os.environ["AZURE_API_KEY"] assert stream_async_client.base_url == os.environ["AZURE_API_BASE"] assert stream_async_client.max_retries == ( @@ -803,7 +804,7 @@ def test_reading_keys_os_environ(): print("async stream client set correctly!") print("\n Testing sync client") - client: openai.AzureOpenAI = model["client"] # type: ignore + client: openai.AzureOpenAI = router.cache.get_cache(f"{model_id}_client") # type: ignore assert client.api_key == os.environ["AZURE_API_KEY"] assert client.base_url == os.environ["AZURE_API_BASE"] assert client.max_retries == ( @@ -815,7 +816,7 @@ def test_reading_keys_os_environ(): print("sync client set correctly!") print("\n Testing sync stream client") - stream_client: openai.AzureOpenAI = model["stream_client"] # type: ignore + stream_client: openai.AzureOpenAI = router.cache.get_cache(f"{model_id}_stream_client") # type: ignore assert stream_client.api_key == os.environ["AZURE_API_KEY"] assert stream_client.base_url == os.environ["AZURE_API_BASE"] assert stream_client.max_retries == ( @@ -877,7 +878,8 @@ def test_reading_openai_keys_os_environ(): os.environ["AZURE_MAX_RETRIES"] ), f"{model['litellm_params']['max_retries']} vs {os.environ['AZURE_MAX_RETRIES']}" print("passed testing of reading keys from os.environ") - async_client: openai.AsyncOpenAI = model["async_client"] # type: ignore + model_id = model["model_info"]["id"] + async_client: openai.AsyncOpenAI = router.cache.get_cache(key=f"{model_id}_async_client") # type: ignore assert async_client.api_key == os.environ["OPENAI_API_KEY"] assert async_client.max_retries == ( os.environ["AZURE_MAX_RETRIES"] @@ -889,7 +891,7 @@ def test_reading_openai_keys_os_environ(): print("\n Testing async streaming client") - stream_async_client: openai.AsyncOpenAI = model["stream_async_client"] # type: ignore + stream_async_client: openai.AsyncOpenAI = router.cache.get_cache(key=f"{model_id}_stream_async_client") # type: ignore assert stream_async_client.api_key == os.environ["OPENAI_API_KEY"] assert stream_async_client.max_retries == ( os.environ["AZURE_MAX_RETRIES"] @@ -900,7 +902,7 @@ def test_reading_openai_keys_os_environ(): print("async stream client set correctly!") print("\n Testing sync client") - client: openai.AzureOpenAI = model["client"] # type: ignore + client: openai.AzureOpenAI = router.cache.get_cache(key=f"{model_id}_client") # type: ignore assert client.api_key == os.environ["OPENAI_API_KEY"] assert client.max_retries == ( os.environ["AZURE_MAX_RETRIES"] @@ -911,7 +913,7 @@ def test_reading_openai_keys_os_environ(): print("sync client set correctly!") print("\n Testing sync stream client") - stream_client: openai.AzureOpenAI = model["stream_client"] # type: ignore + stream_client: openai.AzureOpenAI = router.cache.get_cache(key=f"{model_id}_stream_client") # type: ignore assert stream_client.api_key == os.environ["OPENAI_API_KEY"] assert stream_client.max_retries == ( os.environ["AZURE_MAX_RETRIES"] diff --git a/litellm/tests/test_router_client_init.py b/litellm/tests/test_router_client_init.py new file mode 100644 index 000000000..79f8ba8b2 --- /dev/null +++ b/litellm/tests/test_router_client_init.py @@ -0,0 +1,78 @@ +#### What this tests #### +# This tests client initialization + reinitialization on the router + +#### What this tests #### +# This tests caching on the router +import sys, os, time +import traceback, asyncio +import pytest + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import litellm +from litellm import Router + + +async def test_router_init(): + """ + 1. Initializes clients on the router with 0 + 2. Checks if client is still valid + 3. Checks if new client was initialized + """ + model_list = [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo-0613", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + "model_info": {"id": "1234"}, + "tpm": 100000, + "rpm": 10000, + }, + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "azure/chatgpt-v-2", + "api_key": os.getenv("AZURE_API_KEY"), + "api_base": os.getenv("AZURE_API_BASE"), + "api_version": os.getenv("AZURE_API_VERSION"), + }, + "tpm": 100000, + "rpm": 10000, + }, + ] + + messages = [ + {"role": "user", "content": f"write a one sentence poem {time.time()}?"} + ] + client_ttl_time = 2 + router = Router( + model_list=model_list, + redis_host=os.environ["REDIS_HOST"], + redis_password=os.environ["REDIS_PASSWORD"], + redis_port=os.environ["REDIS_PORT"], + cache_responses=True, + timeout=30, + routing_strategy="simple-shuffle", + client_ttl=client_ttl_time, + ) + model = "gpt-3.5-turbo" + cache_key = f"1234_async_client" + ## ASSERT IT EXISTS AT THE START ## + assert router.cache.get_cache(key=cache_key) is not None + response1 = await router.acompletion(model=model, messages=messages, temperature=1) + await asyncio.sleep(client_ttl_time) + ## ASSERT IT'S CLEARED FROM CACHE ## + assert router.cache.get_cache(key=cache_key, local_only=True) is None + ## ASSERT IT EXISTS AFTER RUNNING __GET_CLIENT() ## + assert ( + router._get_client( + deployment=model_list[0], client_type="async", kwargs={"stream": False} + ) + is not None + ) + + +# asyncio.run(test_router_init())