perf(sagemaker.py): asyncify hf prompt template check

leads to 189% improvement in RPS @ 100 users
This commit is contained in:
Krrish Dholakia 2024-08-23 15:45:42 -07:00
parent b0f01e5b95
commit 2cf149fbad
4 changed files with 253 additions and 125 deletions

View file

@ -0,0 +1,67 @@
import functools
from typing import Awaitable, Callable, ParamSpec, TypeVar
import anyio
from anyio import to_thread
T_ParamSpec = ParamSpec("T_ParamSpec")
T_Retval = TypeVar("T_Retval")
def function_has_argument(function: Callable, arg_name: str) -> bool:
"""Helper function to check if a function has a specific argument."""
import inspect
signature = inspect.signature(function)
return arg_name in signature.parameters
def asyncify(
function: Callable[T_ParamSpec, T_Retval],
*,
cancellable: bool = False,
limiter: anyio.CapacityLimiter | None = None,
) -> Callable[T_ParamSpec, Awaitable[T_Retval]]:
"""
Take a blocking function and create an async one that receives the same
positional and keyword arguments, and that when called, calls the original function
in a worker thread using `anyio.to_thread.run_sync()`.
If the `cancellable` option is enabled and the task waiting for its completion is
cancelled, the thread will still run its course but its return value (or any raised
exception) will be ignored.
## Arguments
- `function`: a blocking regular callable (e.g. a function)
- `cancellable`: `True` to allow cancellation of the operation
- `limiter`: capacity limiter to use to limit the total amount of threads running
(if omitted, the default limiter is used)
## Return
An async function that takes the same positional and keyword arguments as the
original one, that when called runs the same original function in a thread worker
and returns the result.
"""
async def wrapper(
*args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs
) -> T_Retval:
partial_f = functools.partial(function, *args, **kwargs)
# In `v4.1.0` anyio added the `abandon_on_cancel` argument and deprecated the old
# `cancellable` argument, so we need to use the new `abandon_on_cancel` to avoid
# surfacing deprecation warnings.
if function_has_argument(anyio.to_thread.run_sync, "abandon_on_cancel"):
return await anyio.to_thread.run_sync(
partial_f,
abandon_on_cancel=cancellable,
limiter=limiter,
)
return await anyio.to_thread.run_sync(
partial_f,
cancellable=cancellable,
limiter=limiter,
)
return wrapper

View file

@ -400,6 +400,7 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] =
tokenizer_config = known_tokenizer_config[model]
else:
tokenizer_config = _get_tokenizer_config(model)
known_tokenizer_config.update({model: tokenizer_config})
if (
tokenizer_config["status"] == "failure"

View file

@ -15,6 +15,7 @@ import requests # type: ignore
import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.asyncify import asyncify
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
@ -24,11 +25,8 @@ from litellm.llms.custom_httpx.http_handler import (
from litellm.types.llms.openai import (
ChatCompletionToolCallChunk,
ChatCompletionUsageBlock,
OpenAIChatCompletionChunk,
)
from litellm.types.utils import CustomStreamingDecoder
from litellm.types.utils import GenericStreamingChunk as GChunk
from litellm.types.utils import StreamingChatCompletionChunk
from litellm.utils import (
CustomStreamWrapper,
EmbeddingResponse,
@ -37,8 +35,8 @@ from litellm.utils import (
get_secret,
)
from ..base_aws_llm import BaseAWSLLM
from ..prompt_templates.factory import custom_prompt, prompt_factory
from .base_aws_llm import BaseAWSLLM
from .prompt_templates.factory import custom_prompt, prompt_factory
_response_stream_shape_cache = None
@ -201,6 +199,49 @@ class SagemakerLLM(BaseAWSLLM):
return prepped_request
def _transform_prompt(
self,
model: str,
messages: List,
custom_prompt_dict: dict,
hf_model_name: Optional[str],
) -> str:
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", None),
initial_prompt_value=model_prompt_details.get(
"initial_prompt_value", ""
),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
messages=messages,
)
elif hf_model_name in custom_prompt_dict:
# check if the base huggingface model has a registered custom prompt
model_prompt_details = custom_prompt_dict[hf_model_name]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", None),
initial_prompt_value=model_prompt_details.get(
"initial_prompt_value", ""
),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
messages=messages,
)
else:
if hf_model_name is None:
if "llama-2" in model.lower(): # llama-2 model
if "chat" in model.lower(): # apply llama2 chat template
hf_model_name = "meta-llama/Llama-2-7b-chat-hf"
else: # apply regular llama2 template
hf_model_name = "meta-llama/Llama-2-7b"
hf_model_name = (
hf_model_name or model
) # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt)
prompt = prompt_factory(model=hf_model_name, messages=messages)
return prompt
def completion(
self,
model: str,
@ -244,10 +285,6 @@ class SagemakerLLM(BaseAWSLLM):
aws_region_name=aws_region_name,
)
custom_stream_decoder = AWSEventStreamDecoder(
model="", is_messages_api=True
)
return openai_like_chat_completions.completion(
model=model,
messages=messages,
@ -266,7 +303,6 @@ class SagemakerLLM(BaseAWSLLM):
headers=prepared_request.headers,
custom_endpoint=True,
custom_llm_provider="sagemaker_chat",
streaming_decoder=custom_stream_decoder, # type: ignore
)
## Load Config
@ -277,42 +313,8 @@ class SagemakerLLM(BaseAWSLLM):
): # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", None),
initial_prompt_value=model_prompt_details.get(
"initial_prompt_value", ""
),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
messages=messages,
)
elif hf_model_name in custom_prompt_dict:
# check if the base huggingface model has a registered custom prompt
model_prompt_details = custom_prompt_dict[hf_model_name]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", None),
initial_prompt_value=model_prompt_details.get(
"initial_prompt_value", ""
),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
messages=messages,
)
else:
if hf_model_name is None:
if "llama-2" in model.lower(): # llama-2 model
if "chat" in model.lower(): # apply llama2 chat template
hf_model_name = "meta-llama/Llama-2-7b-chat-hf"
else: # apply regular llama2 template
hf_model_name = "meta-llama/Llama-2-7b"
hf_model_name = (
hf_model_name or model
) # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt)
prompt = prompt_factory(model=hf_model_name, messages=messages)
if stream is True:
data = {"inputs": prompt, "parameters": inference_params, "stream": True}
data = {"parameters": inference_params, "stream": True}
prepared_request = self._prepare_request(
model=model,
data=data,
@ -329,18 +331,41 @@ class SagemakerLLM(BaseAWSLLM):
if acompletion is True:
response = self.async_streaming(
prepared_request=prepared_request,
messages=messages,
model=model,
custom_prompt_dict=custom_prompt_dict,
hf_model_name=hf_model_name,
optional_params=optional_params,
encoding=encoding,
model_response=model_response,
model=model,
logging_obj=logging_obj,
data=data,
model_id=model_id,
aws_region_name=aws_region_name,
credentials=credentials,
)
return response
else:
if stream is not None and stream is True:
prompt = self._transform_prompt(
model=model,
messages=messages,
custom_prompt_dict=custom_prompt_dict,
hf_model_name=hf_model_name,
)
data["inputs"] = prompt
prepared_request = self._prepare_request(
model=model,
data=data,
optional_params=optional_params,
credentials=credentials,
aws_region_name=aws_region_name,
)
if model_id is not None:
# Add model_id as InferenceComponentName header
# boto3 doc: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_runtime_InvokeEndpoint.html
prepared_request.headers.update(
{"X-Amzn-SageMaker-Inference-Component": model_id}
)
sync_handler = _get_httpx_client()
sync_response = sync_handler.post(
url=prepared_request.url,
@ -377,27 +402,41 @@ class SagemakerLLM(BaseAWSLLM):
return streaming_response
# Non-Streaming Requests
_data = {"inputs": prompt, "parameters": inference_params}
prepared_request = self._prepare_request(
_data = {"parameters": inference_params}
prepared_request_args = {
"model": model,
"data": _data,
"optional_params": optional_params,
"credentials": credentials,
"aws_region_name": aws_region_name,
}
# Async completion
if acompletion is True:
return self.async_completion(
messages=messages,
model=model,
custom_prompt_dict=custom_prompt_dict,
hf_model_name=hf_model_name,
model_response=model_response,
encoding=encoding,
logging_obj=logging_obj,
data=_data,
model_id=model_id,
optional_params=optional_params,
credentials=credentials,
aws_region_name=aws_region_name,
)
# Async completion
if acompletion is True:
return self.async_completion(
prepared_request=prepared_request,
model_response=model_response,
encoding=encoding,
prompt = self._transform_prompt(
model=model,
logging_obj=logging_obj,
data=_data,
model_id=model_id,
messages=messages,
custom_prompt_dict=custom_prompt_dict,
hf_model_name=hf_model_name,
)
_data["inputs"] = prompt
## Non-Streaming completion CALL
prepared_request = self._prepare_request(**prepared_request_args)
try:
if model_id is not None:
# Add model_id as InferenceComponentName header
@ -483,7 +522,7 @@ class SagemakerLLM(BaseAWSLLM):
completion_output = completion_output.replace(prompt, "", 1)
model_response.choices[0].message.content = completion_output # type: ignore
except:
except Exception:
raise SagemakerError(
message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}",
status_code=500,
@ -555,15 +594,34 @@ class SagemakerLLM(BaseAWSLLM):
async def async_streaming(
self,
prepared_request,
messages: list,
model: str,
custom_prompt_dict: dict,
hf_model_name: Optional[str],
credentials,
aws_region_name: str,
optional_params,
encoding,
model_response: ModelResponse,
model: str,
model_id: Optional[str],
logging_obj: Any,
data,
):
data["inputs"] = self._transform_prompt(
model=model,
messages=messages,
custom_prompt_dict=custom_prompt_dict,
hf_model_name=hf_model_name,
)
asyncified_prepare_request = asyncify(self._prepare_request)
prepared_request_args = {
"model": model,
"data": data,
"optional_params": optional_params,
"credentials": credentials,
"aws_region_name": aws_region_name,
}
prepared_request = await asyncified_prepare_request(**prepared_request_args)
streaming_response = CustomStreamWrapper(
completion_stream=None,
make_call=partial(
@ -590,16 +648,40 @@ class SagemakerLLM(BaseAWSLLM):
async def async_completion(
self,
prepared_request,
messages: list,
model: str,
custom_prompt_dict: dict,
hf_model_name: Optional[str],
credentials,
aws_region_name: str,
encoding,
model_response: ModelResponse,
model: str,
optional_params: dict,
logging_obj: Any,
data: dict,
model_id: Optional[str],
):
timeout = 300.0
async_handler = _get_async_httpx_client()
async_transform_prompt = asyncify(self._transform_prompt)
data["inputs"] = await async_transform_prompt(
model=model,
messages=messages,
custom_prompt_dict=custom_prompt_dict,
hf_model_name=hf_model_name,
)
asyncified_prepare_request = asyncify(self._prepare_request)
prepared_request_args = {
"model": model,
"data": data,
"optional_params": optional_params,
"credentials": credentials,
"aws_region_name": aws_region_name,
}
prepared_request = await asyncified_prepare_request(**prepared_request_args)
## LOGGING
logging_obj.pre_call(
input=[],
@ -669,7 +751,7 @@ class SagemakerLLM(BaseAWSLLM):
completion_output = completion_output.replace(data["inputs"], "", 1)
model_response.choices[0].message.content = completion_output # type: ignore
except:
except Exception:
raise SagemakerError(
message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}",
status_code=500,
@ -855,21 +937,12 @@ def get_response_stream_shape():
class AWSEventStreamDecoder:
def __init__(self, model: str, is_messages_api: Optional[bool] = None) -> None:
def __init__(self, model: str) -> None:
from botocore.parsers import EventStreamJSONParser
self.model = model
self.parser = EventStreamJSONParser()
self.content_blocks: List = []
self.is_messages_api = is_messages_api
def _chunk_parser_messages_api(
self, chunk_data: dict
) -> StreamingChatCompletionChunk:
openai_chunk = StreamingChatCompletionChunk(**chunk_data)
return openai_chunk
def _chunk_parser(self, chunk_data: dict) -> GChunk:
verbose_logger.debug("in sagemaker chunk parser, chunk_data %s", chunk_data)
@ -885,7 +958,6 @@ class AWSEventStreamDecoder:
index=_index,
is_finished=True,
finish_reason="stop",
usage=None,
)
return GChunk(
@ -893,12 +965,9 @@ class AWSEventStreamDecoder:
index=_index,
is_finished=is_finished,
finish_reason=finish_reason,
usage=None,
)
def iter_bytes(
self, iterator: Iterator[bytes]
) -> Iterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]:
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GChunk]:
"""Given an iterator that yields lines, iterate over it & yield every event encountered"""
from botocore.eventstream import EventStreamBuffer
@ -919,9 +988,6 @@ class AWSEventStreamDecoder:
# Try to parse the accumulated JSON
try:
_data = json.loads(accumulated_json)
if self.is_messages_api:
yield self._chunk_parser_messages_api(chunk_data=_data)
else:
yield self._chunk_parser(chunk_data=_data)
# Reset accumulated_json after successful parsing
accumulated_json = ""
@ -933,20 +999,16 @@ class AWSEventStreamDecoder:
if accumulated_json:
try:
_data = json.loads(accumulated_json)
if self.is_messages_api:
yield self._chunk_parser_messages_api(chunk_data=_data)
else:
yield self._chunk_parser(chunk_data=_data)
except json.JSONDecodeError as e:
except json.JSONDecodeError:
# Handle or log any unparseable data at the end
verbose_logger.error(
f"Warning: Unparseable JSON data remained: {accumulated_json}"
)
yield None
async def aiter_bytes(
self, iterator: AsyncIterator[bytes]
) -> AsyncIterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]:
) -> AsyncIterator[GChunk]:
"""Given an async iterator that yields lines, iterate over it & yield every event encountered"""
from botocore.eventstream import EventStreamBuffer
@ -968,9 +1030,6 @@ class AWSEventStreamDecoder:
# Try to parse the accumulated JSON
try:
_data = json.loads(accumulated_json)
if self.is_messages_api:
yield self._chunk_parser_messages_api(chunk_data=_data)
else:
yield self._chunk_parser(chunk_data=_data)
# Reset accumulated_json after successful parsing
accumulated_json = ""
@ -982,16 +1041,12 @@ class AWSEventStreamDecoder:
if accumulated_json:
try:
_data = json.loads(accumulated_json)
if self.is_messages_api:
yield self._chunk_parser_messages_api(chunk_data=_data)
else:
yield self._chunk_parser(chunk_data=_data)
except json.JSONDecodeError:
# Handle or log any unparseable data at the end
verbose_logger.error(
f"Warning: Unparseable JSON data remained: {accumulated_json}"
)
yield None
def _parse_message_from_event(self, event) -> Optional[str]:
response_dict = event.to_response_dict()

View file

@ -1,7 +1,12 @@
model_list:
- model_name: "*"
- model_name: fake-openai-endpoint
litellm_params:
model: "*"
model: sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614
# sagemaker_base_url: https://exampleopenaiendpoint-production.up.railway.app/invocations/
# api_base: https://exampleopenaiendpoint-production.up.railway.app
general_settings:
global_max_parallel_requests: 0