From 563d59a3059f2c90965749c2fb6a27b3308cc4b0 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Tue, 30 Jul 2024 09:46:30 -0700 Subject: [PATCH] test batches endpoint on proxy --- litellm/batches/main.py | 4 -- litellm/proxy/proxy_server.py | 89 ++++++++++++++++++++++++++- tests/test_openai_batches_endpoint.py | 23 +++++++ 3 files changed, 110 insertions(+), 6 deletions(-) diff --git a/litellm/batches/main.py b/litellm/batches/main.py index 3c41cf5d2..a2ebc664e 100644 --- a/litellm/batches/main.py +++ b/litellm/batches/main.py @@ -439,10 +439,6 @@ def list_batches( pass -async def alist_batch(): - pass - - def cancel_batch(): pass diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index fe9b74874..cbb62289b 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -4898,12 +4898,12 @@ async def create_batch( @router.get( - "/v1/batches{batch_id:path}", + "/v1/batches/{batch_id:path}", dependencies=[Depends(user_api_key_auth)], tags=["batch"], ) @router.get( - "/batches{batch_id:path}", + "/batches/{batch_id:path}", dependencies=[Depends(user_api_key_auth)], tags=["batch"], ) @@ -4993,6 +4993,91 @@ async def retrieve_batch( ) +@router.get( + "/v1/batches", + dependencies=[Depends(user_api_key_auth)], + tags=["batch"], +) +@router.get( + "/batches", + dependencies=[Depends(user_api_key_auth)], + tags=["batch"], +) +async def list_batches( + fastapi_response: Response, + limit: Optional[int] = None, + after: Optional[str] = None, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Lists + This is the equivalent of GET https://api.openai.com/v1/batches/ + Supports Identical Params as: https://platform.openai.com/docs/api-reference/batch/list + + Example Curl + ``` + curl http://localhost:4000/v1/batches?limit=2 \ + -H "Authorization: Bearer sk-1234" \ + -H "Content-Type: application/json" \ + + ``` + """ + global proxy_logging_obj + verbose_proxy_logger.debug("GET /v1/batches after={} limit={}".format(after, limit)) + try: + # for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch + response = await litellm.alist_batches( + custom_llm_provider="openai", + after=after, + limit=limit, + ) + + ### RESPONSE HEADERS ### + hidden_params = getattr(response, "_hidden_params", {}) or {} + model_id = hidden_params.get("model_id", None) or "" + cache_key = hidden_params.get("cache_key", None) or "" + api_base = hidden_params.get("api_base", None) or "" + + fastapi_response.headers.update( + get_custom_headers( + user_api_key_dict=user_api_key_dict, + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version, + model_region=getattr(user_api_key_dict, "allowed_model_region", ""), + ) + ) + + return response + except Exception as e: + await proxy_logging_obj.post_call_failure_hook( + user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data + ) + verbose_proxy_logger.error( + "litellm.proxy.proxy_server.retrieve_batch(): Exception occured - {}".format( + str(e) + ) + ) + verbose_proxy_logger.debug(traceback.format_exc()) + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "message", str(e.detail)), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + else: + error_traceback = traceback.format_exc() + error_msg = f"{str(e)}" + raise ProxyException( + message=getattr(e, "message", error_msg), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", 500), + ) + + ###################################################################### # END OF /v1/batches Endpoints Implementation diff --git a/tests/test_openai_batches_endpoint.py b/tests/test_openai_batches_endpoint.py index 75e3c3f88..a6e26e782 100644 --- a/tests/test_openai_batches_endpoint.py +++ b/tests/test_openai_batches_endpoint.py @@ -41,6 +41,19 @@ async def get_batch_by_id(session, batch_id): return None +async def list_batches(session): + url = f"{BASE_URL}/v1/batches" + headers = {"Authorization": f"Bearer {API_KEY}"} + + async with session.get(url, headers=headers) as response: + if response.status == 200: + result = await response.json() + return result + else: + print(f"Error: Failed to get batch. Status code: {response.status}") + return None + + @pytest.mark.asyncio async def test_batches_operations(): async with aiohttp.ClientSession() as session: @@ -60,5 +73,15 @@ async def test_batches_operations(): assert get_batch_response["id"] == batch_id assert get_batch_response["input_file_id"] == file_id + # test LIST Batches + list_batch_response = await list_batches(session) + print("response from list batch", list_batch_response) + + assert list_batch_response is not None + assert len(list_batch_response["data"]) > 0 + + element_0 = list_batch_response["data"][0] + assert element_0["id"] is not None + # Test delete file await delete_file(session, file_id)