test(test_caching.py): cleaning up tests

This commit is contained in:
Krrish Dholakia 2023-11-22 13:43:48 -08:00
parent e495a8a9c2
commit bd87e30058
4 changed files with 51 additions and 39 deletions

View file

@ -133,7 +133,7 @@ experimental = False
#### GLOBAL VARIABLES ####
llm_router: Optional[litellm.Router] = None
llm_model_list: Optional[list] = None
server_settings: dict = {}
general_settings: dict = {}
log_file = "api_log.json"
worker_config = None
master_key = None
@ -246,7 +246,7 @@ def run_ollama_serve():
def load_router_config(router: Optional[litellm.Router], config_file_path: str):
global master_key
config = {}
server_settings: dict = {}
general_settings: dict = {}
try:
if os.path.exists(config_file_path):
with open(config_file_path, 'r') as file:
@ -293,7 +293,7 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
if "ollama" in litellm_model_name:
run_ollama_serve()
return router, model_list, server_settings
return router, model_list, general_settings
async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict, config: dict):
token = f"sk-{secrets.token_urlsafe(16)}"
@ -366,13 +366,13 @@ def initialize(
config,
use_queue
):
global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, server_settings
global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, general_settings
generate_feedback_box()
user_model = model
user_debug = debug
dynamic_config = {"general": {}, user_model: {}}
if config:
llm_router, llm_model_list, server_settings = load_router_config(router=llm_router, config_file_path=config)
llm_router, llm_model_list, general_settings = load_router_config(router=llm_router, config_file_path=config)
if headers: # model-specific param
user_headers = headers
dynamic_config[user_model]["headers"] = headers
@ -480,9 +480,9 @@ async def shutdown_event():
@router.get("/v1/models", dependencies=[Depends(user_api_key_auth)])
@router.get("/models", dependencies=[Depends(user_api_key_auth)]) # if project requires model list
def model_list():
global llm_model_list, server_settings
global llm_model_list, general_settings
all_models = []
if server_settings.get("infer_model_from_keys", False):
if general_settings.get("infer_model_from_keys", False):
all_models = litellm.utils.get_valid_models()
if llm_model_list:
print(f"llm model list: {llm_model_list}")
@ -522,7 +522,7 @@ async def completion(request: Request, model: Optional[str] = None):
except:
data = json.loads(body_str)
data["model"] = (
server_settings.get("completion_model", None) # server default
general_settings.get("completion_model", None) # server default
or user_model # model name passed via cli args
or model # for azure deployments
or data["model"] # default passed in http request
@ -551,7 +551,7 @@ async def completion(request: Request, model: Optional[str] = None):
@router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)])
@router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)]) # azure compatible endpoint
async def chat_completion(request: Request, model: Optional[str] = None):
global server_settings
global general_settings
try:
body = await request.body()
body_str = body.decode()
@ -560,7 +560,7 @@ async def chat_completion(request: Request, model: Optional[str] = None):
except:
data = json.loads(body_str)
data["model"] = (
server_settings.get("completion_model", None) # server default
general_settings.get("completion_model", None) # server default
or user_model # model name passed via cli args
or model # for azure deployments
or data["model"] # default passed in http request
@ -584,8 +584,36 @@ async def chat_completion(request: Request, model: Optional[str] = None):
@router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)])
@router.post("/embeddings", dependencies=[Depends(user_api_key_auth)])
async def embeddings(request: Request):
pass
@router.post("/openai/deployments/{model:path}/embeddings", dependencies=[Depends(user_api_key_auth)]) # azure compatible
async def embeddings(request: Request, model: str):
try:
body = await request.body()
body_str = body.decode()
try:
data = ast.literal_eval(body_str)
except:
data = json.loads(body_str)
data["model"] = (
general_settings.get("embedding_model", None) # server default
or user_model # model name passed via cli args
or model # for azure http calls
or data["model"] # default passed in http request
)
if user_model:
data["model"] = user_model
## ROUTE TO CORRECT ENDPOINT ##
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
if llm_router is not None and data["model"] in router_model_names: # model in router model list
response = await llm_router.aembedding(**data)
else:
response = litellm.aembedding(**data)
return response
except Exception as e:
raise e
except Exception as e:
pass
@router.post("/key/generate", dependencies=[Depends(user_api_key_auth)])
async def generate_key_fn(request: Request):
@ -619,7 +647,7 @@ async def async_queue_request(request: Request):
except:
data = json.loads(body_str)
data["model"] = (
server_settings.get("completion_model", None) # server default
general_settings.get("completion_model", None) # server default
or user_model # model name passed via cli args
or data["model"] # default passed in http request
)