mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
test(test_caching.py): cleaning up tests
This commit is contained in:
parent
78582e158a
commit
b0801f61e6
4 changed files with 51 additions and 39 deletions
|
@ -143,7 +143,6 @@ const sidebars = {
|
||||||
items: [
|
items: [
|
||||||
"caching/local_caching",
|
"caching/local_caching",
|
||||||
"caching/redis_cache",
|
"caching/redis_cache",
|
||||||
"caching/caching_api",
|
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
|
@ -133,7 +133,7 @@ experimental = False
|
||||||
#### GLOBAL VARIABLES ####
|
#### GLOBAL VARIABLES ####
|
||||||
llm_router: Optional[litellm.Router] = None
|
llm_router: Optional[litellm.Router] = None
|
||||||
llm_model_list: Optional[list] = None
|
llm_model_list: Optional[list] = None
|
||||||
server_settings: dict = {}
|
general_settings: dict = {}
|
||||||
log_file = "api_log.json"
|
log_file = "api_log.json"
|
||||||
worker_config = None
|
worker_config = None
|
||||||
master_key = None
|
master_key = None
|
||||||
|
@ -246,7 +246,7 @@ def run_ollama_serve():
|
||||||
def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
global master_key
|
global master_key
|
||||||
config = {}
|
config = {}
|
||||||
server_settings: dict = {}
|
general_settings: dict = {}
|
||||||
try:
|
try:
|
||||||
if os.path.exists(config_file_path):
|
if os.path.exists(config_file_path):
|
||||||
with open(config_file_path, 'r') as file:
|
with open(config_file_path, 'r') as file:
|
||||||
|
@ -293,7 +293,7 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
if "ollama" in litellm_model_name:
|
if "ollama" in litellm_model_name:
|
||||||
run_ollama_serve()
|
run_ollama_serve()
|
||||||
|
|
||||||
return router, model_list, server_settings
|
return router, model_list, general_settings
|
||||||
|
|
||||||
async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict, config: dict):
|
async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict, config: dict):
|
||||||
token = f"sk-{secrets.token_urlsafe(16)}"
|
token = f"sk-{secrets.token_urlsafe(16)}"
|
||||||
|
@ -366,13 +366,13 @@ def initialize(
|
||||||
config,
|
config,
|
||||||
use_queue
|
use_queue
|
||||||
):
|
):
|
||||||
global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, server_settings
|
global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, general_settings
|
||||||
generate_feedback_box()
|
generate_feedback_box()
|
||||||
user_model = model
|
user_model = model
|
||||||
user_debug = debug
|
user_debug = debug
|
||||||
dynamic_config = {"general": {}, user_model: {}}
|
dynamic_config = {"general": {}, user_model: {}}
|
||||||
if config:
|
if config:
|
||||||
llm_router, llm_model_list, server_settings = load_router_config(router=llm_router, config_file_path=config)
|
llm_router, llm_model_list, general_settings = load_router_config(router=llm_router, config_file_path=config)
|
||||||
if headers: # model-specific param
|
if headers: # model-specific param
|
||||||
user_headers = headers
|
user_headers = headers
|
||||||
dynamic_config[user_model]["headers"] = headers
|
dynamic_config[user_model]["headers"] = headers
|
||||||
|
@ -480,9 +480,9 @@ async def shutdown_event():
|
||||||
@router.get("/v1/models", dependencies=[Depends(user_api_key_auth)])
|
@router.get("/v1/models", dependencies=[Depends(user_api_key_auth)])
|
||||||
@router.get("/models", dependencies=[Depends(user_api_key_auth)]) # if project requires model list
|
@router.get("/models", dependencies=[Depends(user_api_key_auth)]) # if project requires model list
|
||||||
def model_list():
|
def model_list():
|
||||||
global llm_model_list, server_settings
|
global llm_model_list, general_settings
|
||||||
all_models = []
|
all_models = []
|
||||||
if server_settings.get("infer_model_from_keys", False):
|
if general_settings.get("infer_model_from_keys", False):
|
||||||
all_models = litellm.utils.get_valid_models()
|
all_models = litellm.utils.get_valid_models()
|
||||||
if llm_model_list:
|
if llm_model_list:
|
||||||
print(f"llm model list: {llm_model_list}")
|
print(f"llm model list: {llm_model_list}")
|
||||||
|
@ -522,7 +522,7 @@ async def completion(request: Request, model: Optional[str] = None):
|
||||||
except:
|
except:
|
||||||
data = json.loads(body_str)
|
data = json.loads(body_str)
|
||||||
data["model"] = (
|
data["model"] = (
|
||||||
server_settings.get("completion_model", None) # server default
|
general_settings.get("completion_model", None) # server default
|
||||||
or user_model # model name passed via cli args
|
or user_model # model name passed via cli args
|
||||||
or model # for azure deployments
|
or model # for azure deployments
|
||||||
or data["model"] # default passed in http request
|
or data["model"] # default passed in http request
|
||||||
|
@ -551,7 +551,7 @@ async def completion(request: Request, model: Optional[str] = None):
|
||||||
@router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)])
|
@router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)])
|
||||||
@router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)]) # azure compatible endpoint
|
@router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)]) # azure compatible endpoint
|
||||||
async def chat_completion(request: Request, model: Optional[str] = None):
|
async def chat_completion(request: Request, model: Optional[str] = None):
|
||||||
global server_settings
|
global general_settings
|
||||||
try:
|
try:
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
body_str = body.decode()
|
body_str = body.decode()
|
||||||
|
@ -560,7 +560,7 @@ async def chat_completion(request: Request, model: Optional[str] = None):
|
||||||
except:
|
except:
|
||||||
data = json.loads(body_str)
|
data = json.loads(body_str)
|
||||||
data["model"] = (
|
data["model"] = (
|
||||||
server_settings.get("completion_model", None) # server default
|
general_settings.get("completion_model", None) # server default
|
||||||
or user_model # model name passed via cli args
|
or user_model # model name passed via cli args
|
||||||
or model # for azure deployments
|
or model # for azure deployments
|
||||||
or data["model"] # default passed in http request
|
or data["model"] # default passed in http request
|
||||||
|
@ -584,8 +584,36 @@ async def chat_completion(request: Request, model: Optional[str] = None):
|
||||||
|
|
||||||
@router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)])
|
@router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)])
|
||||||
@router.post("/embeddings", dependencies=[Depends(user_api_key_auth)])
|
@router.post("/embeddings", dependencies=[Depends(user_api_key_auth)])
|
||||||
async def embeddings(request: Request):
|
@router.post("/openai/deployments/{model:path}/embeddings", dependencies=[Depends(user_api_key_auth)]) # azure compatible
|
||||||
pass
|
async def embeddings(request: Request, model: str):
|
||||||
|
try:
|
||||||
|
body = await request.body()
|
||||||
|
body_str = body.decode()
|
||||||
|
try:
|
||||||
|
data = ast.literal_eval(body_str)
|
||||||
|
except:
|
||||||
|
data = json.loads(body_str)
|
||||||
|
|
||||||
|
data["model"] = (
|
||||||
|
general_settings.get("embedding_model", None) # server default
|
||||||
|
or user_model # model name passed via cli args
|
||||||
|
or model # for azure http calls
|
||||||
|
or data["model"] # default passed in http request
|
||||||
|
)
|
||||||
|
if user_model:
|
||||||
|
data["model"] = user_model
|
||||||
|
|
||||||
|
## ROUTE TO CORRECT ENDPOINT ##
|
||||||
|
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
|
||||||
|
if llm_router is not None and data["model"] in router_model_names: # model in router model list
|
||||||
|
response = await llm_router.aembedding(**data)
|
||||||
|
else:
|
||||||
|
response = litellm.aembedding(**data)
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
@router.post("/key/generate", dependencies=[Depends(user_api_key_auth)])
|
@router.post("/key/generate", dependencies=[Depends(user_api_key_auth)])
|
||||||
async def generate_key_fn(request: Request):
|
async def generate_key_fn(request: Request):
|
||||||
|
@ -619,7 +647,7 @@ async def async_queue_request(request: Request):
|
||||||
except:
|
except:
|
||||||
data = json.loads(body_str)
|
data = json.loads(body_str)
|
||||||
data["model"] = (
|
data["model"] = (
|
||||||
server_settings.get("completion_model", None) # server default
|
general_settings.get("completion_model", None) # server default
|
||||||
or user_model # model name passed via cli args
|
or user_model # model name passed via cli args
|
||||||
or data["model"] # default passed in http request
|
or data["model"] # default passed in http request
|
||||||
)
|
)
|
||||||
|
|
|
@ -91,7 +91,7 @@ def test_embedding_caching():
|
||||||
print(f"embedding2: {embedding2}")
|
print(f"embedding2: {embedding2}")
|
||||||
pytest.fail("Error occurred: Embedding caching failed")
|
pytest.fail("Error occurred: Embedding caching failed")
|
||||||
|
|
||||||
test_embedding_caching()
|
# test_embedding_caching()
|
||||||
|
|
||||||
|
|
||||||
def test_embedding_caching_azure():
|
def test_embedding_caching_azure():
|
||||||
|
@ -212,25 +212,6 @@ def test_custom_redis_cache_with_key():
|
||||||
|
|
||||||
# test_custom_redis_cache_with_key()
|
# test_custom_redis_cache_with_key()
|
||||||
|
|
||||||
def test_hosted_cache():
|
|
||||||
litellm.cache = Cache(type="hosted") # use api.litellm.ai for caching
|
|
||||||
|
|
||||||
messages = [{"role": "user", "content": "what is litellm arr today?"}]
|
|
||||||
response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
|
|
||||||
print("response1", response1)
|
|
||||||
|
|
||||||
response2 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
|
|
||||||
print("response2", response2)
|
|
||||||
|
|
||||||
if response1['choices'][0]['message']['content'] != response2['choices'][0]['message']['content']: # 1 and 2 should be the same
|
|
||||||
print(f"response1: {response1}")
|
|
||||||
print(f"response2: {response2}")
|
|
||||||
pytest.fail(f"Hosted cache: Response2 is not cached and the same as response 1")
|
|
||||||
litellm.cache = None
|
|
||||||
|
|
||||||
# test_hosted_cache()
|
|
||||||
|
|
||||||
|
|
||||||
# def test_redis_cache_with_ttl():
|
# def test_redis_cache_with_ttl():
|
||||||
# cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
|
# cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
|
||||||
# sample_model_response_object_str = """{
|
# sample_model_response_object_str = """{
|
||||||
|
|
|
@ -1232,11 +1232,15 @@ def client(original_function):
|
||||||
cached_result = litellm.cache.get_cache(*args, **kwargs)
|
cached_result = litellm.cache.get_cache(*args, **kwargs)
|
||||||
if cached_result != None:
|
if cached_result != None:
|
||||||
print_verbose(f"Cache Hit!")
|
print_verbose(f"Cache Hit!")
|
||||||
call_type = original_function.__name__
|
if "detail" in cached_result:
|
||||||
if call_type == CallTypes.completion.value and isinstance(cached_result, dict):
|
# implies an error occurred
|
||||||
return convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse())
|
pass
|
||||||
else:
|
else:
|
||||||
return cached_result
|
call_type = original_function.__name__
|
||||||
|
if call_type == CallTypes.completion.value and isinstance(cached_result, dict):
|
||||||
|
return convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse())
|
||||||
|
else:
|
||||||
|
return cached_result
|
||||||
# MODEL CALL
|
# MODEL CALL
|
||||||
result = original_function(*args, **kwargs)
|
result = original_function(*args, **kwargs)
|
||||||
end_time = datetime.datetime.now()
|
end_time = datetime.datetime.now()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue