test(test_caching.py): cleaning up tests

This commit is contained in:
Krrish Dholakia 2023-11-22 13:43:48 -08:00
parent 78582e158a
commit b0801f61e6
4 changed files with 51 additions and 39 deletions

View file

@ -143,7 +143,6 @@ const sidebars = {
items: [
"caching/local_caching",
"caching/redis_cache",
"caching/caching_api",
],
},
{

View file

@ -133,7 +133,7 @@ experimental = False
#### GLOBAL VARIABLES ####
llm_router: Optional[litellm.Router] = None
llm_model_list: Optional[list] = None
server_settings: dict = {}
general_settings: dict = {}
log_file = "api_log.json"
worker_config = None
master_key = None
@ -246,7 +246,7 @@ def run_ollama_serve():
def load_router_config(router: Optional[litellm.Router], config_file_path: str):
global master_key
config = {}
server_settings: dict = {}
general_settings: dict = {}
try:
if os.path.exists(config_file_path):
with open(config_file_path, 'r') as file:
@ -293,7 +293,7 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
if "ollama" in litellm_model_name:
run_ollama_serve()
return router, model_list, server_settings
return router, model_list, general_settings
async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict, config: dict):
token = f"sk-{secrets.token_urlsafe(16)}"
@ -366,13 +366,13 @@ def initialize(
config,
use_queue
):
global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, server_settings
global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, general_settings
generate_feedback_box()
user_model = model
user_debug = debug
dynamic_config = {"general": {}, user_model: {}}
if config:
llm_router, llm_model_list, server_settings = load_router_config(router=llm_router, config_file_path=config)
llm_router, llm_model_list, general_settings = load_router_config(router=llm_router, config_file_path=config)
if headers: # model-specific param
user_headers = headers
dynamic_config[user_model]["headers"] = headers
@ -480,9 +480,9 @@ async def shutdown_event():
@router.get("/v1/models", dependencies=[Depends(user_api_key_auth)])
@router.get("/models", dependencies=[Depends(user_api_key_auth)]) # if project requires model list
def model_list():
global llm_model_list, server_settings
global llm_model_list, general_settings
all_models = []
if server_settings.get("infer_model_from_keys", False):
if general_settings.get("infer_model_from_keys", False):
all_models = litellm.utils.get_valid_models()
if llm_model_list:
print(f"llm model list: {llm_model_list}")
@ -522,7 +522,7 @@ async def completion(request: Request, model: Optional[str] = None):
except:
data = json.loads(body_str)
data["model"] = (
server_settings.get("completion_model", None) # server default
general_settings.get("completion_model", None) # server default
or user_model # model name passed via cli args
or model # for azure deployments
or data["model"] # default passed in http request
@ -551,7 +551,7 @@ async def completion(request: Request, model: Optional[str] = None):
@router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)])
@router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)]) # azure compatible endpoint
async def chat_completion(request: Request, model: Optional[str] = None):
global server_settings
global general_settings
try:
body = await request.body()
body_str = body.decode()
@ -560,7 +560,7 @@ async def chat_completion(request: Request, model: Optional[str] = None):
except:
data = json.loads(body_str)
data["model"] = (
server_settings.get("completion_model", None) # server default
general_settings.get("completion_model", None) # server default
or user_model # model name passed via cli args
or model # for azure deployments
or data["model"] # default passed in http request
@ -584,8 +584,36 @@ async def chat_completion(request: Request, model: Optional[str] = None):
@router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)])
@router.post("/embeddings", dependencies=[Depends(user_api_key_auth)])
async def embeddings(request: Request):
pass
@router.post("/openai/deployments/{model:path}/embeddings", dependencies=[Depends(user_api_key_auth)]) # azure compatible
async def embeddings(request: Request, model: str):
try:
body = await request.body()
body_str = body.decode()
try:
data = ast.literal_eval(body_str)
except:
data = json.loads(body_str)
data["model"] = (
general_settings.get("embedding_model", None) # server default
or user_model # model name passed via cli args
or model # for azure http calls
or data["model"] # default passed in http request
)
if user_model:
data["model"] = user_model
## ROUTE TO CORRECT ENDPOINT ##
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
if llm_router is not None and data["model"] in router_model_names: # model in router model list
response = await llm_router.aembedding(**data)
else:
response = litellm.aembedding(**data)
return response
except Exception as e:
raise e
except Exception as e:
pass
@router.post("/key/generate", dependencies=[Depends(user_api_key_auth)])
async def generate_key_fn(request: Request):
@ -619,7 +647,7 @@ async def async_queue_request(request: Request):
except:
data = json.loads(body_str)
data["model"] = (
server_settings.get("completion_model", None) # server default
general_settings.get("completion_model", None) # server default
or user_model # model name passed via cli args
or data["model"] # default passed in http request
)

View file

@ -91,7 +91,7 @@ def test_embedding_caching():
print(f"embedding2: {embedding2}")
pytest.fail("Error occurred: Embedding caching failed")
test_embedding_caching()
# test_embedding_caching()
def test_embedding_caching_azure():
@ -212,25 +212,6 @@ def test_custom_redis_cache_with_key():
# test_custom_redis_cache_with_key()
def test_hosted_cache():
litellm.cache = Cache(type="hosted") # use api.litellm.ai for caching
messages = [{"role": "user", "content": "what is litellm arr today?"}]
response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
print("response1", response1)
response2 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
print("response2", response2)
if response1['choices'][0]['message']['content'] != response2['choices'][0]['message']['content']: # 1 and 2 should be the same
print(f"response1: {response1}")
print(f"response2: {response2}")
pytest.fail(f"Hosted cache: Response2 is not cached and the same as response 1")
litellm.cache = None
# test_hosted_cache()
# def test_redis_cache_with_ttl():
# cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
# sample_model_response_object_str = """{

View file

@ -1232,11 +1232,15 @@ def client(original_function):
cached_result = litellm.cache.get_cache(*args, **kwargs)
if cached_result != None:
print_verbose(f"Cache Hit!")
call_type = original_function.__name__
if call_type == CallTypes.completion.value and isinstance(cached_result, dict):
return convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse())
else:
return cached_result
if "detail" in cached_result:
# implies an error occurred
pass
else:
call_type = original_function.__name__
if call_type == CallTypes.completion.value and isinstance(cached_result, dict):
return convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse())
else:
return cached_result
# MODEL CALL
result = original_function(*args, **kwargs)
end_time = datetime.datetime.now()