From 523d8e5977b794dae358ed6a6b7938b94c73e2e1 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 6 Jan 2024 14:59:10 +0530 Subject: [PATCH 1/7] build(Dockerfile): moves prisma logic to dockerfile --- Dockerfile | 48 ++++++++++++----------------------- litellm/proxy/proxy_server.py | 22 ++++++++++++++++ retry_push.sh | 28 ++++++++++++++++++++ schema.prisma | 33 ++++++++++++++++++++++++ 4 files changed, 99 insertions(+), 32 deletions(-) create mode 100644 retry_push.sh create mode 100644 schema.prisma diff --git a/Dockerfile b/Dockerfile index b76aaf1d1..e46a9d6b8 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,53 +1,37 @@ -# Base image for building -ARG LITELLM_BUILD_IMAGE=python:3.9 # Runtime image ARG LITELLM_RUNTIME_IMAGE=python:3.9-slim - # Builder stage FROM $LITELLM_BUILD_IMAGE as builder -# Set the working directory to /app -WORKDIR /app - -# Install build dependencies -RUN apt-get clean && apt-get update && \ - apt-get install -y gcc python3-dev && \ - rm -rf /var/lib/apt/lists/* - -RUN pip install --upgrade pip && \ - pip install build - -# Copy the current directory contents into the container at /app -COPY . . - -# Build the package -RUN rm -rf dist/* && python -m build - -# There should be only one wheel file now, assume the build only creates one -RUN ls -1 dist/*.whl | head -1 - -# Install the package -RUN pip install dist/*.whl - -# install dependencies as wheels -RUN pip wheel --no-cache-dir --wheel-dir=/wheels/ -r requirements.txt + @@ -35,8 +34,12 @@ RUN pip wheel --no-cache-dir --wheel-dir=/wheels/ -r requirements.txt # Runtime stage FROM $LITELLM_RUNTIME_IMAGE as runtime +ARG with_database WORKDIR /app +# Copy the current directory contents into the container at /app +COPY . . +RUN ls -la /app # Copy the built wheel from the builder stage to the runtime stage; assumes only one wheel file is present COPY --from=builder /app/dist/*.whl . -COPY --from=builder /wheels/ /wheels/ - + @@ -45,9 +48,17 @@ COPY --from=builder /wheels/ /wheels/ # Install the built wheel using pip; again using a wildcard if it's the only file RUN pip install *.whl /wheels/* --no-index --find-links=/wheels/ && rm -f *.whl && rm -rf /wheels +# Check if the with_database argument is set to 'true' +RUN echo "Value of with_database is: ${with_database}" +# If true, execute the following instructions +RUN if [ "$with_database" = "true" ]; then \ + prisma generate; \ + chmod +x /app/retry_push.sh; \ + /app/retry_push.sh; \ + fi -EXPOSE 4000/tcp +EXPOSE 8000/tcp # Set your entrypoint and command ENTRYPOINT ["litellm"] -CMD ["--port", "4000"] \ No newline at end of file +CMD ["--config", "./proxy_server_config.yaml", "--port", "8000", "--num_workers", "8"] \ No newline at end of file diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 56fc298ae..4f6f9caab 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -2419,6 +2419,28 @@ async def health_endpoint( } +@router.get("/health/readiness", tags=["health"]) +async def health_readiness(): + """ + Unprotected endpoint for checking if worker can receive requests + """ + global prisma_client + if prisma_client is not None: # if db passed in, check if it's connected + if prisma_client.db.is_connected() == True: + return {"status": "healthy"} + else: + return {"status": "healthy"} + raise HTTPException(status_code=503, detail="Service Unhealthy") + + +@router.get("/health/liveliness", tags=["health"]) +async def health_liveliness(): + """ + Unprotected endpoint for checking if worker is alive + """ + return "I'm alive!" + + @router.get("/") async def home(request: Request): return "LiteLLM: RUNNING" diff --git a/retry_push.sh b/retry_push.sh new file mode 100644 index 000000000..5c41d72a0 --- /dev/null +++ b/retry_push.sh @@ -0,0 +1,28 @@ +#!/bin/bash + +retry_count=0 +max_retries=3 +exit_code=1 + +until [ $retry_count -ge $max_retries ] || [ $exit_code -eq 0 ] +do + retry_count=$((retry_count+1)) + echo "Attempt $retry_count..." + + # Run the Prisma db push command + prisma db push --accept-data-loss + + exit_code=$? + + if [ $exit_code -ne 0 ] && [ $retry_count -lt $max_retries ]; then + echo "Retrying in 10 seconds..." + sleep 10 + fi +done + +if [ $exit_code -ne 0 ]; then + echo "Unable to push database changes after $max_retries retries." + exit 1 +fi + +echo "Database push successful!" \ No newline at end of file diff --git a/schema.prisma b/schema.prisma new file mode 100644 index 000000000..d12cac8f2 --- /dev/null +++ b/schema.prisma @@ -0,0 +1,33 @@ +datasource client { + provider = "postgresql" + url = env("DATABASE_URL") +} + +generator client { + provider = "prisma-client-py" +} + +model LiteLLM_UserTable { + user_id String @unique + max_budget Float? + spend Float @default(0.0) + user_email String? +} + +// required for token gen +model LiteLLM_VerificationToken { + token String @unique + spend Float @default(0.0) + expires DateTime? + models String[] + aliases Json @default("{}") + config Json @default("{}") + user_id String? + max_parallel_requests Int? + metadata Json @default("{}") +} + +model LiteLLM_Config { + param_name String @id + param_value Json? +} \ No newline at end of file From 13e8535b14a04dabf09ad6bb640c633da22eba73 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 6 Jan 2024 15:17:42 +0530 Subject: [PATCH 2/7] test(test_async_fn.py): skip cloudflare test - flaky --- litellm/tests/test_async_fn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/litellm/tests/test_async_fn.py b/litellm/tests/test_async_fn.py index 5d6f18836..2d1f83fa1 100644 --- a/litellm/tests/test_async_fn.py +++ b/litellm/tests/test_async_fn.py @@ -155,6 +155,7 @@ def test_async_completion_cloudflare(): # test_async_completion_cloudflare() +@pytest.mark.skip(reason="Flaky test") def test_get_cloudflare_response_streaming(): import asyncio From 9a4a96f46e9c6f1383dfe4ac24c533302048bc10 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 6 Jan 2024 15:50:05 +0530 Subject: [PATCH 3/7] perf(azure+openai-files): use model_dump instead of json.loads + model_dump_json --- .circleci/config.yml | 3 +-- litellm/llms/azure.py | 18 ++++++++--------- litellm/llms/openai.py | 45 +++++++++++++++++++++--------------------- 3 files changed, 33 insertions(+), 33 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 25451f47b..5afd0c5d1 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -163,5 +163,4 @@ workflows: filters: branches: only: - - main - - /litellm_.*/ \ No newline at end of file + - main \ No newline at end of file diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index 98cc97d53..8a387e8a9 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -248,7 +248,7 @@ class AzureChatCompletion(BaseLLM): else: azure_client = client response = azure_client.chat.completions.create(**data, timeout=timeout) # type: ignore - stringified_response = response.model_dump_json() + stringified_response = response.model_dump() ## LOGGING logging_obj.post_call( input=messages, @@ -261,7 +261,7 @@ class AzureChatCompletion(BaseLLM): }, ) return convert_to_model_response_object( - response_object=json.loads(stringified_response), + response_object=stringified_response, model_response_object=model_response, ) except AzureOpenAIError as e: @@ -323,7 +323,7 @@ class AzureChatCompletion(BaseLLM): **data, timeout=timeout ) return convert_to_model_response_object( - response_object=json.loads(response.model_dump_json()), + response_object=response.model_dump(), model_response_object=model_response, ) except AzureOpenAIError as e: @@ -465,7 +465,7 @@ class AzureChatCompletion(BaseLLM): else: openai_aclient = client response = await openai_aclient.embeddings.create(**data, timeout=timeout) - stringified_response = response.model_dump_json() + stringified_response = response.model_dump() ## LOGGING logging_obj.post_call( input=input, @@ -474,7 +474,7 @@ class AzureChatCompletion(BaseLLM): original_response=stringified_response, ) return convert_to_model_response_object( - response_object=json.loads(stringified_response), + response_object=stringified_response, model_response_object=model_response, response_type="embedding", ) @@ -564,7 +564,7 @@ class AzureChatCompletion(BaseLLM): original_response=response, ) - return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="embedding") # type: ignore + return convert_to_model_response_object(response_object=response.model_dump(), model_response_object=model_response, response_type="embedding") # type: ignore except AzureOpenAIError as e: exception_mapping_worked = True raise e @@ -599,7 +599,7 @@ class AzureChatCompletion(BaseLLM): else: openai_aclient = client response = await openai_aclient.images.generate(**data, timeout=timeout) - stringified_response = response.model_dump_json() + stringified_response = response.model_dump() ## LOGGING logging_obj.post_call( input=input, @@ -608,7 +608,7 @@ class AzureChatCompletion(BaseLLM): original_response=stringified_response, ) return convert_to_model_response_object( - response_object=json.loads(stringified_response), + response_object=stringified_response, model_response_object=model_response, response_type="image_generation", ) @@ -697,7 +697,7 @@ class AzureChatCompletion(BaseLLM): original_response=response, ) # return response - return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="image_generation") # type: ignore + return convert_to_model_response_object(response_object=response.model_dump(), model_response_object=model_response, response_type="image_generation") # type: ignore except AzureOpenAIError as e: exception_mapping_worked = True raise e diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 0299c502c..91a79fa57 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -280,18 +280,6 @@ class OpenAIChatCompletion(BaseLLM): max_retries=max_retries, ) else: - ## LOGGING - logging_obj.pre_call( - input=messages, - api_key=api_key, - additional_args={ - "headers": headers, - "api_base": api_base, - "acompletion": acompletion, - "complete_input_dict": data, - }, - ) - if not isinstance(max_retries, int): raise OpenAIError( status_code=422, message="max retries must be an int" @@ -306,8 +294,21 @@ class OpenAIChatCompletion(BaseLLM): ) else: openai_client = client + + ## LOGGING + logging_obj.pre_call( + input=messages, + api_key=openai_client.api_key, + additional_args={ + "headers": headers, + "api_base": openai_client._base_url._uri_reference, + "acompletion": acompletion, + "complete_input_dict": data, + }, + ) + response = openai_client.chat.completions.create(**data, timeout=timeout) # type: ignore - stringified_response = response.model_dump_json() + stringified_response = response.model_dump() logging_obj.post_call( input=messages, api_key=api_key, @@ -315,7 +316,7 @@ class OpenAIChatCompletion(BaseLLM): additional_args={"complete_input_dict": data}, ) return convert_to_model_response_object( - response_object=json.loads(stringified_response), + response_object=stringified_response, model_response_object=model_response, ) except Exception as e: @@ -386,7 +387,7 @@ class OpenAIChatCompletion(BaseLLM): response = await openai_aclient.chat.completions.create( **data, timeout=timeout ) - stringified_response = response.model_dump_json() + stringified_response = response.model_dump() logging_obj.post_call( input=data["messages"], api_key=api_key, @@ -394,7 +395,7 @@ class OpenAIChatCompletion(BaseLLM): additional_args={"complete_input_dict": data}, ) return convert_to_model_response_object( - response_object=json.loads(stringified_response), + response_object=stringified_response, model_response_object=model_response, ) except Exception as e: @@ -527,7 +528,7 @@ class OpenAIChatCompletion(BaseLLM): else: openai_aclient = client response = await openai_aclient.embeddings.create(**data, timeout=timeout) # type: ignore - stringified_response = response.model_dump_json() + stringified_response = response.model_dump() ## LOGGING logging_obj.post_call( input=input, @@ -535,7 +536,7 @@ class OpenAIChatCompletion(BaseLLM): additional_args={"complete_input_dict": data}, original_response=stringified_response, ) - return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="embedding") # type: ignore + return convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="embedding") # type: ignore except Exception as e: ## LOGGING logging_obj.post_call( @@ -597,7 +598,7 @@ class OpenAIChatCompletion(BaseLLM): original_response=response, ) - return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="embedding") # type: ignore + return convert_to_model_response_object(response_object=response.model_dump(), model_response_object=model_response, response_type="embedding") # type: ignore except OpenAIError as e: exception_mapping_worked = True raise e @@ -634,7 +635,7 @@ class OpenAIChatCompletion(BaseLLM): else: openai_aclient = client response = await openai_aclient.images.generate(**data, timeout=timeout) # type: ignore - stringified_response = response.model_dump_json() + stringified_response = response.model_dump() ## LOGGING logging_obj.post_call( input=prompt, @@ -642,7 +643,7 @@ class OpenAIChatCompletion(BaseLLM): additional_args={"complete_input_dict": data}, original_response=stringified_response, ) - return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="image_generation") # type: ignore + return convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="image_generation") # type: ignore except Exception as e: ## LOGGING logging_obj.post_call( @@ -710,7 +711,7 @@ class OpenAIChatCompletion(BaseLLM): original_response=response, ) # return response - return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="image_generation") # type: ignore + return convert_to_model_response_object(response_object=model_dump(), model_response_object=model_response, response_type="image_generation") # type: ignore except OpenAIError as e: exception_mapping_worked = True raise e From 712f89b4f129cc5941cb826366fbddf7deb0477d Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 6 Jan 2024 17:02:50 +0530 Subject: [PATCH 4/7] fix(utils.py): handle original_response being a json --- litellm/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/litellm/utils.py b/litellm/utils.py index 8f93fb620..b258b473a 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -956,6 +956,8 @@ class Logging: ): # Log the exact result from the LLM API, for streaming - log the type of response received litellm.error_logs["POST_CALL"] = locals() + if isinstance(original_response, dict): + original_response = json.dumps(original_response) try: self.model_call_details["input"] = input self.model_call_details["api_key"] = api_key From 2d8d7e35696756b581db914d1cb69a720ef35a46 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 6 Jan 2024 17:05:55 +0530 Subject: [PATCH 5/7] perf(router.py): don't use asyncio.wait for - just pass it to the completion call for timeouts --- litellm/router.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index 770098df0..39c0c4b56 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -352,18 +352,16 @@ class Router: else: model_client = potential_model_client self.total_calls[model_name] += 1 - response = await asyncio.wait_for( - litellm.acompletion( + response = await litellm.acompletion( **{ **data, "messages": messages, "caching": self.cache_responses, "client": model_client, + "timeout": self.timeout, **kwargs, } - ), - timeout=self.timeout, - ) + ) self.success_calls[model_name] += 1 return response except Exception as e: @@ -614,18 +612,16 @@ class Router: else: model_client = potential_model_client self.total_calls[model_name] += 1 - response = await asyncio.wait_for( - litellm.atext_completion( + response = await litellm.atext_completion( **{ **data, "prompt": prompt, "caching": self.cache_responses, "client": model_client, + "timeout": self.timeout, **kwargs, } - ), - timeout=self.timeout, - ) + ) self.success_calls[model_name] += 1 return response except Exception as e: From f2ad13af6511181cf7318ac144eb9d418631f726 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 6 Jan 2024 17:55:32 +0530 Subject: [PATCH 6/7] fix(openai.py): fix image generation model dump --- litellm/llms/openai.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 91a79fa57..3265b230f 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -711,7 +711,7 @@ class OpenAIChatCompletion(BaseLLM): original_response=response, ) # return response - return convert_to_model_response_object(response_object=model_dump(), model_response_object=model_response, response_type="image_generation") # type: ignore + return convert_to_model_response_object(response_object=response.model_dump(), model_response_object=model_response, response_type="image_generation") # type: ignore except OpenAIError as e: exception_mapping_worked = True raise e From 3577857ed1ebde29bf9ab3502f81962f7f7483de Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 6 Jan 2024 21:52:58 +0530 Subject: [PATCH 7/7] fix(sagemaker.py): fix the post-call logging logic --- litellm/llms/sagemaker.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/litellm/llms/sagemaker.py b/litellm/llms/sagemaker.py index 2dfe7c8cb..1d341e7e9 100644 --- a/litellm/llms/sagemaker.py +++ b/litellm/llms/sagemaker.py @@ -360,14 +360,6 @@ def embedding( except Exception as e: raise SagemakerError(status_code=500, message=f"{str(e)}") - ## LOGGING - logging_obj.post_call( - input=input, - api_key="", - additional_args={"complete_input_dict": data}, - original_response=response, - ) - response = json.loads(response["Body"].read().decode("utf8")) ## LOGGING logging_obj.post_call( @@ -376,6 +368,7 @@ def embedding( original_response=response, additional_args={"complete_input_dict": data}, ) + print_verbose(f"raw model_response: {response}") if "embedding" not in response: raise SagemakerError(status_code=500, message="embedding not found in response")