From b0801f61e6ba1b0f9d3bda35f115c2116ab856e3 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 22 Nov 2023 13:43:48 -0800 Subject: [PATCH] test(test_caching.py): cleaning up tests --- docs/my-website/sidebars.js | 1 - litellm/proxy/proxy_server.py | 54 ++++++++++++++++++++++++++--------- litellm/tests/test_caching.py | 21 +------------- litellm/utils.py | 14 +++++---- 4 files changed, 51 insertions(+), 39 deletions(-) diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index d50f03675..ddae583c1 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -143,7 +143,6 @@ const sidebars = { items: [ "caching/local_caching", "caching/redis_cache", - "caching/caching_api", ], }, { diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 8543a535e..91238a8e4 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -133,7 +133,7 @@ experimental = False #### GLOBAL VARIABLES #### llm_router: Optional[litellm.Router] = None llm_model_list: Optional[list] = None -server_settings: dict = {} +general_settings: dict = {} log_file = "api_log.json" worker_config = None master_key = None @@ -246,7 +246,7 @@ def run_ollama_serve(): def load_router_config(router: Optional[litellm.Router], config_file_path: str): global master_key config = {} - server_settings: dict = {} + general_settings: dict = {} try: if os.path.exists(config_file_path): with open(config_file_path, 'r') as file: @@ -293,7 +293,7 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): if "ollama" in litellm_model_name: run_ollama_serve() - return router, model_list, server_settings + return router, model_list, general_settings async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict, config: dict): token = f"sk-{secrets.token_urlsafe(16)}" @@ -366,13 +366,13 @@ def initialize( config, use_queue ): - global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, server_settings + global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, general_settings generate_feedback_box() user_model = model user_debug = debug dynamic_config = {"general": {}, user_model: {}} if config: - llm_router, llm_model_list, server_settings = load_router_config(router=llm_router, config_file_path=config) + llm_router, llm_model_list, general_settings = load_router_config(router=llm_router, config_file_path=config) if headers: # model-specific param user_headers = headers dynamic_config[user_model]["headers"] = headers @@ -480,9 +480,9 @@ async def shutdown_event(): @router.get("/v1/models", dependencies=[Depends(user_api_key_auth)]) @router.get("/models", dependencies=[Depends(user_api_key_auth)]) # if project requires model list def model_list(): - global llm_model_list, server_settings + global llm_model_list, general_settings all_models = [] - if server_settings.get("infer_model_from_keys", False): + if general_settings.get("infer_model_from_keys", False): all_models = litellm.utils.get_valid_models() if llm_model_list: print(f"llm model list: {llm_model_list}") @@ -522,7 +522,7 @@ async def completion(request: Request, model: Optional[str] = None): except: data = json.loads(body_str) data["model"] = ( - server_settings.get("completion_model", None) # server default + general_settings.get("completion_model", None) # server default or user_model # model name passed via cli args or model # for azure deployments or data["model"] # default passed in http request @@ -551,7 +551,7 @@ async def completion(request: Request, model: Optional[str] = None): @router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)]) @router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)]) # azure compatible endpoint async def chat_completion(request: Request, model: Optional[str] = None): - global server_settings + global general_settings try: body = await request.body() body_str = body.decode() @@ -560,7 +560,7 @@ async def chat_completion(request: Request, model: Optional[str] = None): except: data = json.loads(body_str) data["model"] = ( - server_settings.get("completion_model", None) # server default + general_settings.get("completion_model", None) # server default or user_model # model name passed via cli args or model # for azure deployments or data["model"] # default passed in http request @@ -584,8 +584,36 @@ async def chat_completion(request: Request, model: Optional[str] = None): @router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)]) @router.post("/embeddings", dependencies=[Depends(user_api_key_auth)]) -async def embeddings(request: Request): - pass +@router.post("/openai/deployments/{model:path}/embeddings", dependencies=[Depends(user_api_key_auth)]) # azure compatible +async def embeddings(request: Request, model: str): + try: + body = await request.body() + body_str = body.decode() + try: + data = ast.literal_eval(body_str) + except: + data = json.loads(body_str) + + data["model"] = ( + general_settings.get("embedding_model", None) # server default + or user_model # model name passed via cli args + or model # for azure http calls + or data["model"] # default passed in http request + ) + if user_model: + data["model"] = user_model + + ## ROUTE TO CORRECT ENDPOINT ## + router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else [] + if llm_router is not None and data["model"] in router_model_names: # model in router model list + response = await llm_router.aembedding(**data) + else: + response = litellm.aembedding(**data) + return response + except Exception as e: + raise e + except Exception as e: + pass @router.post("/key/generate", dependencies=[Depends(user_api_key_auth)]) async def generate_key_fn(request: Request): @@ -619,7 +647,7 @@ async def async_queue_request(request: Request): except: data = json.loads(body_str) data["model"] = ( - server_settings.get("completion_model", None) # server default + general_settings.get("completion_model", None) # server default or user_model # model name passed via cli args or data["model"] # default passed in http request ) diff --git a/litellm/tests/test_caching.py b/litellm/tests/test_caching.py index 83188b5ba..a482fb48d 100644 --- a/litellm/tests/test_caching.py +++ b/litellm/tests/test_caching.py @@ -91,7 +91,7 @@ def test_embedding_caching(): print(f"embedding2: {embedding2}") pytest.fail("Error occurred: Embedding caching failed") -test_embedding_caching() +# test_embedding_caching() def test_embedding_caching_azure(): @@ -212,25 +212,6 @@ def test_custom_redis_cache_with_key(): # test_custom_redis_cache_with_key() -def test_hosted_cache(): - litellm.cache = Cache(type="hosted") # use api.litellm.ai for caching - - messages = [{"role": "user", "content": "what is litellm arr today?"}] - response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True) - print("response1", response1) - - response2 = completion(model="gpt-3.5-turbo", messages=messages, caching=True) - print("response2", response2) - - if response1['choices'][0]['message']['content'] != response2['choices'][0]['message']['content']: # 1 and 2 should be the same - print(f"response1: {response1}") - print(f"response2: {response2}") - pytest.fail(f"Hosted cache: Response2 is not cached and the same as response 1") - litellm.cache = None - -# test_hosted_cache() - - # def test_redis_cache_with_ttl(): # cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD']) # sample_model_response_object_str = """{ diff --git a/litellm/utils.py b/litellm/utils.py index 9af58db33..5d4ef74b2 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1232,11 +1232,15 @@ def client(original_function): cached_result = litellm.cache.get_cache(*args, **kwargs) if cached_result != None: print_verbose(f"Cache Hit!") - call_type = original_function.__name__ - if call_type == CallTypes.completion.value and isinstance(cached_result, dict): - return convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse()) - else: - return cached_result + if "detail" in cached_result: + # implies an error occurred + pass + else: + call_type = original_function.__name__ + if call_type == CallTypes.completion.value and isinstance(cached_result, dict): + return convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse()) + else: + return cached_result # MODEL CALL result = original_function(*args, **kwargs) end_time = datetime.datetime.now()