mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
feat(router.py): support caching groups
This commit is contained in:
parent
a7822b8772
commit
84ad9f441e
4 changed files with 104 additions and 16 deletions
|
@ -70,6 +70,7 @@ class Router:
|
|||
redis_password: Optional[str] = None,
|
||||
cache_responses: Optional[bool] = False,
|
||||
cache_kwargs: dict = {}, # additional kwargs to pass to RedisCache (see caching.py)
|
||||
caching_groups: Optional[List[tuple]] = None, # if you want to cache across model groups
|
||||
## RELIABILITY ##
|
||||
num_retries: int = 0,
|
||||
timeout: Optional[float] = None,
|
||||
|
@ -112,6 +113,7 @@ class Router:
|
|||
self.default_litellm_params = default_litellm_params
|
||||
self.default_litellm_params.setdefault("timeout", timeout)
|
||||
self.default_litellm_params.setdefault("max_retries", 0)
|
||||
self.default_litellm_params.setdefault("metadata", {}).update({"caching_groups": caching_groups})
|
||||
|
||||
### CACHING ###
|
||||
cache_type: Literal["local", "redis"] = "local" # default to an in-memory cache
|
||||
|
@ -203,8 +205,10 @@ class Router:
|
|||
data = deployment["litellm_params"].copy()
|
||||
kwargs["model_info"] = deployment.get("model_info", {})
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if k not in data: # prioritize model-specific params > default router params
|
||||
data[k] = v
|
||||
if k not in kwargs: # prioritize model-specific params > default router params
|
||||
kwargs[k] = v
|
||||
elif k == "metadata":
|
||||
kwargs[k].update(v)
|
||||
model_client = self._get_client(deployment=deployment, kwargs=kwargs)
|
||||
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs})
|
||||
except Exception as e:
|
||||
|
@ -241,8 +245,10 @@ class Router:
|
|||
data = deployment["litellm_params"].copy()
|
||||
model_name = data["model"]
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if k not in data: # prioritize model-specific params > default router params
|
||||
data[k] = v
|
||||
if k not in kwargs: # prioritize model-specific params > default router params
|
||||
kwargs[k] = v
|
||||
elif k == "metadata":
|
||||
kwargs[k].update(v)
|
||||
|
||||
model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async")
|
||||
self.total_calls[model_name] +=1
|
||||
|
@ -269,8 +275,11 @@ class Router:
|
|||
|
||||
data = deployment["litellm_params"].copy()
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if k not in data: # prioritize model-specific params > default router params
|
||||
data[k] = v
|
||||
if k not in kwargs: # prioritize model-specific params > default router params
|
||||
kwargs[k] = v
|
||||
elif k == "metadata":
|
||||
kwargs[k].update(v)
|
||||
|
||||
# call via litellm.completion()
|
||||
return litellm.text_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore
|
||||
except Exception as e:
|
||||
|
@ -298,8 +307,11 @@ class Router:
|
|||
|
||||
data = deployment["litellm_params"].copy()
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if k not in data: # prioritize model-specific params > default router params
|
||||
data[k] = v
|
||||
if k not in kwargs: # prioritize model-specific params > default router params
|
||||
kwargs[k] = v
|
||||
elif k == "metadata":
|
||||
kwargs[k].update(v)
|
||||
|
||||
########## remove -ModelID-XXXX from model ##############
|
||||
original_model_string = data["model"]
|
||||
# Find the index of "ModelID" in the string
|
||||
|
@ -333,8 +345,10 @@ class Router:
|
|||
kwargs.setdefault("metadata", {}).update({"model_group": model, "deployment": deployment["litellm_params"]["model"]}) # [TODO]: move to using async_function_with_fallbacks
|
||||
data = deployment["litellm_params"].copy()
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if k not in data: # prioritize model-specific params > default router params
|
||||
data[k] = v
|
||||
if k not in kwargs: # prioritize model-specific params > default router params
|
||||
kwargs[k] = v
|
||||
elif k == "metadata":
|
||||
kwargs[k].update(v)
|
||||
model_client = self._get_client(deployment=deployment, kwargs=kwargs)
|
||||
# call via litellm.embedding()
|
||||
return litellm.embedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs})
|
||||
|
@ -350,8 +364,10 @@ class Router:
|
|||
data = deployment["litellm_params"].copy()
|
||||
kwargs["model_info"] = deployment.get("model_info", {})
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if k not in data: # prioritize model-specific params > default router params
|
||||
data[k] = v
|
||||
if k not in kwargs: # prioritize model-specific params > default router params
|
||||
kwargs[k] = v
|
||||
elif k == "metadata":
|
||||
kwargs[k].update(v)
|
||||
|
||||
model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async")
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue