forked from phoenix/litellm-mirror
feat(router.py): support caching groups
This commit is contained in:
parent
a7822b8772
commit
84ad9f441e
4 changed files with 104 additions and 16 deletions
|
@ -233,15 +233,29 @@ class Cache:
|
||||||
# check if param == model and model_group is passed in, then override model with model_group
|
# check if param == model and model_group is passed in, then override model with model_group
|
||||||
if param == "model":
|
if param == "model":
|
||||||
model_group = None
|
model_group = None
|
||||||
|
caching_group = None
|
||||||
metadata = kwargs.get("metadata", None)
|
metadata = kwargs.get("metadata", None)
|
||||||
litellm_params = kwargs.get("litellm_params", {})
|
litellm_params = kwargs.get("litellm_params", {})
|
||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
model_group = metadata.get("model_group")
|
model_group = metadata.get("model_group")
|
||||||
|
model_group = metadata.get("model_group", None)
|
||||||
|
caching_groups = metadata.get("caching_groups", None)
|
||||||
|
if caching_groups:
|
||||||
|
for group in caching_groups:
|
||||||
|
if model_group in group:
|
||||||
|
caching_group = group
|
||||||
|
break
|
||||||
if litellm_params is not None:
|
if litellm_params is not None:
|
||||||
metadata = litellm_params.get("metadata", None)
|
metadata = litellm_params.get("metadata", None)
|
||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
model_group = metadata.get("model_group", None)
|
model_group = metadata.get("model_group", None)
|
||||||
param_value = model_group or kwargs[param] # use model_group if it exists, else use kwargs["model"]
|
caching_groups = metadata.get("caching_groups", None)
|
||||||
|
if caching_groups:
|
||||||
|
for group in caching_groups:
|
||||||
|
if model_group in group:
|
||||||
|
caching_group = group
|
||||||
|
break
|
||||||
|
param_value = caching_group or model_group or kwargs[param] # use caching_group, if set then model_group if it exists, else use kwargs["model"]
|
||||||
else:
|
else:
|
||||||
if kwargs[param] is None:
|
if kwargs[param] is None:
|
||||||
continue # ignore None params
|
continue # ignore None params
|
||||||
|
|
|
@ -357,7 +357,7 @@ def completion(
|
||||||
client = kwargs.get("client", None)
|
client = kwargs.get("client", None)
|
||||||
######## end of unpacking kwargs ###########
|
######## end of unpacking kwargs ###########
|
||||||
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries"]
|
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries"]
|
||||||
litellm_params = ["metadata", "acompletion", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "model_info", "proxy_server_request", "preset_cache_key"]
|
litellm_params = ["metadata", "acompletion", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "model_info", "proxy_server_request", "preset_cache_key", "caching_groups"]
|
||||||
default_params = openai_params + litellm_params
|
default_params = openai_params + litellm_params
|
||||||
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
|
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
|
||||||
if mock_response:
|
if mock_response:
|
||||||
|
@ -1824,7 +1824,7 @@ def embedding(
|
||||||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||||
aembedding = kwargs.get("aembedding", None)
|
aembedding = kwargs.get("aembedding", None)
|
||||||
openai_params = ["user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "max_retries", "encoding_format"]
|
openai_params = ["user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "max_retries", "encoding_format"]
|
||||||
litellm_params = ["metadata", "aembedding", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key"]
|
litellm_params = ["metadata", "aembedding", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key", "caching_groups"]
|
||||||
default_params = openai_params + litellm_params
|
default_params = openai_params + litellm_params
|
||||||
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
|
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
|
||||||
|
|
||||||
|
|
|
@ -70,6 +70,7 @@ class Router:
|
||||||
redis_password: Optional[str] = None,
|
redis_password: Optional[str] = None,
|
||||||
cache_responses: Optional[bool] = False,
|
cache_responses: Optional[bool] = False,
|
||||||
cache_kwargs: dict = {}, # additional kwargs to pass to RedisCache (see caching.py)
|
cache_kwargs: dict = {}, # additional kwargs to pass to RedisCache (see caching.py)
|
||||||
|
caching_groups: Optional[List[tuple]] = None, # if you want to cache across model groups
|
||||||
## RELIABILITY ##
|
## RELIABILITY ##
|
||||||
num_retries: int = 0,
|
num_retries: int = 0,
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
|
@ -112,6 +113,7 @@ class Router:
|
||||||
self.default_litellm_params = default_litellm_params
|
self.default_litellm_params = default_litellm_params
|
||||||
self.default_litellm_params.setdefault("timeout", timeout)
|
self.default_litellm_params.setdefault("timeout", timeout)
|
||||||
self.default_litellm_params.setdefault("max_retries", 0)
|
self.default_litellm_params.setdefault("max_retries", 0)
|
||||||
|
self.default_litellm_params.setdefault("metadata", {}).update({"caching_groups": caching_groups})
|
||||||
|
|
||||||
### CACHING ###
|
### CACHING ###
|
||||||
cache_type: Literal["local", "redis"] = "local" # default to an in-memory cache
|
cache_type: Literal["local", "redis"] = "local" # default to an in-memory cache
|
||||||
|
@ -203,8 +205,10 @@ class Router:
|
||||||
data = deployment["litellm_params"].copy()
|
data = deployment["litellm_params"].copy()
|
||||||
kwargs["model_info"] = deployment.get("model_info", {})
|
kwargs["model_info"] = deployment.get("model_info", {})
|
||||||
for k, v in self.default_litellm_params.items():
|
for k, v in self.default_litellm_params.items():
|
||||||
if k not in data: # prioritize model-specific params > default router params
|
if k not in kwargs: # prioritize model-specific params > default router params
|
||||||
data[k] = v
|
kwargs[k] = v
|
||||||
|
elif k == "metadata":
|
||||||
|
kwargs[k].update(v)
|
||||||
model_client = self._get_client(deployment=deployment, kwargs=kwargs)
|
model_client = self._get_client(deployment=deployment, kwargs=kwargs)
|
||||||
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs})
|
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -241,8 +245,10 @@ class Router:
|
||||||
data = deployment["litellm_params"].copy()
|
data = deployment["litellm_params"].copy()
|
||||||
model_name = data["model"]
|
model_name = data["model"]
|
||||||
for k, v in self.default_litellm_params.items():
|
for k, v in self.default_litellm_params.items():
|
||||||
if k not in data: # prioritize model-specific params > default router params
|
if k not in kwargs: # prioritize model-specific params > default router params
|
||||||
data[k] = v
|
kwargs[k] = v
|
||||||
|
elif k == "metadata":
|
||||||
|
kwargs[k].update(v)
|
||||||
|
|
||||||
model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async")
|
model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async")
|
||||||
self.total_calls[model_name] +=1
|
self.total_calls[model_name] +=1
|
||||||
|
@ -269,8 +275,11 @@ class Router:
|
||||||
|
|
||||||
data = deployment["litellm_params"].copy()
|
data = deployment["litellm_params"].copy()
|
||||||
for k, v in self.default_litellm_params.items():
|
for k, v in self.default_litellm_params.items():
|
||||||
if k not in data: # prioritize model-specific params > default router params
|
if k not in kwargs: # prioritize model-specific params > default router params
|
||||||
data[k] = v
|
kwargs[k] = v
|
||||||
|
elif k == "metadata":
|
||||||
|
kwargs[k].update(v)
|
||||||
|
|
||||||
# call via litellm.completion()
|
# call via litellm.completion()
|
||||||
return litellm.text_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore
|
return litellm.text_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -298,8 +307,11 @@ class Router:
|
||||||
|
|
||||||
data = deployment["litellm_params"].copy()
|
data = deployment["litellm_params"].copy()
|
||||||
for k, v in self.default_litellm_params.items():
|
for k, v in self.default_litellm_params.items():
|
||||||
if k not in data: # prioritize model-specific params > default router params
|
if k not in kwargs: # prioritize model-specific params > default router params
|
||||||
data[k] = v
|
kwargs[k] = v
|
||||||
|
elif k == "metadata":
|
||||||
|
kwargs[k].update(v)
|
||||||
|
|
||||||
########## remove -ModelID-XXXX from model ##############
|
########## remove -ModelID-XXXX from model ##############
|
||||||
original_model_string = data["model"]
|
original_model_string = data["model"]
|
||||||
# Find the index of "ModelID" in the string
|
# Find the index of "ModelID" in the string
|
||||||
|
@ -333,8 +345,10 @@ class Router:
|
||||||
kwargs.setdefault("metadata", {}).update({"model_group": model, "deployment": deployment["litellm_params"]["model"]}) # [TODO]: move to using async_function_with_fallbacks
|
kwargs.setdefault("metadata", {}).update({"model_group": model, "deployment": deployment["litellm_params"]["model"]}) # [TODO]: move to using async_function_with_fallbacks
|
||||||
data = deployment["litellm_params"].copy()
|
data = deployment["litellm_params"].copy()
|
||||||
for k, v in self.default_litellm_params.items():
|
for k, v in self.default_litellm_params.items():
|
||||||
if k not in data: # prioritize model-specific params > default router params
|
if k not in kwargs: # prioritize model-specific params > default router params
|
||||||
data[k] = v
|
kwargs[k] = v
|
||||||
|
elif k == "metadata":
|
||||||
|
kwargs[k].update(v)
|
||||||
model_client = self._get_client(deployment=deployment, kwargs=kwargs)
|
model_client = self._get_client(deployment=deployment, kwargs=kwargs)
|
||||||
# call via litellm.embedding()
|
# call via litellm.embedding()
|
||||||
return litellm.embedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs})
|
return litellm.embedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs})
|
||||||
|
@ -350,8 +364,10 @@ class Router:
|
||||||
data = deployment["litellm_params"].copy()
|
data = deployment["litellm_params"].copy()
|
||||||
kwargs["model_info"] = deployment.get("model_info", {})
|
kwargs["model_info"] = deployment.get("model_info", {})
|
||||||
for k, v in self.default_litellm_params.items():
|
for k, v in self.default_litellm_params.items():
|
||||||
if k not in data: # prioritize model-specific params > default router params
|
if k not in kwargs: # prioritize model-specific params > default router params
|
||||||
data[k] = v
|
kwargs[k] = v
|
||||||
|
elif k == "metadata":
|
||||||
|
kwargs[k].update(v)
|
||||||
|
|
||||||
model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async")
|
model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async")
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,8 @@ import litellm
|
||||||
from litellm import Router
|
from litellm import Router
|
||||||
|
|
||||||
## Scenarios
|
## Scenarios
|
||||||
## 1. 2 models - openai + azure - 1 model group "gpt-3.5-turbo", assert cache key is the model group
|
## 1. 2 models - openai + azure - 1 model group "gpt-3.5-turbo",
|
||||||
|
## 2. 2 models - openai, azure - 2 diff model groups, 1 caching group
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_acompletion_caching_on_router():
|
async def test_acompletion_caching_on_router():
|
||||||
|
@ -67,3 +68,60 @@ async def test_acompletion_caching_on_router():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_acompletion_caching_on_router_caching_groups():
|
||||||
|
# tests acompletion + caching on router
|
||||||
|
try:
|
||||||
|
litellm.set_verbose = True
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "openai-gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo-0613",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
"tpm": 100000,
|
||||||
|
"rpm": 10000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "azure-gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION")
|
||||||
|
},
|
||||||
|
"tpm": 100000,
|
||||||
|
"rpm": 10000,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": f"write a one sentence poem {time.time()}?"}
|
||||||
|
]
|
||||||
|
start_time = time.time()
|
||||||
|
router = Router(model_list=model_list,
|
||||||
|
redis_host=os.environ["REDIS_HOST"],
|
||||||
|
redis_password=os.environ["REDIS_PASSWORD"],
|
||||||
|
redis_port=os.environ["REDIS_PORT"],
|
||||||
|
cache_responses=True,
|
||||||
|
timeout=30,
|
||||||
|
routing_strategy="simple-shuffle",
|
||||||
|
caching_groups=[("openai-gpt-3.5-turbo", "azure-gpt-3.5-turbo")])
|
||||||
|
response1 = await router.acompletion(model="openai-gpt-3.5-turbo", messages=messages, temperature=1)
|
||||||
|
print(f"response1: {response1}")
|
||||||
|
await asyncio.sleep(1) # add cache is async, async sleep for cache to get set
|
||||||
|
response2 = await router.acompletion(model="azure-gpt-3.5-turbo", messages=messages, temperature=1)
|
||||||
|
print(f"response2: {response2}")
|
||||||
|
assert response1.id == response2.id
|
||||||
|
assert len(response1.choices[0].message.content) > 0
|
||||||
|
assert response1.choices[0].message.content == response2.choices[0].message.content
|
||||||
|
router.reset()
|
||||||
|
except litellm.Timeout as e:
|
||||||
|
end_time = time.time()
|
||||||
|
print(f"timeout error occurred: {end_time - start_time}")
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
Loading…
Add table
Add a link
Reference in a new issue