feat(router.py): support caching groups

This commit is contained in:
Krrish Dholakia 2023-12-15 21:45:37 -08:00
parent a7822b8772
commit 84ad9f441e
4 changed files with 104 additions and 16 deletions

View file

@ -233,15 +233,29 @@ class Cache:
# check if param == model and model_group is passed in, then override model with model_group # check if param == model and model_group is passed in, then override model with model_group
if param == "model": if param == "model":
model_group = None model_group = None
caching_group = None
metadata = kwargs.get("metadata", None) metadata = kwargs.get("metadata", None)
litellm_params = kwargs.get("litellm_params", {}) litellm_params = kwargs.get("litellm_params", {})
if metadata is not None: if metadata is not None:
model_group = metadata.get("model_group") model_group = metadata.get("model_group")
model_group = metadata.get("model_group", None)
caching_groups = metadata.get("caching_groups", None)
if caching_groups:
for group in caching_groups:
if model_group in group:
caching_group = group
break
if litellm_params is not None: if litellm_params is not None:
metadata = litellm_params.get("metadata", None) metadata = litellm_params.get("metadata", None)
if metadata is not None: if metadata is not None:
model_group = metadata.get("model_group", None) model_group = metadata.get("model_group", None)
param_value = model_group or kwargs[param] # use model_group if it exists, else use kwargs["model"] caching_groups = metadata.get("caching_groups", None)
if caching_groups:
for group in caching_groups:
if model_group in group:
caching_group = group
break
param_value = caching_group or model_group or kwargs[param] # use caching_group, if set then model_group if it exists, else use kwargs["model"]
else: else:
if kwargs[param] is None: if kwargs[param] is None:
continue # ignore None params continue # ignore None params

View file

@ -357,7 +357,7 @@ def completion(
client = kwargs.get("client", None) client = kwargs.get("client", None)
######## end of unpacking kwargs ########### ######## end of unpacking kwargs ###########
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries"] openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries"]
litellm_params = ["metadata", "acompletion", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "model_info", "proxy_server_request", "preset_cache_key"] litellm_params = ["metadata", "acompletion", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "model_info", "proxy_server_request", "preset_cache_key", "caching_groups"]
default_params = openai_params + litellm_params default_params = openai_params + litellm_params
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
if mock_response: if mock_response:
@ -1824,7 +1824,7 @@ def embedding(
proxy_server_request = kwargs.get("proxy_server_request", None) proxy_server_request = kwargs.get("proxy_server_request", None)
aembedding = kwargs.get("aembedding", None) aembedding = kwargs.get("aembedding", None)
openai_params = ["user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "max_retries", "encoding_format"] openai_params = ["user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "max_retries", "encoding_format"]
litellm_params = ["metadata", "aembedding", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key"] litellm_params = ["metadata", "aembedding", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key", "caching_groups"]
default_params = openai_params + litellm_params default_params = openai_params + litellm_params
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider

View file

@ -70,6 +70,7 @@ class Router:
redis_password: Optional[str] = None, redis_password: Optional[str] = None,
cache_responses: Optional[bool] = False, cache_responses: Optional[bool] = False,
cache_kwargs: dict = {}, # additional kwargs to pass to RedisCache (see caching.py) cache_kwargs: dict = {}, # additional kwargs to pass to RedisCache (see caching.py)
caching_groups: Optional[List[tuple]] = None, # if you want to cache across model groups
## RELIABILITY ## ## RELIABILITY ##
num_retries: int = 0, num_retries: int = 0,
timeout: Optional[float] = None, timeout: Optional[float] = None,
@ -112,6 +113,7 @@ class Router:
self.default_litellm_params = default_litellm_params self.default_litellm_params = default_litellm_params
self.default_litellm_params.setdefault("timeout", timeout) self.default_litellm_params.setdefault("timeout", timeout)
self.default_litellm_params.setdefault("max_retries", 0) self.default_litellm_params.setdefault("max_retries", 0)
self.default_litellm_params.setdefault("metadata", {}).update({"caching_groups": caching_groups})
### CACHING ### ### CACHING ###
cache_type: Literal["local", "redis"] = "local" # default to an in-memory cache cache_type: Literal["local", "redis"] = "local" # default to an in-memory cache
@ -203,8 +205,10 @@ class Router:
data = deployment["litellm_params"].copy() data = deployment["litellm_params"].copy()
kwargs["model_info"] = deployment.get("model_info", {}) kwargs["model_info"] = deployment.get("model_info", {})
for k, v in self.default_litellm_params.items(): for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params if k not in kwargs: # prioritize model-specific params > default router params
data[k] = v kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
model_client = self._get_client(deployment=deployment, kwargs=kwargs) model_client = self._get_client(deployment=deployment, kwargs=kwargs)
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs}) return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs})
except Exception as e: except Exception as e:
@ -241,8 +245,10 @@ class Router:
data = deployment["litellm_params"].copy() data = deployment["litellm_params"].copy()
model_name = data["model"] model_name = data["model"]
for k, v in self.default_litellm_params.items(): for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params if k not in kwargs: # prioritize model-specific params > default router params
data[k] = v kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async") model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async")
self.total_calls[model_name] +=1 self.total_calls[model_name] +=1
@ -269,8 +275,11 @@ class Router:
data = deployment["litellm_params"].copy() data = deployment["litellm_params"].copy()
for k, v in self.default_litellm_params.items(): for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params if k not in kwargs: # prioritize model-specific params > default router params
data[k] = v kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
# call via litellm.completion() # call via litellm.completion()
return litellm.text_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore return litellm.text_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore
except Exception as e: except Exception as e:
@ -298,8 +307,11 @@ class Router:
data = deployment["litellm_params"].copy() data = deployment["litellm_params"].copy()
for k, v in self.default_litellm_params.items(): for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params if k not in kwargs: # prioritize model-specific params > default router params
data[k] = v kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
########## remove -ModelID-XXXX from model ############## ########## remove -ModelID-XXXX from model ##############
original_model_string = data["model"] original_model_string = data["model"]
# Find the index of "ModelID" in the string # Find the index of "ModelID" in the string
@ -333,8 +345,10 @@ class Router:
kwargs.setdefault("metadata", {}).update({"model_group": model, "deployment": deployment["litellm_params"]["model"]}) # [TODO]: move to using async_function_with_fallbacks kwargs.setdefault("metadata", {}).update({"model_group": model, "deployment": deployment["litellm_params"]["model"]}) # [TODO]: move to using async_function_with_fallbacks
data = deployment["litellm_params"].copy() data = deployment["litellm_params"].copy()
for k, v in self.default_litellm_params.items(): for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params if k not in kwargs: # prioritize model-specific params > default router params
data[k] = v kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
model_client = self._get_client(deployment=deployment, kwargs=kwargs) model_client = self._get_client(deployment=deployment, kwargs=kwargs)
# call via litellm.embedding() # call via litellm.embedding()
return litellm.embedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs}) return litellm.embedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs})
@ -350,8 +364,10 @@ class Router:
data = deployment["litellm_params"].copy() data = deployment["litellm_params"].copy()
kwargs["model_info"] = deployment.get("model_info", {}) kwargs["model_info"] = deployment.get("model_info", {})
for k, v in self.default_litellm_params.items(): for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params if k not in kwargs: # prioritize model-specific params > default router params
data[k] = v kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async") model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async")

View file

@ -10,7 +10,8 @@ import litellm
from litellm import Router from litellm import Router
## Scenarios ## Scenarios
## 1. 2 models - openai + azure - 1 model group "gpt-3.5-turbo", assert cache key is the model group ## 1. 2 models - openai + azure - 1 model group "gpt-3.5-turbo",
## 2. 2 models - openai, azure - 2 diff model groups, 1 caching group
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_acompletion_caching_on_router(): async def test_acompletion_caching_on_router():
@ -64,6 +65,63 @@ async def test_acompletion_caching_on_router():
end_time = time.time() end_time = time.time()
print(f"timeout error occurred: {end_time - start_time}") print(f"timeout error occurred: {end_time - start_time}")
pass pass
except Exception as e:
traceback.print_exc()
pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio
async def test_acompletion_caching_on_router_caching_groups():
# tests acompletion + caching on router
try:
litellm.set_verbose = True
model_list = [
{
"model_name": "openai-gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo-0613",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 100000,
"rpm": 10000,
},
{
"model_name": "azure-gpt-3.5-turbo",
"litellm_params": {
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_base": os.getenv("AZURE_API_BASE"),
"api_version": os.getenv("AZURE_API_VERSION")
},
"tpm": 100000,
"rpm": 10000,
}
]
messages = [
{"role": "user", "content": f"write a one sentence poem {time.time()}?"}
]
start_time = time.time()
router = Router(model_list=model_list,
redis_host=os.environ["REDIS_HOST"],
redis_password=os.environ["REDIS_PASSWORD"],
redis_port=os.environ["REDIS_PORT"],
cache_responses=True,
timeout=30,
routing_strategy="simple-shuffle",
caching_groups=[("openai-gpt-3.5-turbo", "azure-gpt-3.5-turbo")])
response1 = await router.acompletion(model="openai-gpt-3.5-turbo", messages=messages, temperature=1)
print(f"response1: {response1}")
await asyncio.sleep(1) # add cache is async, async sleep for cache to get set
response2 = await router.acompletion(model="azure-gpt-3.5-turbo", messages=messages, temperature=1)
print(f"response2: {response2}")
assert response1.id == response2.id
assert len(response1.choices[0].message.content) > 0
assert response1.choices[0].message.content == response2.choices[0].message.content
router.reset()
except litellm.Timeout as e:
end_time = time.time()
print(f"timeout error occurred: {end_time - start_time}")
pass
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")