diff --git a/litellm/caching.py b/litellm/caching.py index 02dd1aebe..73dde7cf9 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -233,15 +233,29 @@ class Cache: # check if param == model and model_group is passed in, then override model with model_group if param == "model": model_group = None + caching_group = None metadata = kwargs.get("metadata", None) litellm_params = kwargs.get("litellm_params", {}) if metadata is not None: model_group = metadata.get("model_group") + model_group = metadata.get("model_group", None) + caching_groups = metadata.get("caching_groups", None) + if caching_groups: + for group in caching_groups: + if model_group in group: + caching_group = group + break if litellm_params is not None: metadata = litellm_params.get("metadata", None) if metadata is not None: model_group = metadata.get("model_group", None) - param_value = model_group or kwargs[param] # use model_group if it exists, else use kwargs["model"] + caching_groups = metadata.get("caching_groups", None) + if caching_groups: + for group in caching_groups: + if model_group in group: + caching_group = group + break + param_value = caching_group or model_group or kwargs[param] # use caching_group, if set then model_group if it exists, else use kwargs["model"] else: if kwargs[param] is None: continue # ignore None params diff --git a/litellm/main.py b/litellm/main.py index 0ad091bf7..d128fc14b 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -357,7 +357,7 @@ def completion( client = kwargs.get("client", None) ######## end of unpacking kwargs ########### openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries"] - litellm_params = ["metadata", "acompletion", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "model_info", "proxy_server_request", "preset_cache_key"] + litellm_params = ["metadata", "acompletion", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "model_info", "proxy_server_request", "preset_cache_key", "caching_groups"] default_params = openai_params + litellm_params non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider if mock_response: @@ -1824,7 +1824,7 @@ def embedding( proxy_server_request = kwargs.get("proxy_server_request", None) aembedding = kwargs.get("aembedding", None) openai_params = ["user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "max_retries", "encoding_format"] - litellm_params = ["metadata", "aembedding", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key"] + litellm_params = ["metadata", "aembedding", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key", "caching_groups"] default_params = openai_params + litellm_params non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider diff --git a/litellm/router.py b/litellm/router.py index 994e52b52..96d90dde0 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -70,6 +70,7 @@ class Router: redis_password: Optional[str] = None, cache_responses: Optional[bool] = False, cache_kwargs: dict = {}, # additional kwargs to pass to RedisCache (see caching.py) + caching_groups: Optional[List[tuple]] = None, # if you want to cache across model groups ## RELIABILITY ## num_retries: int = 0, timeout: Optional[float] = None, @@ -112,6 +113,7 @@ class Router: self.default_litellm_params = default_litellm_params self.default_litellm_params.setdefault("timeout", timeout) self.default_litellm_params.setdefault("max_retries", 0) + self.default_litellm_params.setdefault("metadata", {}).update({"caching_groups": caching_groups}) ### CACHING ### cache_type: Literal["local", "redis"] = "local" # default to an in-memory cache @@ -203,8 +205,10 @@ class Router: data = deployment["litellm_params"].copy() kwargs["model_info"] = deployment.get("model_info", {}) for k, v in self.default_litellm_params.items(): - if k not in data: # prioritize model-specific params > default router params - data[k] = v + if k not in kwargs: # prioritize model-specific params > default router params + kwargs[k] = v + elif k == "metadata": + kwargs[k].update(v) model_client = self._get_client(deployment=deployment, kwargs=kwargs) return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs}) except Exception as e: @@ -241,8 +245,10 @@ class Router: data = deployment["litellm_params"].copy() model_name = data["model"] for k, v in self.default_litellm_params.items(): - if k not in data: # prioritize model-specific params > default router params - data[k] = v + if k not in kwargs: # prioritize model-specific params > default router params + kwargs[k] = v + elif k == "metadata": + kwargs[k].update(v) model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async") self.total_calls[model_name] +=1 @@ -269,8 +275,11 @@ class Router: data = deployment["litellm_params"].copy() for k, v in self.default_litellm_params.items(): - if k not in data: # prioritize model-specific params > default router params - data[k] = v + if k not in kwargs: # prioritize model-specific params > default router params + kwargs[k] = v + elif k == "metadata": + kwargs[k].update(v) + # call via litellm.completion() return litellm.text_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore except Exception as e: @@ -298,8 +307,11 @@ class Router: data = deployment["litellm_params"].copy() for k, v in self.default_litellm_params.items(): - if k not in data: # prioritize model-specific params > default router params - data[k] = v + if k not in kwargs: # prioritize model-specific params > default router params + kwargs[k] = v + elif k == "metadata": + kwargs[k].update(v) + ########## remove -ModelID-XXXX from model ############## original_model_string = data["model"] # Find the index of "ModelID" in the string @@ -333,8 +345,10 @@ class Router: kwargs.setdefault("metadata", {}).update({"model_group": model, "deployment": deployment["litellm_params"]["model"]}) # [TODO]: move to using async_function_with_fallbacks data = deployment["litellm_params"].copy() for k, v in self.default_litellm_params.items(): - if k not in data: # prioritize model-specific params > default router params - data[k] = v + if k not in kwargs: # prioritize model-specific params > default router params + kwargs[k] = v + elif k == "metadata": + kwargs[k].update(v) model_client = self._get_client(deployment=deployment, kwargs=kwargs) # call via litellm.embedding() return litellm.embedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs}) @@ -350,8 +364,10 @@ class Router: data = deployment["litellm_params"].copy() kwargs["model_info"] = deployment.get("model_info", {}) for k, v in self.default_litellm_params.items(): - if k not in data: # prioritize model-specific params > default router params - data[k] = v + if k not in kwargs: # prioritize model-specific params > default router params + kwargs[k] = v + elif k == "metadata": + kwargs[k].update(v) model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async") diff --git a/litellm/tests/test_router_caching.py b/litellm/tests/test_router_caching.py index 0ca85d2ef..27191c8d2 100644 --- a/litellm/tests/test_router_caching.py +++ b/litellm/tests/test_router_caching.py @@ -10,7 +10,8 @@ import litellm from litellm import Router ## Scenarios -## 1. 2 models - openai + azure - 1 model group "gpt-3.5-turbo", assert cache key is the model group +## 1. 2 models - openai + azure - 1 model group "gpt-3.5-turbo", +## 2. 2 models - openai, azure - 2 diff model groups, 1 caching group @pytest.mark.asyncio async def test_acompletion_caching_on_router(): @@ -64,6 +65,63 @@ async def test_acompletion_caching_on_router(): end_time = time.time() print(f"timeout error occurred: {end_time - start_time}") pass + except Exception as e: + traceback.print_exc() + pytest.fail(f"Error occurred: {e}") + +@pytest.mark.asyncio +async def test_acompletion_caching_on_router_caching_groups(): + # tests acompletion + caching on router + try: + litellm.set_verbose = True + model_list = [ + { + "model_name": "openai-gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo-0613", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + "tpm": 100000, + "rpm": 10000, + }, + { + "model_name": "azure-gpt-3.5-turbo", + "litellm_params": { + "model": "azure/chatgpt-v-2", + "api_key": os.getenv("AZURE_API_KEY"), + "api_base": os.getenv("AZURE_API_BASE"), + "api_version": os.getenv("AZURE_API_VERSION") + }, + "tpm": 100000, + "rpm": 10000, + } + ] + + messages = [ + {"role": "user", "content": f"write a one sentence poem {time.time()}?"} + ] + start_time = time.time() + router = Router(model_list=model_list, + redis_host=os.environ["REDIS_HOST"], + redis_password=os.environ["REDIS_PASSWORD"], + redis_port=os.environ["REDIS_PORT"], + cache_responses=True, + timeout=30, + routing_strategy="simple-shuffle", + caching_groups=[("openai-gpt-3.5-turbo", "azure-gpt-3.5-turbo")]) + response1 = await router.acompletion(model="openai-gpt-3.5-turbo", messages=messages, temperature=1) + print(f"response1: {response1}") + await asyncio.sleep(1) # add cache is async, async sleep for cache to get set + response2 = await router.acompletion(model="azure-gpt-3.5-turbo", messages=messages, temperature=1) + print(f"response2: {response2}") + assert response1.id == response2.id + assert len(response1.choices[0].message.content) > 0 + assert response1.choices[0].message.content == response2.choices[0].message.content + router.reset() + except litellm.Timeout as e: + end_time = time.time() + print(f"timeout error occurred: {end_time - start_time}") + pass except Exception as e: traceback.print_exc() pytest.fail(f"Error occurred: {e}") \ No newline at end of file