From 2e8d582a3458d38a43e8faf96b0c3953a065d4e7 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 24 Nov 2023 11:38:53 -0800 Subject: [PATCH] fix(proxy_server.py): fix linting issues --- litellm/caching.py | 6 ++- litellm/main.py | 10 +++-- litellm/proxy/proxy_server.py | 54 +++++++++++++++----------- litellm/tests/test_router_cooldowns.py | 4 +- litellm/tests/test_router_fallbacks.py | 2 +- 5 files changed, 44 insertions(+), 32 deletions(-) diff --git a/litellm/caching.py b/litellm/caching.py index d14ef8a657..5e8fcf4477 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -156,8 +156,10 @@ class DualCache(BaseCache): traceback.print_exc() def flush_cache(self): - self.redis_cache.flush_cache() - self.in_memory_cache.flush_cache() + if self.in_memory_cache is not None: + self.in_memory_cache.flush_cache() + if self.redis_cache is not None: + self.redis_cache.flush_cache() #### LiteLLM.Completion Cache #### class Cache: diff --git a/litellm/main.py b/litellm/main.py index 7ffabd0584..ddc96223a4 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -28,7 +28,9 @@ from litellm.utils import ( completion_with_fallbacks, get_llm_provider, get_api_key, - mock_completion_streaming_obj + mock_completion_streaming_obj, + convert_to_model_response_object, + token_counter ) from .llms import ( anthropic, @@ -2145,7 +2147,7 @@ def stream_chunk_builder(chunks: list, messages: Optional[list]=None): # # Update usage information if needed if messages: - response["usage"]["prompt_tokens"] = litellm.utils.token_counter(model=model, messages=messages) - response["usage"]["completion_tokens"] = litellm.utils.token_counter(model=model, text=combined_content) + response["usage"]["prompt_tokens"] = token_counter(model=model, messages=messages) + response["usage"]["completion_tokens"] = token_counter(model=model, text=combined_content) response["usage"]["total_tokens"] = response["usage"]["prompt_tokens"] + response["usage"]["completion_tokens"] - return litellm.utils.convert_to_model_response_object(response_object=response, model_response_object=litellm.ModelResponse()) + return convert_to_model_response_object(response_object=response, model_response_object=litellm.ModelResponse()) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 1fa1343de9..bb3011b3f3 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -277,10 +277,10 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): if isinstance(litellm.success_callback, list): import utils print("setting litellm success callback to track cost") - if (utils.track_cost_callback) not in litellm.success_callback: - litellm.success_callback.append(utils.track_cost_callback) + if (utils.track_cost_callback) not in litellm.success_callback: # type: ignore + litellm.success_callback.append(utils.track_cost_callback) # type: ignore else: - litellm.success_callback = utils.track_cost_callback + litellm.success_callback = utils.track_cost_callback # type: ignore ### START REDIS QUEUE ### use_queue = general_settings.get("use_queue", False) celery_setup(use_queue=use_queue) @@ -717,32 +717,40 @@ async def test_endpoint(request: Request): @router.post("/queue/request", dependencies=[Depends(user_api_key_auth)]) async def async_queue_request(request: Request): global celery_fn, llm_model_list - body = await request.body() - body_str = body.decode() - try: - data = ast.literal_eval(body_str) - except: - data = json.loads(body_str) - data["model"] = ( - general_settings.get("completion_model", None) # server default - or user_model # model name passed via cli args - or data["model"] # default passed in http request - ) - data["llm_model_list"] = llm_model_list - print(f"data: {data}") - job = celery_fn.apply_async(kwargs=data) - return {"id": job.id, "url": f"/queue/response/{job.id}", "eta": 5, "status": "queued"} - pass + if celery_fn is not None: + body = await request.body() + body_str = body.decode() + try: + data = ast.literal_eval(body_str) + except: + data = json.loads(body_str) + data["model"] = ( + general_settings.get("completion_model", None) # server default + or user_model # model name passed via cli args + or data["model"] # default passed in http request + ) + data["llm_model_list"] = llm_model_list + print(f"data: {data}") + job = celery_fn.apply_async(kwargs=data) + return {"id": job.id, "url": f"/queue/response/{job.id}", "eta": 5, "status": "queued"} + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": "Queue not initialized"}, + ) @router.get("/queue/response/{task_id}", dependencies=[Depends(user_api_key_auth)]) async def async_queue_response(request: Request, task_id: str): global celery_app_conn, async_result try: - job = async_result(task_id, app=celery_app_conn) - if job.ready(): - return {"status": "finished", "result": job.result} + if celery_app_conn is not None and async_result is not None: + job = async_result(task_id, app=celery_app_conn) + if job.ready(): + return {"status": "finished", "result": job.result} + else: + return {'status': 'queued'} else: - return {'status': 'queued'} + raise Exception() except Exception as e: return {"status": "finished", "result": str(e)} diff --git a/litellm/tests/test_router_cooldowns.py b/litellm/tests/test_router_cooldowns.py index 24e47e7cf1..0c7079bd15 100644 --- a/litellm/tests/test_router_cooldowns.py +++ b/litellm/tests/test_router_cooldowns.py @@ -38,9 +38,9 @@ model_list = [{ # list of model deployments router = Router(model_list=model_list, redis_host=os.getenv("REDIS_HOST"), redis_password=os.getenv("REDIS_PASSWORD"), - redis_port=int(os.getenv("REDIS_PORT")), + redis_port=int(os.getenv("REDIS_PORT")), # type: ignore routing_strategy="simple-shuffle", - set_verbose=True, + set_verbose=False, num_retries=1) # type: ignore kwargs = {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hey, how's it going?"}],} diff --git a/litellm/tests/test_router_fallbacks.py b/litellm/tests/test_router_fallbacks.py index 4f21a0ac1c..cdcc8cc2da 100644 --- a/litellm/tests/test_router_fallbacks.py +++ b/litellm/tests/test_router_fallbacks.py @@ -101,7 +101,7 @@ def test_async_fallbacks(): asyncio.run(test_get_response()) -# test_async_fallbacks() +test_async_fallbacks() def test_sync_context_window_fallbacks(): try: