fix(router.py): fix default caching response value

This commit is contained in:
Krrish Dholakia 2023-12-07 13:44:21 -08:00
parent 418099085c
commit e5638e2c5d
7 changed files with 127 additions and 7 deletions

View file

@ -68,7 +68,7 @@ You can now test this by starting your proxy:
litellm --config /path/to/config.yaml litellm --config /path/to/config.yaml
``` ```
[Quick Test Proxy](./simple_proxy.md#using-litellm-proxy---curl-request-openai-package) [Quick Test Proxy](./proxy/quick_start#using-litellm-proxy---curl-request-openai-package-langchain-langchain-js)
## Infisical Secret Manager ## Infisical Secret Manager
Integrates with [Infisical's Secret Manager](https://infisical.com/) for secure storage and retrieval of API keys and sensitive data. Integrates with [Infisical's Secret Manager](https://infisical.com/) for secure storage and retrieval of API keys and sensitive data.

View file

@ -11,6 +11,7 @@
import os import os
import inspect import inspect
import redis, litellm import redis, litellm
from typing import List, Optional
def _get_redis_kwargs(): def _get_redis_kwargs():
arg_spec = inspect.getfullargspec(redis.Redis) arg_spec = inspect.getfullargspec(redis.Redis)
@ -67,6 +68,13 @@ def get_redis_url_from_environment():
return f"redis://{redis_password}{os.environ['REDIS_HOST']}:{os.environ['REDIS_PORT']}" return f"redis://{redis_password}{os.environ['REDIS_HOST']}:{os.environ['REDIS_PORT']}"
def get_redis_client(**env_overrides): def get_redis_client(**env_overrides):
### check if "os.environ/<key-name>" passed in
for k, v in env_overrides.items():
if v.startswith("os.environ/"):
v = v.replace("os.environ/", "")
value = litellm.get_secret(v)
env_overrides[k] = value
redis_kwargs = { redis_kwargs = {
**_redis_kwargs_from_environment(), **_redis_kwargs_from_environment(),
**env_overrides, **env_overrides,
@ -81,5 +89,5 @@ def get_redis_client(**env_overrides):
return redis.Redis.from_url(**redis_kwargs) return redis.Redis.from_url(**redis_kwargs)
elif "host" not in redis_kwargs or redis_kwargs['host'] is None: elif "host" not in redis_kwargs or redis_kwargs['host'] is None:
raise ValueError("Either 'host' or 'url' must be specified for redis.") raise ValueError("Either 'host' or 'url' must be specified for redis.")
litellm.print_verbose(f"redis_kwargs: {redis_kwargs}")
return redis.Redis(**redis_kwargs) return redis.Redis(**redis_kwargs)

View file

@ -416,7 +416,7 @@ def run_ollama_serve():
""") """)
def load_router_config(router: Optional[litellm.Router], config_file_path: str): def load_router_config(router: Optional[litellm.Router], config_file_path: str):
global master_key, user_config_file_path, otel_logging, user_custom_auth global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path
config = {} config = {}
try: try:
if os.path.exists(config_file_path): if os.path.exists(config_file_path):
@ -492,7 +492,7 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
print(f"{blue_color_code}Cache Password:{reset_color_code} {cache_password}") print(f"{blue_color_code}Cache Password:{reset_color_code} {cache_password}")
print() print()
## to pass a complete url, just set it as `os.environ[REDIS_URL] = <your-redis-url>`, _redis.py checks for REDIS specific environment variables ## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = <your-redis-url>`, _redis.py checks for REDIS specific environment variables
litellm.cache = Cache( litellm.cache = Cache(
type=cache_type, type=cache_type,
host=cache_host, host=cache_host,
@ -929,6 +929,7 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
else: else:
data["metadata"] = {"user_api_key": user_api_key_dict.api_key} data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
data["metadata"]["headers"] = dict(request.headers) data["metadata"]["headers"] = dict(request.headers)
global user_temperature, user_request_timeout, user_max_tokens, user_api_base global user_temperature, user_request_timeout, user_max_tokens, user_api_base
# override with user settings, these are params passed via cli # override with user settings, these are params passed via cli
if user_temperature: if user_temperature:

View file

@ -53,7 +53,7 @@ class Router:
``` ```
""" """
model_names: List = [] model_names: List = []
cache_responses: bool = False cache_responses: Optional[bool] = None
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
num_retries: int = 0 num_retries: int = 0
tenacity = None tenacity = None
@ -65,7 +65,7 @@ class Router:
redis_host: Optional[str] = None, redis_host: Optional[str] = None,
redis_port: Optional[int] = None, redis_port: Optional[int] = None,
redis_password: Optional[str] = None, redis_password: Optional[str] = None,
cache_responses: bool = False, cache_responses: Optional[bool] = None,
cache_kwargs: dict = {}, # additional kwargs to pass to RedisCache (see caching.py) cache_kwargs: dict = {}, # additional kwargs to pass to RedisCache (see caching.py)
## RELIABILITY ## ## RELIABILITY ##
num_retries: int = 0, num_retries: int = 0,

View file

@ -0,0 +1,74 @@
#### What this tests ####
# This tests using caching w/ litellm which requires SSL=True
import sys, os
import time
import traceback
from dotenv import load_dotenv
load_dotenv()
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm
from litellm import embedding, completion, Router
from litellm.caching import Cache
messages = [{"role": "user", "content": f"who is ishaan {time.time()}"}]
def test_caching_v2(): # test in memory cache
try:
litellm.cache = Cache(type="redis", host="os.environ/REDIS_HOST_2", port="os.environ/REDIS_PORT_2", password="os.environ/REDIS_PASSWORD_2", ssl="os.environ/REDIS_SSL")
response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
response2 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
print(f"response1: {response1}")
print(f"response2: {response2}")
litellm.cache = None # disable cache
if response2['choices'][0]['message']['content'] != response1['choices'][0]['message']['content']:
print(f"response1: {response1}")
print(f"response2: {response2}")
raise Exception()
except Exception as e:
print(f"error occurred: {traceback.format_exc()}")
pytest.fail(f"Error occurred: {e}")
# test_caching_v2()
def test_caching_router():
"""
Test scenario where litellm.cache is set but kwargs("caching") is not. This should still return a cache hit.
"""
try:
model_list = [
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
}
]
litellm.cache = Cache(type="redis", host="os.environ/REDIS_HOST_2", port="os.environ/REDIS_PORT_2", password="os.environ/REDIS_PASSWORD_2", ssl="os.environ/REDIS_SSL")
router = Router(model_list=model_list,
routing_strategy="simple-shuffle",
set_verbose=False,
num_retries=1) # type: ignore
response1 = completion(model="gpt-3.5-turbo", messages=messages)
response2 = completion(model="gpt-3.5-turbo", messages=messages)
if response2['choices'][0]['message']['content'] != response1['choices'][0]['message']['content']:
print(f"response1: {response1}")
print(f"response2: {response2}")
litellm.cache = None # disable cache
assert response2['choices'][0]['message']['content'] == response1['choices'][0]['message']['content']
except Exception as e:
print(f"error occurred: {traceback.format_exc()}")
pytest.fail(f"Error occurred: {e}")
test_caching_router()

View file

@ -0,0 +1,36 @@
#### What this tests ####
# This tests using caching w/ litellm which requires SSL=True
import sys, os
import time
import traceback
from dotenv import load_dotenv
load_dotenv()
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm
from litellm import embedding, completion
from litellm.caching import Cache
messages = [{"role": "user", "content": f"who is ishaan {time.time()}"}]
def test_caching_v2(): # test in memory cache
try:
response1 = completion(model="openai/gpt-3.5-turbo", messages=messages, api_base="http://0.0.0.0:8000")
response2 = completion(model="openai/gpt-3.5-turbo", messages=messages, api_base="http://0.0.0.0:8000")
print(f"response1: {response1}")
print(f"response2: {response2}")
litellm.cache = None # disable cache
if response2['choices'][0]['message']['content'] != response1['choices'][0]['message']['content']:
print(f"response1: {response1}")
print(f"response2: {response2}")
raise Exception()
except Exception as e:
print(f"error occurred: {traceback.format_exc()}")
pytest.fail(f"Error occurred: {e}")
test_caching_v2()

View file

@ -1509,7 +1509,8 @@ def client(original_function):
# if caching is false, don't run this # if caching is false, don't run this
if (kwargs.get("caching", None) is None and litellm.cache is not None) or kwargs.get("caching", False) == True: # allow users to control returning cached responses from the completion function if (kwargs.get("caching", None) is None and litellm.cache is not None) or kwargs.get("caching", False) == True: # allow users to control returning cached responses from the completion function
# checking cache # checking cache
if (litellm.cache != None): print_verbose(f"INSIDE CHECKING CACHE")
if litellm.cache is not None:
print_verbose(f"Checking Cache") print_verbose(f"Checking Cache")
cached_result = litellm.cache.get_cache(*args, **kwargs) cached_result = litellm.cache.get_cache(*args, **kwargs)
if cached_result != None: if cached_result != None: