From 14115d0d607ea1b6503f5579fb9abccc06d5f59b Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 21 Dec 2023 15:39:09 +0530 Subject: [PATCH] feat(proxy_server.py): add new images/generation endpoint --- litellm/proxy/proxy_server.py | 66 +++++++++++++++++++ .../test_configs/test_config_no_auth.yaml | 17 ++++- litellm/tests/test_proxy_server.py | 25 ++++++- litellm/tests/test_router.py | 12 ++-- 4 files changed, 112 insertions(+), 8 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 296d8e576..8a32f1b4f 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1127,6 +1127,72 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen detail=error_msg ) + +@router.post("/v1/images/generations", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse, tags=["image generation"]) +@router.post("/images/generations", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse, tags=["image generation"]) +async def image_generation(request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()): + global proxy_logging_obj + try: + # Use orjson to parse JSON data, orjson speeds up requests significantly + body = await request.body() + data = orjson.loads(body) + + # Include original request and headers in the data + data["proxy_server_request"] = { + "url": str(request.url), + "method": request.method, + "headers": dict(request.headers), + "body": copy.copy(data) # use copy instead of deepcopy + } + + if data.get("user", None) is None and user_api_key_dict.user_id is not None: + data["user"] = user_api_key_dict.user_id + + data["model"] = ( + general_settings.get("image_generation_model", None) # server default + or user_model # model name passed via cli args + or data["model"] # default passed in http request + ) + if user_model: + data["model"] = user_model + if "metadata" in data: + data["metadata"]["user_api_key"] = user_api_key_dict.api_key + data["metadata"]["headers"] = dict(request.headers) + else: + data["metadata"] = {"user_api_key": user_api_key_dict.api_key} + data["metadata"]["headers"] = dict(request.headers) + router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else [] + + ### CALL HOOKS ### - modify incoming data / reject request before calling the model + data = await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings") + ## ROUTE TO CORRECT ENDPOINT ## + if llm_router is not None and data["model"] in router_model_names: # model in router model list + response = await llm_router.aimage_generation(**data) + elif llm_router is not None and data["model"] in llm_router.deployment_names: # model in router deployments, calling a specific deployment on the router + response = await llm_router.aimage_generation(**data, specific_deployment = True) + elif llm_router is not None and llm_router.model_group_alias is not None and data["model"] in llm_router.model_group_alias: # model set in model_group_alias + response = await llm_router.aimage_generation(**data) # ensure this goes the llm_router, router will do the correct alias mapping + else: + response = await litellm.aimage_generation(**data) + background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL + + return response + except Exception as e: + await proxy_logging_obj.post_call_failure_hook(user_api_key_dict=user_api_key_dict, original_exception=e) + traceback.print_exc() + if isinstance(e, HTTPException): + raise e + else: + error_traceback = traceback.format_exc() + error_msg = f"{str(e)}\n\n{error_traceback}" + try: + status = e.status_code # type: ignore + except: + status = 500 + raise HTTPException( + status_code=status, + detail=error_msg + ) #### KEY MANAGEMENT #### @router.post("/key/generate", tags=["key management"], dependencies=[Depends(user_api_key_auth)], response_model=GenerateKeyResponse) diff --git a/litellm/tests/test_configs/test_config_no_auth.yaml b/litellm/tests/test_configs/test_config_no_auth.yaml index 2fd9ef203..76e7a294b 100644 --- a/litellm/tests/test_configs/test_config_no_auth.yaml +++ b/litellm/tests/test_configs/test_config_no_auth.yaml @@ -78,4 +78,19 @@ model_list: model: "bedrock/amazon.titan-embed-text-v1" - model_name: "GPT-J 6B - Sagemaker Text Embedding (Internal)" litellm_params: - model: "sagemaker/berri-benchmarking-gpt-j-6b-fp16" \ No newline at end of file + model: "sagemaker/berri-benchmarking-gpt-j-6b-fp16" +- model_name: dall-e-3 + litellm_params: + model: dall-e-3 +- model_name: dall-e-3 + litellm_params: + model: "azure/dall-e-3-test" + api_version: "2023-12-01-preview" + api_base: "os.environ/AZURE_SWEDEN_API_BASE" + api_key: "os.environ/AZURE_SWEDEN_API_KEY" +- model_name: dall-e-2 + litellm_params: + model: "azure/" + api_version: "2023-06-01-preview" + api_base: "os.environ/AZURE_API_BASE" + api_key: "os.environ/AZURE_API_KEY" \ No newline at end of file diff --git a/litellm/tests/test_proxy_server.py b/litellm/tests/test_proxy_server.py index 5e9854f43..b71f5b890 100644 --- a/litellm/tests/test_proxy_server.py +++ b/litellm/tests/test_proxy_server.py @@ -101,7 +101,7 @@ def test_chat_completion_azure(client_no_auth): # Run the test # test_chat_completion_azure() - +### EMBEDDING def test_embedding(client_no_auth): global headers from litellm.proxy.proxy_server import user_custom_auth @@ -161,7 +161,30 @@ def test_sagemaker_embedding(client_no_auth): # Run the test # test_embedding() +#### IMAGE GENERATION + +def test_img_gen(client_no_auth): + global headers + from litellm.proxy.proxy_server import user_custom_auth + try: + test_data = { + "model": "dall-e-3", + "prompt": "A cute baby sea otter", + "n": 1, + "size": "1024x1024" + } + + response = client_no_auth.post("/v1/images/generations", json=test_data) + + assert response.status_code == 200 + result = response.json() + print(len(result["data"][0]["url"])) + assert len(result["data"][0]["url"]) > 10 + except Exception as e: + pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}") + +#### ADDITIONAL # @pytest.mark.skip(reason="hitting yaml load issues on circle-ci") def test_add_new_model(client_no_auth): global headers diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index b52db394f..81440c257 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -455,12 +455,12 @@ async def test_aimg_gen_on_router(): } ] router = Router(model_list=model_list) - # response = await router.aimage_generation( - # model="dall-e-3", - # prompt="A cute baby sea otter" - # ) - # print(response) - # assert len(response.data) > 0 + response = await router.aimage_generation( + model="dall-e-3", + prompt="A cute baby sea otter" + ) + print(response) + assert len(response.data) > 0 response = await router.aimage_generation( model="dall-e-2",