test(test_caching.py): cleaning up tests

This commit is contained in:
Krrish Dholakia 2023-11-22 13:43:48 -08:00
parent 78582e158a
commit b0801f61e6
4 changed files with 51 additions and 39 deletions

View file

@ -143,7 +143,6 @@ const sidebars = {
items: [ items: [
"caching/local_caching", "caching/local_caching",
"caching/redis_cache", "caching/redis_cache",
"caching/caching_api",
], ],
}, },
{ {

View file

@ -133,7 +133,7 @@ experimental = False
#### GLOBAL VARIABLES #### #### GLOBAL VARIABLES ####
llm_router: Optional[litellm.Router] = None llm_router: Optional[litellm.Router] = None
llm_model_list: Optional[list] = None llm_model_list: Optional[list] = None
server_settings: dict = {} general_settings: dict = {}
log_file = "api_log.json" log_file = "api_log.json"
worker_config = None worker_config = None
master_key = None master_key = None
@ -246,7 +246,7 @@ def run_ollama_serve():
def load_router_config(router: Optional[litellm.Router], config_file_path: str): def load_router_config(router: Optional[litellm.Router], config_file_path: str):
global master_key global master_key
config = {} config = {}
server_settings: dict = {} general_settings: dict = {}
try: try:
if os.path.exists(config_file_path): if os.path.exists(config_file_path):
with open(config_file_path, 'r') as file: with open(config_file_path, 'r') as file:
@ -293,7 +293,7 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
if "ollama" in litellm_model_name: if "ollama" in litellm_model_name:
run_ollama_serve() run_ollama_serve()
return router, model_list, server_settings return router, model_list, general_settings
async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict, config: dict): async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict, config: dict):
token = f"sk-{secrets.token_urlsafe(16)}" token = f"sk-{secrets.token_urlsafe(16)}"
@ -366,13 +366,13 @@ def initialize(
config, config,
use_queue use_queue
): ):
global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, server_settings global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, general_settings
generate_feedback_box() generate_feedback_box()
user_model = model user_model = model
user_debug = debug user_debug = debug
dynamic_config = {"general": {}, user_model: {}} dynamic_config = {"general": {}, user_model: {}}
if config: if config:
llm_router, llm_model_list, server_settings = load_router_config(router=llm_router, config_file_path=config) llm_router, llm_model_list, general_settings = load_router_config(router=llm_router, config_file_path=config)
if headers: # model-specific param if headers: # model-specific param
user_headers = headers user_headers = headers
dynamic_config[user_model]["headers"] = headers dynamic_config[user_model]["headers"] = headers
@ -480,9 +480,9 @@ async def shutdown_event():
@router.get("/v1/models", dependencies=[Depends(user_api_key_auth)]) @router.get("/v1/models", dependencies=[Depends(user_api_key_auth)])
@router.get("/models", dependencies=[Depends(user_api_key_auth)]) # if project requires model list @router.get("/models", dependencies=[Depends(user_api_key_auth)]) # if project requires model list
def model_list(): def model_list():
global llm_model_list, server_settings global llm_model_list, general_settings
all_models = [] all_models = []
if server_settings.get("infer_model_from_keys", False): if general_settings.get("infer_model_from_keys", False):
all_models = litellm.utils.get_valid_models() all_models = litellm.utils.get_valid_models()
if llm_model_list: if llm_model_list:
print(f"llm model list: {llm_model_list}") print(f"llm model list: {llm_model_list}")
@ -522,7 +522,7 @@ async def completion(request: Request, model: Optional[str] = None):
except: except:
data = json.loads(body_str) data = json.loads(body_str)
data["model"] = ( data["model"] = (
server_settings.get("completion_model", None) # server default general_settings.get("completion_model", None) # server default
or user_model # model name passed via cli args or user_model # model name passed via cli args
or model # for azure deployments or model # for azure deployments
or data["model"] # default passed in http request or data["model"] # default passed in http request
@ -551,7 +551,7 @@ async def completion(request: Request, model: Optional[str] = None):
@router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)]) @router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)])
@router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)]) # azure compatible endpoint @router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)]) # azure compatible endpoint
async def chat_completion(request: Request, model: Optional[str] = None): async def chat_completion(request: Request, model: Optional[str] = None):
global server_settings global general_settings
try: try:
body = await request.body() body = await request.body()
body_str = body.decode() body_str = body.decode()
@ -560,7 +560,7 @@ async def chat_completion(request: Request, model: Optional[str] = None):
except: except:
data = json.loads(body_str) data = json.loads(body_str)
data["model"] = ( data["model"] = (
server_settings.get("completion_model", None) # server default general_settings.get("completion_model", None) # server default
or user_model # model name passed via cli args or user_model # model name passed via cli args
or model # for azure deployments or model # for azure deployments
or data["model"] # default passed in http request or data["model"] # default passed in http request
@ -584,8 +584,36 @@ async def chat_completion(request: Request, model: Optional[str] = None):
@router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)]) @router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)])
@router.post("/embeddings", dependencies=[Depends(user_api_key_auth)]) @router.post("/embeddings", dependencies=[Depends(user_api_key_auth)])
async def embeddings(request: Request): @router.post("/openai/deployments/{model:path}/embeddings", dependencies=[Depends(user_api_key_auth)]) # azure compatible
pass async def embeddings(request: Request, model: str):
try:
body = await request.body()
body_str = body.decode()
try:
data = ast.literal_eval(body_str)
except:
data = json.loads(body_str)
data["model"] = (
general_settings.get("embedding_model", None) # server default
or user_model # model name passed via cli args
or model # for azure http calls
or data["model"] # default passed in http request
)
if user_model:
data["model"] = user_model
## ROUTE TO CORRECT ENDPOINT ##
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
if llm_router is not None and data["model"] in router_model_names: # model in router model list
response = await llm_router.aembedding(**data)
else:
response = litellm.aembedding(**data)
return response
except Exception as e:
raise e
except Exception as e:
pass
@router.post("/key/generate", dependencies=[Depends(user_api_key_auth)]) @router.post("/key/generate", dependencies=[Depends(user_api_key_auth)])
async def generate_key_fn(request: Request): async def generate_key_fn(request: Request):
@ -619,7 +647,7 @@ async def async_queue_request(request: Request):
except: except:
data = json.loads(body_str) data = json.loads(body_str)
data["model"] = ( data["model"] = (
server_settings.get("completion_model", None) # server default general_settings.get("completion_model", None) # server default
or user_model # model name passed via cli args or user_model # model name passed via cli args
or data["model"] # default passed in http request or data["model"] # default passed in http request
) )

View file

@ -91,7 +91,7 @@ def test_embedding_caching():
print(f"embedding2: {embedding2}") print(f"embedding2: {embedding2}")
pytest.fail("Error occurred: Embedding caching failed") pytest.fail("Error occurred: Embedding caching failed")
test_embedding_caching() # test_embedding_caching()
def test_embedding_caching_azure(): def test_embedding_caching_azure():
@ -212,25 +212,6 @@ def test_custom_redis_cache_with_key():
# test_custom_redis_cache_with_key() # test_custom_redis_cache_with_key()
def test_hosted_cache():
litellm.cache = Cache(type="hosted") # use api.litellm.ai for caching
messages = [{"role": "user", "content": "what is litellm arr today?"}]
response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
print("response1", response1)
response2 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
print("response2", response2)
if response1['choices'][0]['message']['content'] != response2['choices'][0]['message']['content']: # 1 and 2 should be the same
print(f"response1: {response1}")
print(f"response2: {response2}")
pytest.fail(f"Hosted cache: Response2 is not cached and the same as response 1")
litellm.cache = None
# test_hosted_cache()
# def test_redis_cache_with_ttl(): # def test_redis_cache_with_ttl():
# cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD']) # cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
# sample_model_response_object_str = """{ # sample_model_response_object_str = """{

View file

@ -1232,11 +1232,15 @@ def client(original_function):
cached_result = litellm.cache.get_cache(*args, **kwargs) cached_result = litellm.cache.get_cache(*args, **kwargs)
if cached_result != None: if cached_result != None:
print_verbose(f"Cache Hit!") print_verbose(f"Cache Hit!")
call_type = original_function.__name__ if "detail" in cached_result:
if call_type == CallTypes.completion.value and isinstance(cached_result, dict): # implies an error occurred
return convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse()) pass
else: else:
return cached_result call_type = original_function.__name__
if call_type == CallTypes.completion.value and isinstance(cached_result, dict):
return convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse())
else:
return cached_result
# MODEL CALL # MODEL CALL
result = original_function(*args, **kwargs) result = original_function(*args, **kwargs)
end_time = datetime.datetime.now() end_time = datetime.datetime.now()