mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
test(test_caching.py): cleaning up tests
This commit is contained in:
parent
78582e158a
commit
b0801f61e6
4 changed files with 51 additions and 39 deletions
|
@ -143,7 +143,6 @@ const sidebars = {
|
|||
items: [
|
||||
"caching/local_caching",
|
||||
"caching/redis_cache",
|
||||
"caching/caching_api",
|
||||
],
|
||||
},
|
||||
{
|
||||
|
|
|
@ -133,7 +133,7 @@ experimental = False
|
|||
#### GLOBAL VARIABLES ####
|
||||
llm_router: Optional[litellm.Router] = None
|
||||
llm_model_list: Optional[list] = None
|
||||
server_settings: dict = {}
|
||||
general_settings: dict = {}
|
||||
log_file = "api_log.json"
|
||||
worker_config = None
|
||||
master_key = None
|
||||
|
@ -246,7 +246,7 @@ def run_ollama_serve():
|
|||
def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||
global master_key
|
||||
config = {}
|
||||
server_settings: dict = {}
|
||||
general_settings: dict = {}
|
||||
try:
|
||||
if os.path.exists(config_file_path):
|
||||
with open(config_file_path, 'r') as file:
|
||||
|
@ -293,7 +293,7 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
|||
if "ollama" in litellm_model_name:
|
||||
run_ollama_serve()
|
||||
|
||||
return router, model_list, server_settings
|
||||
return router, model_list, general_settings
|
||||
|
||||
async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict, config: dict):
|
||||
token = f"sk-{secrets.token_urlsafe(16)}"
|
||||
|
@ -366,13 +366,13 @@ def initialize(
|
|||
config,
|
||||
use_queue
|
||||
):
|
||||
global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, server_settings
|
||||
global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, general_settings
|
||||
generate_feedback_box()
|
||||
user_model = model
|
||||
user_debug = debug
|
||||
dynamic_config = {"general": {}, user_model: {}}
|
||||
if config:
|
||||
llm_router, llm_model_list, server_settings = load_router_config(router=llm_router, config_file_path=config)
|
||||
llm_router, llm_model_list, general_settings = load_router_config(router=llm_router, config_file_path=config)
|
||||
if headers: # model-specific param
|
||||
user_headers = headers
|
||||
dynamic_config[user_model]["headers"] = headers
|
||||
|
@ -480,9 +480,9 @@ async def shutdown_event():
|
|||
@router.get("/v1/models", dependencies=[Depends(user_api_key_auth)])
|
||||
@router.get("/models", dependencies=[Depends(user_api_key_auth)]) # if project requires model list
|
||||
def model_list():
|
||||
global llm_model_list, server_settings
|
||||
global llm_model_list, general_settings
|
||||
all_models = []
|
||||
if server_settings.get("infer_model_from_keys", False):
|
||||
if general_settings.get("infer_model_from_keys", False):
|
||||
all_models = litellm.utils.get_valid_models()
|
||||
if llm_model_list:
|
||||
print(f"llm model list: {llm_model_list}")
|
||||
|
@ -522,7 +522,7 @@ async def completion(request: Request, model: Optional[str] = None):
|
|||
except:
|
||||
data = json.loads(body_str)
|
||||
data["model"] = (
|
||||
server_settings.get("completion_model", None) # server default
|
||||
general_settings.get("completion_model", None) # server default
|
||||
or user_model # model name passed via cli args
|
||||
or model # for azure deployments
|
||||
or data["model"] # default passed in http request
|
||||
|
@ -551,7 +551,7 @@ async def completion(request: Request, model: Optional[str] = None):
|
|||
@router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)])
|
||||
@router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)]) # azure compatible endpoint
|
||||
async def chat_completion(request: Request, model: Optional[str] = None):
|
||||
global server_settings
|
||||
global general_settings
|
||||
try:
|
||||
body = await request.body()
|
||||
body_str = body.decode()
|
||||
|
@ -560,7 +560,7 @@ async def chat_completion(request: Request, model: Optional[str] = None):
|
|||
except:
|
||||
data = json.loads(body_str)
|
||||
data["model"] = (
|
||||
server_settings.get("completion_model", None) # server default
|
||||
general_settings.get("completion_model", None) # server default
|
||||
or user_model # model name passed via cli args
|
||||
or model # for azure deployments
|
||||
or data["model"] # default passed in http request
|
||||
|
@ -584,8 +584,36 @@ async def chat_completion(request: Request, model: Optional[str] = None):
|
|||
|
||||
@router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)])
|
||||
@router.post("/embeddings", dependencies=[Depends(user_api_key_auth)])
|
||||
async def embeddings(request: Request):
|
||||
pass
|
||||
@router.post("/openai/deployments/{model:path}/embeddings", dependencies=[Depends(user_api_key_auth)]) # azure compatible
|
||||
async def embeddings(request: Request, model: str):
|
||||
try:
|
||||
body = await request.body()
|
||||
body_str = body.decode()
|
||||
try:
|
||||
data = ast.literal_eval(body_str)
|
||||
except:
|
||||
data = json.loads(body_str)
|
||||
|
||||
data["model"] = (
|
||||
general_settings.get("embedding_model", None) # server default
|
||||
or user_model # model name passed via cli args
|
||||
or model # for azure http calls
|
||||
or data["model"] # default passed in http request
|
||||
)
|
||||
if user_model:
|
||||
data["model"] = user_model
|
||||
|
||||
## ROUTE TO CORRECT ENDPOINT ##
|
||||
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
|
||||
if llm_router is not None and data["model"] in router_model_names: # model in router model list
|
||||
response = await llm_router.aembedding(**data)
|
||||
else:
|
||||
response = litellm.aembedding(**data)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
@router.post("/key/generate", dependencies=[Depends(user_api_key_auth)])
|
||||
async def generate_key_fn(request: Request):
|
||||
|
@ -619,7 +647,7 @@ async def async_queue_request(request: Request):
|
|||
except:
|
||||
data = json.loads(body_str)
|
||||
data["model"] = (
|
||||
server_settings.get("completion_model", None) # server default
|
||||
general_settings.get("completion_model", None) # server default
|
||||
or user_model # model name passed via cli args
|
||||
or data["model"] # default passed in http request
|
||||
)
|
||||
|
|
|
@ -91,7 +91,7 @@ def test_embedding_caching():
|
|||
print(f"embedding2: {embedding2}")
|
||||
pytest.fail("Error occurred: Embedding caching failed")
|
||||
|
||||
test_embedding_caching()
|
||||
# test_embedding_caching()
|
||||
|
||||
|
||||
def test_embedding_caching_azure():
|
||||
|
@ -212,25 +212,6 @@ def test_custom_redis_cache_with_key():
|
|||
|
||||
# test_custom_redis_cache_with_key()
|
||||
|
||||
def test_hosted_cache():
|
||||
litellm.cache = Cache(type="hosted") # use api.litellm.ai for caching
|
||||
|
||||
messages = [{"role": "user", "content": "what is litellm arr today?"}]
|
||||
response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
|
||||
print("response1", response1)
|
||||
|
||||
response2 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
|
||||
print("response2", response2)
|
||||
|
||||
if response1['choices'][0]['message']['content'] != response2['choices'][0]['message']['content']: # 1 and 2 should be the same
|
||||
print(f"response1: {response1}")
|
||||
print(f"response2: {response2}")
|
||||
pytest.fail(f"Hosted cache: Response2 is not cached and the same as response 1")
|
||||
litellm.cache = None
|
||||
|
||||
# test_hosted_cache()
|
||||
|
||||
|
||||
# def test_redis_cache_with_ttl():
|
||||
# cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
|
||||
# sample_model_response_object_str = """{
|
||||
|
|
|
@ -1232,11 +1232,15 @@ def client(original_function):
|
|||
cached_result = litellm.cache.get_cache(*args, **kwargs)
|
||||
if cached_result != None:
|
||||
print_verbose(f"Cache Hit!")
|
||||
call_type = original_function.__name__
|
||||
if call_type == CallTypes.completion.value and isinstance(cached_result, dict):
|
||||
return convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse())
|
||||
else:
|
||||
return cached_result
|
||||
if "detail" in cached_result:
|
||||
# implies an error occurred
|
||||
pass
|
||||
else:
|
||||
call_type = original_function.__name__
|
||||
if call_type == CallTypes.completion.value and isinstance(cached_result, dict):
|
||||
return convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse())
|
||||
else:
|
||||
return cached_result
|
||||
# MODEL CALL
|
||||
result = original_function(*args, **kwargs)
|
||||
end_time = datetime.datetime.now()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue