Merge pull request #1344 from BerriAI/litellm_speed_improvements

Litellm speed improvements
This commit is contained in:
Krish Dholakia 2024-01-06 22:38:10 +05:30 committed by GitHub
commit 439ee3bafc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 46 additions and 82 deletions

View file

@ -164,4 +164,3 @@ workflows:
branches: branches:
only: only:
- main - main
- /litellm_.*/

View file

@ -1,36 +1,10 @@
# Base image for building
ARG LITELLM_BUILD_IMAGE=python:3.9
# Runtime image # Runtime image
ARG LITELLM_RUNTIME_IMAGE=python:3.9-slim ARG LITELLM_RUNTIME_IMAGE=python:3.9-slim
# Builder stage # Builder stage
FROM $LITELLM_BUILD_IMAGE as builder FROM $LITELLM_BUILD_IMAGE as builder
# Set the working directory to /app @@ -35,8 +34,12 @@ RUN pip wheel --no-cache-dir --wheel-dir=/wheels/ -r requirements.txt
WORKDIR /app
# Install build dependencies
RUN apt-get clean && apt-get update && \
apt-get install -y gcc python3-dev && \
rm -rf /var/lib/apt/lists/*
RUN pip install --upgrade pip && \
pip install build
# Copy the current directory contents into the container at /app
COPY . .
# Build the package
RUN rm -rf dist/* && python -m build
# There should be only one wheel file now, assume the build only creates one
RUN ls -1 dist/*.whl | head -1
# Install the package
RUN pip install dist/*.whl
# install dependencies as wheels
RUN pip wheel --no-cache-dir --wheel-dir=/wheels/ -r requirements.txt
# Runtime stage # Runtime stage
FROM $LITELLM_RUNTIME_IMAGE as runtime FROM $LITELLM_RUNTIME_IMAGE as runtime
@ -43,8 +17,7 @@ RUN ls -la /app
# Copy the built wheel from the builder stage to the runtime stage; assumes only one wheel file is present # Copy the built wheel from the builder stage to the runtime stage; assumes only one wheel file is present
COPY --from=builder /app/dist/*.whl . COPY --from=builder /app/dist/*.whl .
COPY --from=builder /wheels/ /wheels/ @@ -45,9 +48,17 @@ COPY --from=builder /wheels/ /wheels/
# Install the built wheel using pip; again using a wildcard if it's the only file # Install the built wheel using pip; again using a wildcard if it's the only file
RUN pip install *.whl /wheels/* --no-index --find-links=/wheels/ && rm -f *.whl && rm -rf /wheels RUN pip install *.whl /wheels/* --no-index --find-links=/wheels/ && rm -f *.whl && rm -rf /wheels
@ -57,8 +30,8 @@ RUN if [ "$with_database" = "true" ]; then \
/app/retry_push.sh; \ /app/retry_push.sh; \
fi fi
EXPOSE 4000/tcp EXPOSE 8000/tcp
# Set your entrypoint and command # Set your entrypoint and command
ENTRYPOINT ["litellm"] ENTRYPOINT ["litellm"]
CMD ["--port", "4000"] CMD ["--config", "./proxy_server_config.yaml", "--port", "8000", "--num_workers", "8"]

View file

@ -248,7 +248,7 @@ class AzureChatCompletion(BaseLLM):
else: else:
azure_client = client azure_client = client
response = azure_client.chat.completions.create(**data, timeout=timeout) # type: ignore response = azure_client.chat.completions.create(**data, timeout=timeout) # type: ignore
stringified_response = response.model_dump_json() stringified_response = response.model_dump()
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=messages, input=messages,
@ -261,7 +261,7 @@ class AzureChatCompletion(BaseLLM):
}, },
) )
return convert_to_model_response_object( return convert_to_model_response_object(
response_object=json.loads(stringified_response), response_object=stringified_response,
model_response_object=model_response, model_response_object=model_response,
) )
except AzureOpenAIError as e: except AzureOpenAIError as e:
@ -323,7 +323,7 @@ class AzureChatCompletion(BaseLLM):
**data, timeout=timeout **data, timeout=timeout
) )
return convert_to_model_response_object( return convert_to_model_response_object(
response_object=json.loads(response.model_dump_json()), response_object=response.model_dump(),
model_response_object=model_response, model_response_object=model_response,
) )
except AzureOpenAIError as e: except AzureOpenAIError as e:
@ -465,7 +465,7 @@ class AzureChatCompletion(BaseLLM):
else: else:
openai_aclient = client openai_aclient = client
response = await openai_aclient.embeddings.create(**data, timeout=timeout) response = await openai_aclient.embeddings.create(**data, timeout=timeout)
stringified_response = response.model_dump_json() stringified_response = response.model_dump()
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=input, input=input,
@ -474,7 +474,7 @@ class AzureChatCompletion(BaseLLM):
original_response=stringified_response, original_response=stringified_response,
) )
return convert_to_model_response_object( return convert_to_model_response_object(
response_object=json.loads(stringified_response), response_object=stringified_response,
model_response_object=model_response, model_response_object=model_response,
response_type="embedding", response_type="embedding",
) )
@ -564,7 +564,7 @@ class AzureChatCompletion(BaseLLM):
original_response=response, original_response=response,
) )
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="embedding") # type: ignore return convert_to_model_response_object(response_object=response.model_dump(), model_response_object=model_response, response_type="embedding") # type: ignore
except AzureOpenAIError as e: except AzureOpenAIError as e:
exception_mapping_worked = True exception_mapping_worked = True
raise e raise e
@ -599,7 +599,7 @@ class AzureChatCompletion(BaseLLM):
else: else:
openai_aclient = client openai_aclient = client
response = await openai_aclient.images.generate(**data, timeout=timeout) response = await openai_aclient.images.generate(**data, timeout=timeout)
stringified_response = response.model_dump_json() stringified_response = response.model_dump()
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=input, input=input,
@ -608,7 +608,7 @@ class AzureChatCompletion(BaseLLM):
original_response=stringified_response, original_response=stringified_response,
) )
return convert_to_model_response_object( return convert_to_model_response_object(
response_object=json.loads(stringified_response), response_object=stringified_response,
model_response_object=model_response, model_response_object=model_response,
response_type="image_generation", response_type="image_generation",
) )
@ -697,7 +697,7 @@ class AzureChatCompletion(BaseLLM):
original_response=response, original_response=response,
) )
# return response # return response
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="image_generation") # type: ignore return convert_to_model_response_object(response_object=response.model_dump(), model_response_object=model_response, response_type="image_generation") # type: ignore
except AzureOpenAIError as e: except AzureOpenAIError as e:
exception_mapping_worked = True exception_mapping_worked = True
raise e raise e

View file

@ -280,18 +280,6 @@ class OpenAIChatCompletion(BaseLLM):
max_retries=max_retries, max_retries=max_retries,
) )
else: else:
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=api_key,
additional_args={
"headers": headers,
"api_base": api_base,
"acompletion": acompletion,
"complete_input_dict": data,
},
)
if not isinstance(max_retries, int): if not isinstance(max_retries, int):
raise OpenAIError( raise OpenAIError(
status_code=422, message="max retries must be an int" status_code=422, message="max retries must be an int"
@ -306,8 +294,21 @@ class OpenAIChatCompletion(BaseLLM):
) )
else: else:
openai_client = client openai_client = client
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=openai_client.api_key,
additional_args={
"headers": headers,
"api_base": openai_client._base_url._uri_reference,
"acompletion": acompletion,
"complete_input_dict": data,
},
)
response = openai_client.chat.completions.create(**data, timeout=timeout) # type: ignore response = openai_client.chat.completions.create(**data, timeout=timeout) # type: ignore
stringified_response = response.model_dump_json() stringified_response = response.model_dump()
logging_obj.post_call( logging_obj.post_call(
input=messages, input=messages,
api_key=api_key, api_key=api_key,
@ -315,7 +316,7 @@ class OpenAIChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
return convert_to_model_response_object( return convert_to_model_response_object(
response_object=json.loads(stringified_response), response_object=stringified_response,
model_response_object=model_response, model_response_object=model_response,
) )
except Exception as e: except Exception as e:
@ -386,7 +387,7 @@ class OpenAIChatCompletion(BaseLLM):
response = await openai_aclient.chat.completions.create( response = await openai_aclient.chat.completions.create(
**data, timeout=timeout **data, timeout=timeout
) )
stringified_response = response.model_dump_json() stringified_response = response.model_dump()
logging_obj.post_call( logging_obj.post_call(
input=data["messages"], input=data["messages"],
api_key=api_key, api_key=api_key,
@ -394,7 +395,7 @@ class OpenAIChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
return convert_to_model_response_object( return convert_to_model_response_object(
response_object=json.loads(stringified_response), response_object=stringified_response,
model_response_object=model_response, model_response_object=model_response,
) )
except Exception as e: except Exception as e:
@ -527,7 +528,7 @@ class OpenAIChatCompletion(BaseLLM):
else: else:
openai_aclient = client openai_aclient = client
response = await openai_aclient.embeddings.create(**data, timeout=timeout) # type: ignore response = await openai_aclient.embeddings.create(**data, timeout=timeout) # type: ignore
stringified_response = response.model_dump_json() stringified_response = response.model_dump()
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=input, input=input,
@ -535,7 +536,7 @@ class OpenAIChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
original_response=stringified_response, original_response=stringified_response,
) )
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="embedding") # type: ignore return convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="embedding") # type: ignore
except Exception as e: except Exception as e:
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
@ -597,7 +598,7 @@ class OpenAIChatCompletion(BaseLLM):
original_response=response, original_response=response,
) )
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="embedding") # type: ignore return convert_to_model_response_object(response_object=response.model_dump(), model_response_object=model_response, response_type="embedding") # type: ignore
except OpenAIError as e: except OpenAIError as e:
exception_mapping_worked = True exception_mapping_worked = True
raise e raise e
@ -634,7 +635,7 @@ class OpenAIChatCompletion(BaseLLM):
else: else:
openai_aclient = client openai_aclient = client
response = await openai_aclient.images.generate(**data, timeout=timeout) # type: ignore response = await openai_aclient.images.generate(**data, timeout=timeout) # type: ignore
stringified_response = response.model_dump_json() stringified_response = response.model_dump()
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=prompt, input=prompt,
@ -642,7 +643,7 @@ class OpenAIChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
original_response=stringified_response, original_response=stringified_response,
) )
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="image_generation") # type: ignore return convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="image_generation") # type: ignore
except Exception as e: except Exception as e:
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
@ -710,7 +711,7 @@ class OpenAIChatCompletion(BaseLLM):
original_response=response, original_response=response,
) )
# return response # return response
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="image_generation") # type: ignore return convert_to_model_response_object(response_object=response.model_dump(), model_response_object=model_response, response_type="image_generation") # type: ignore
except OpenAIError as e: except OpenAIError as e:
exception_mapping_worked = True exception_mapping_worked = True
raise e raise e

View file

@ -360,14 +360,6 @@ def embedding(
except Exception as e: except Exception as e:
raise SagemakerError(status_code=500, message=f"{str(e)}") raise SagemakerError(status_code=500, message=f"{str(e)}")
## LOGGING
logging_obj.post_call(
input=input,
api_key="",
additional_args={"complete_input_dict": data},
original_response=response,
)
response = json.loads(response["Body"].read().decode("utf8")) response = json.loads(response["Body"].read().decode("utf8"))
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
@ -376,6 +368,7 @@ def embedding(
original_response=response, original_response=response,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
print_verbose(f"raw model_response: {response}") print_verbose(f"raw model_response: {response}")
if "embedding" not in response: if "embedding" not in response:
raise SagemakerError(status_code=500, message="embedding not found in response") raise SagemakerError(status_code=500, message="embedding not found in response")

View file

@ -352,17 +352,15 @@ class Router:
else: else:
model_client = potential_model_client model_client = potential_model_client
self.total_calls[model_name] += 1 self.total_calls[model_name] += 1
response = await asyncio.wait_for( response = await litellm.acompletion(
litellm.acompletion(
**{ **{
**data, **data,
"messages": messages, "messages": messages,
"caching": self.cache_responses, "caching": self.cache_responses,
"client": model_client, "client": model_client,
"timeout": self.timeout,
**kwargs, **kwargs,
} }
),
timeout=self.timeout,
) )
self.success_calls[model_name] += 1 self.success_calls[model_name] += 1
return response return response
@ -614,17 +612,15 @@ class Router:
else: else:
model_client = potential_model_client model_client = potential_model_client
self.total_calls[model_name] += 1 self.total_calls[model_name] += 1
response = await asyncio.wait_for( response = await litellm.atext_completion(
litellm.atext_completion(
**{ **{
**data, **data,
"prompt": prompt, "prompt": prompt,
"caching": self.cache_responses, "caching": self.cache_responses,
"client": model_client, "client": model_client,
"timeout": self.timeout,
**kwargs, **kwargs,
} }
),
timeout=self.timeout,
) )
self.success_calls[model_name] += 1 self.success_calls[model_name] += 1
return response return response

View file

@ -956,6 +956,8 @@ class Logging:
): ):
# Log the exact result from the LLM API, for streaming - log the type of response received # Log the exact result from the LLM API, for streaming - log the type of response received
litellm.error_logs["POST_CALL"] = locals() litellm.error_logs["POST_CALL"] = locals()
if isinstance(original_response, dict):
original_response = json.dumps(original_response)
try: try:
self.model_call_details["input"] = input self.model_call_details["input"] = input
self.model_call_details["api_key"] = api_key self.model_call_details["api_key"] = api_key