feat(router.py): support caching groups

This commit is contained in:
Krrish Dholakia 2023-12-15 21:45:37 -08:00
parent a7822b8772
commit 84ad9f441e
4 changed files with 104 additions and 16 deletions

View file

@ -70,6 +70,7 @@ class Router:
redis_password: Optional[str] = None,
cache_responses: Optional[bool] = False,
cache_kwargs: dict = {}, # additional kwargs to pass to RedisCache (see caching.py)
caching_groups: Optional[List[tuple]] = None, # if you want to cache across model groups
## RELIABILITY ##
num_retries: int = 0,
timeout: Optional[float] = None,
@ -112,6 +113,7 @@ class Router:
self.default_litellm_params = default_litellm_params
self.default_litellm_params.setdefault("timeout", timeout)
self.default_litellm_params.setdefault("max_retries", 0)
self.default_litellm_params.setdefault("metadata", {}).update({"caching_groups": caching_groups})
### CACHING ###
cache_type: Literal["local", "redis"] = "local" # default to an in-memory cache
@ -203,8 +205,10 @@ class Router:
data = deployment["litellm_params"].copy()
kwargs["model_info"] = deployment.get("model_info", {})
for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params
data[k] = v
if k not in kwargs: # prioritize model-specific params > default router params
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
model_client = self._get_client(deployment=deployment, kwargs=kwargs)
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs})
except Exception as e:
@ -241,8 +245,10 @@ class Router:
data = deployment["litellm_params"].copy()
model_name = data["model"]
for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params
data[k] = v
if k not in kwargs: # prioritize model-specific params > default router params
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async")
self.total_calls[model_name] +=1
@ -269,8 +275,11 @@ class Router:
data = deployment["litellm_params"].copy()
for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params
data[k] = v
if k not in kwargs: # prioritize model-specific params > default router params
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
# call via litellm.completion()
return litellm.text_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore
except Exception as e:
@ -298,8 +307,11 @@ class Router:
data = deployment["litellm_params"].copy()
for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params
data[k] = v
if k not in kwargs: # prioritize model-specific params > default router params
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
########## remove -ModelID-XXXX from model ##############
original_model_string = data["model"]
# Find the index of "ModelID" in the string
@ -333,8 +345,10 @@ class Router:
kwargs.setdefault("metadata", {}).update({"model_group": model, "deployment": deployment["litellm_params"]["model"]}) # [TODO]: move to using async_function_with_fallbacks
data = deployment["litellm_params"].copy()
for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params
data[k] = v
if k not in kwargs: # prioritize model-specific params > default router params
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
model_client = self._get_client(deployment=deployment, kwargs=kwargs)
# call via litellm.embedding()
return litellm.embedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs})
@ -350,8 +364,10 @@ class Router:
data = deployment["litellm_params"].copy()
kwargs["model_info"] = deployment.get("model_info", {})
for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params
data[k] = v
if k not in kwargs: # prioritize model-specific params > default router params
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async")