perf(sagemaker.py): asyncify hf prompt template check

leads to 189% improvement in RPS @ 100 users
This commit is contained in:
Krrish Dholakia 2024-08-23 15:45:42 -07:00
parent b0f01e5b95
commit 2cf149fbad
4 changed files with 253 additions and 125 deletions

View file

@ -0,0 +1,67 @@
import functools
from typing import Awaitable, Callable, ParamSpec, TypeVar
import anyio
from anyio import to_thread
T_ParamSpec = ParamSpec("T_ParamSpec")
T_Retval = TypeVar("T_Retval")
def function_has_argument(function: Callable, arg_name: str) -> bool:
"""Helper function to check if a function has a specific argument."""
import inspect
signature = inspect.signature(function)
return arg_name in signature.parameters
def asyncify(
function: Callable[T_ParamSpec, T_Retval],
*,
cancellable: bool = False,
limiter: anyio.CapacityLimiter | None = None,
) -> Callable[T_ParamSpec, Awaitable[T_Retval]]:
"""
Take a blocking function and create an async one that receives the same
positional and keyword arguments, and that when called, calls the original function
in a worker thread using `anyio.to_thread.run_sync()`.
If the `cancellable` option is enabled and the task waiting for its completion is
cancelled, the thread will still run its course but its return value (or any raised
exception) will be ignored.
## Arguments
- `function`: a blocking regular callable (e.g. a function)
- `cancellable`: `True` to allow cancellation of the operation
- `limiter`: capacity limiter to use to limit the total amount of threads running
(if omitted, the default limiter is used)
## Return
An async function that takes the same positional and keyword arguments as the
original one, that when called runs the same original function in a thread worker
and returns the result.
"""
async def wrapper(
*args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs
) -> T_Retval:
partial_f = functools.partial(function, *args, **kwargs)
# In `v4.1.0` anyio added the `abandon_on_cancel` argument and deprecated the old
# `cancellable` argument, so we need to use the new `abandon_on_cancel` to avoid
# surfacing deprecation warnings.
if function_has_argument(anyio.to_thread.run_sync, "abandon_on_cancel"):
return await anyio.to_thread.run_sync(
partial_f,
abandon_on_cancel=cancellable,
limiter=limiter,
)
return await anyio.to_thread.run_sync(
partial_f,
cancellable=cancellable,
limiter=limiter,
)
return wrapper

View file

@ -400,6 +400,7 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] =
tokenizer_config = known_tokenizer_config[model] tokenizer_config = known_tokenizer_config[model]
else: else:
tokenizer_config = _get_tokenizer_config(model) tokenizer_config = _get_tokenizer_config(model)
known_tokenizer_config.update({model: tokenizer_config})
if ( if (
tokenizer_config["status"] == "failure" tokenizer_config["status"] == "failure"

View file

@ -15,6 +15,7 @@ import requests # type: ignore
import litellm import litellm
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.litellm_core_utils.asyncify import asyncify
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler, AsyncHTTPHandler,
HTTPHandler, HTTPHandler,
@ -24,11 +25,8 @@ from litellm.llms.custom_httpx.http_handler import (
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
ChatCompletionToolCallChunk, ChatCompletionToolCallChunk,
ChatCompletionUsageBlock, ChatCompletionUsageBlock,
OpenAIChatCompletionChunk,
) )
from litellm.types.utils import CustomStreamingDecoder
from litellm.types.utils import GenericStreamingChunk as GChunk from litellm.types.utils import GenericStreamingChunk as GChunk
from litellm.types.utils import StreamingChatCompletionChunk
from litellm.utils import ( from litellm.utils import (
CustomStreamWrapper, CustomStreamWrapper,
EmbeddingResponse, EmbeddingResponse,
@ -37,8 +35,8 @@ from litellm.utils import (
get_secret, get_secret,
) )
from ..base_aws_llm import BaseAWSLLM from .base_aws_llm import BaseAWSLLM
from ..prompt_templates.factory import custom_prompt, prompt_factory from .prompt_templates.factory import custom_prompt, prompt_factory
_response_stream_shape_cache = None _response_stream_shape_cache = None
@ -201,6 +199,49 @@ class SagemakerLLM(BaseAWSLLM):
return prepped_request return prepped_request
def _transform_prompt(
self,
model: str,
messages: List,
custom_prompt_dict: dict,
hf_model_name: Optional[str],
) -> str:
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", None),
initial_prompt_value=model_prompt_details.get(
"initial_prompt_value", ""
),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
messages=messages,
)
elif hf_model_name in custom_prompt_dict:
# check if the base huggingface model has a registered custom prompt
model_prompt_details = custom_prompt_dict[hf_model_name]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", None),
initial_prompt_value=model_prompt_details.get(
"initial_prompt_value", ""
),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
messages=messages,
)
else:
if hf_model_name is None:
if "llama-2" in model.lower(): # llama-2 model
if "chat" in model.lower(): # apply llama2 chat template
hf_model_name = "meta-llama/Llama-2-7b-chat-hf"
else: # apply regular llama2 template
hf_model_name = "meta-llama/Llama-2-7b"
hf_model_name = (
hf_model_name or model
) # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt)
prompt = prompt_factory(model=hf_model_name, messages=messages)
return prompt
def completion( def completion(
self, self,
model: str, model: str,
@ -244,10 +285,6 @@ class SagemakerLLM(BaseAWSLLM):
aws_region_name=aws_region_name, aws_region_name=aws_region_name,
) )
custom_stream_decoder = AWSEventStreamDecoder(
model="", is_messages_api=True
)
return openai_like_chat_completions.completion( return openai_like_chat_completions.completion(
model=model, model=model,
messages=messages, messages=messages,
@ -266,7 +303,6 @@ class SagemakerLLM(BaseAWSLLM):
headers=prepared_request.headers, headers=prepared_request.headers,
custom_endpoint=True, custom_endpoint=True,
custom_llm_provider="sagemaker_chat", custom_llm_provider="sagemaker_chat",
streaming_decoder=custom_stream_decoder, # type: ignore
) )
## Load Config ## Load Config
@ -277,42 +313,8 @@ class SagemakerLLM(BaseAWSLLM):
): # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in ): # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v inference_params[k] = v
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", None),
initial_prompt_value=model_prompt_details.get(
"initial_prompt_value", ""
),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
messages=messages,
)
elif hf_model_name in custom_prompt_dict:
# check if the base huggingface model has a registered custom prompt
model_prompt_details = custom_prompt_dict[hf_model_name]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", None),
initial_prompt_value=model_prompt_details.get(
"initial_prompt_value", ""
),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
messages=messages,
)
else:
if hf_model_name is None:
if "llama-2" in model.lower(): # llama-2 model
if "chat" in model.lower(): # apply llama2 chat template
hf_model_name = "meta-llama/Llama-2-7b-chat-hf"
else: # apply regular llama2 template
hf_model_name = "meta-llama/Llama-2-7b"
hf_model_name = (
hf_model_name or model
) # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt)
prompt = prompt_factory(model=hf_model_name, messages=messages)
if stream is True: if stream is True:
data = {"inputs": prompt, "parameters": inference_params, "stream": True} data = {"parameters": inference_params, "stream": True}
prepared_request = self._prepare_request( prepared_request = self._prepare_request(
model=model, model=model,
data=data, data=data,
@ -329,18 +331,41 @@ class SagemakerLLM(BaseAWSLLM):
if acompletion is True: if acompletion is True:
response = self.async_streaming( response = self.async_streaming(
prepared_request=prepared_request, messages=messages,
model=model,
custom_prompt_dict=custom_prompt_dict,
hf_model_name=hf_model_name,
optional_params=optional_params, optional_params=optional_params,
encoding=encoding, encoding=encoding,
model_response=model_response, model_response=model_response,
model=model,
logging_obj=logging_obj, logging_obj=logging_obj,
data=data, data=data,
model_id=model_id, model_id=model_id,
aws_region_name=aws_region_name,
credentials=credentials,
) )
return response return response
else: else:
if stream is not None and stream is True: prompt = self._transform_prompt(
model=model,
messages=messages,
custom_prompt_dict=custom_prompt_dict,
hf_model_name=hf_model_name,
)
data["inputs"] = prompt
prepared_request = self._prepare_request(
model=model,
data=data,
optional_params=optional_params,
credentials=credentials,
aws_region_name=aws_region_name,
)
if model_id is not None:
# Add model_id as InferenceComponentName header
# boto3 doc: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_runtime_InvokeEndpoint.html
prepared_request.headers.update(
{"X-Amzn-SageMaker-Inference-Component": model_id}
)
sync_handler = _get_httpx_client() sync_handler = _get_httpx_client()
sync_response = sync_handler.post( sync_response = sync_handler.post(
url=prepared_request.url, url=prepared_request.url,
@ -377,27 +402,41 @@ class SagemakerLLM(BaseAWSLLM):
return streaming_response return streaming_response
# Non-Streaming Requests # Non-Streaming Requests
_data = {"inputs": prompt, "parameters": inference_params} _data = {"parameters": inference_params}
prepared_request = self._prepare_request( prepared_request_args = {
"model": model,
"data": _data,
"optional_params": optional_params,
"credentials": credentials,
"aws_region_name": aws_region_name,
}
# Async completion
if acompletion is True:
return self.async_completion(
messages=messages,
model=model, model=model,
custom_prompt_dict=custom_prompt_dict,
hf_model_name=hf_model_name,
model_response=model_response,
encoding=encoding,
logging_obj=logging_obj,
data=_data, data=_data,
model_id=model_id,
optional_params=optional_params, optional_params=optional_params,
credentials=credentials, credentials=credentials,
aws_region_name=aws_region_name, aws_region_name=aws_region_name,
) )
# Async completion prompt = self._transform_prompt(
if acompletion is True:
return self.async_completion(
prepared_request=prepared_request,
model_response=model_response,
encoding=encoding,
model=model, model=model,
logging_obj=logging_obj, messages=messages,
data=_data, custom_prompt_dict=custom_prompt_dict,
model_id=model_id, hf_model_name=hf_model_name,
) )
_data["inputs"] = prompt
## Non-Streaming completion CALL ## Non-Streaming completion CALL
prepared_request = self._prepare_request(**prepared_request_args)
try: try:
if model_id is not None: if model_id is not None:
# Add model_id as InferenceComponentName header # Add model_id as InferenceComponentName header
@ -483,7 +522,7 @@ class SagemakerLLM(BaseAWSLLM):
completion_output = completion_output.replace(prompt, "", 1) completion_output = completion_output.replace(prompt, "", 1)
model_response.choices[0].message.content = completion_output # type: ignore model_response.choices[0].message.content = completion_output # type: ignore
except: except Exception:
raise SagemakerError( raise SagemakerError(
message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}", message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}",
status_code=500, status_code=500,
@ -555,15 +594,34 @@ class SagemakerLLM(BaseAWSLLM):
async def async_streaming( async def async_streaming(
self, self,
prepared_request, messages: list,
model: str,
custom_prompt_dict: dict,
hf_model_name: Optional[str],
credentials,
aws_region_name: str,
optional_params, optional_params,
encoding, encoding,
model_response: ModelResponse, model_response: ModelResponse,
model: str,
model_id: Optional[str], model_id: Optional[str],
logging_obj: Any, logging_obj: Any,
data, data,
): ):
data["inputs"] = self._transform_prompt(
model=model,
messages=messages,
custom_prompt_dict=custom_prompt_dict,
hf_model_name=hf_model_name,
)
asyncified_prepare_request = asyncify(self._prepare_request)
prepared_request_args = {
"model": model,
"data": data,
"optional_params": optional_params,
"credentials": credentials,
"aws_region_name": aws_region_name,
}
prepared_request = await asyncified_prepare_request(**prepared_request_args)
streaming_response = CustomStreamWrapper( streaming_response = CustomStreamWrapper(
completion_stream=None, completion_stream=None,
make_call=partial( make_call=partial(
@ -590,16 +648,40 @@ class SagemakerLLM(BaseAWSLLM):
async def async_completion( async def async_completion(
self, self,
prepared_request, messages: list,
model: str,
custom_prompt_dict: dict,
hf_model_name: Optional[str],
credentials,
aws_region_name: str,
encoding, encoding,
model_response: ModelResponse, model_response: ModelResponse,
model: str, optional_params: dict,
logging_obj: Any, logging_obj: Any,
data: dict, data: dict,
model_id: Optional[str], model_id: Optional[str],
): ):
timeout = 300.0 timeout = 300.0
async_handler = _get_async_httpx_client() async_handler = _get_async_httpx_client()
async_transform_prompt = asyncify(self._transform_prompt)
data["inputs"] = await async_transform_prompt(
model=model,
messages=messages,
custom_prompt_dict=custom_prompt_dict,
hf_model_name=hf_model_name,
)
asyncified_prepare_request = asyncify(self._prepare_request)
prepared_request_args = {
"model": model,
"data": data,
"optional_params": optional_params,
"credentials": credentials,
"aws_region_name": aws_region_name,
}
prepared_request = await asyncified_prepare_request(**prepared_request_args)
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=[], input=[],
@ -669,7 +751,7 @@ class SagemakerLLM(BaseAWSLLM):
completion_output = completion_output.replace(data["inputs"], "", 1) completion_output = completion_output.replace(data["inputs"], "", 1)
model_response.choices[0].message.content = completion_output # type: ignore model_response.choices[0].message.content = completion_output # type: ignore
except: except Exception:
raise SagemakerError( raise SagemakerError(
message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}", message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}",
status_code=500, status_code=500,
@ -855,21 +937,12 @@ def get_response_stream_shape():
class AWSEventStreamDecoder: class AWSEventStreamDecoder:
def __init__(self, model: str, is_messages_api: Optional[bool] = None) -> None: def __init__(self, model: str) -> None:
from botocore.parsers import EventStreamJSONParser from botocore.parsers import EventStreamJSONParser
self.model = model self.model = model
self.parser = EventStreamJSONParser() self.parser = EventStreamJSONParser()
self.content_blocks: List = [] self.content_blocks: List = []
self.is_messages_api = is_messages_api
def _chunk_parser_messages_api(
self, chunk_data: dict
) -> StreamingChatCompletionChunk:
openai_chunk = StreamingChatCompletionChunk(**chunk_data)
return openai_chunk
def _chunk_parser(self, chunk_data: dict) -> GChunk: def _chunk_parser(self, chunk_data: dict) -> GChunk:
verbose_logger.debug("in sagemaker chunk parser, chunk_data %s", chunk_data) verbose_logger.debug("in sagemaker chunk parser, chunk_data %s", chunk_data)
@ -885,7 +958,6 @@ class AWSEventStreamDecoder:
index=_index, index=_index,
is_finished=True, is_finished=True,
finish_reason="stop", finish_reason="stop",
usage=None,
) )
return GChunk( return GChunk(
@ -893,12 +965,9 @@ class AWSEventStreamDecoder:
index=_index, index=_index,
is_finished=is_finished, is_finished=is_finished,
finish_reason=finish_reason, finish_reason=finish_reason,
usage=None,
) )
def iter_bytes( def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GChunk]:
self, iterator: Iterator[bytes]
) -> Iterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]:
"""Given an iterator that yields lines, iterate over it & yield every event encountered""" """Given an iterator that yields lines, iterate over it & yield every event encountered"""
from botocore.eventstream import EventStreamBuffer from botocore.eventstream import EventStreamBuffer
@ -919,9 +988,6 @@ class AWSEventStreamDecoder:
# Try to parse the accumulated JSON # Try to parse the accumulated JSON
try: try:
_data = json.loads(accumulated_json) _data = json.loads(accumulated_json)
if self.is_messages_api:
yield self._chunk_parser_messages_api(chunk_data=_data)
else:
yield self._chunk_parser(chunk_data=_data) yield self._chunk_parser(chunk_data=_data)
# Reset accumulated_json after successful parsing # Reset accumulated_json after successful parsing
accumulated_json = "" accumulated_json = ""
@ -933,20 +999,16 @@ class AWSEventStreamDecoder:
if accumulated_json: if accumulated_json:
try: try:
_data = json.loads(accumulated_json) _data = json.loads(accumulated_json)
if self.is_messages_api:
yield self._chunk_parser_messages_api(chunk_data=_data)
else:
yield self._chunk_parser(chunk_data=_data) yield self._chunk_parser(chunk_data=_data)
except json.JSONDecodeError as e: except json.JSONDecodeError:
# Handle or log any unparseable data at the end # Handle or log any unparseable data at the end
verbose_logger.error( verbose_logger.error(
f"Warning: Unparseable JSON data remained: {accumulated_json}" f"Warning: Unparseable JSON data remained: {accumulated_json}"
) )
yield None
async def aiter_bytes( async def aiter_bytes(
self, iterator: AsyncIterator[bytes] self, iterator: AsyncIterator[bytes]
) -> AsyncIterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]: ) -> AsyncIterator[GChunk]:
"""Given an async iterator that yields lines, iterate over it & yield every event encountered""" """Given an async iterator that yields lines, iterate over it & yield every event encountered"""
from botocore.eventstream import EventStreamBuffer from botocore.eventstream import EventStreamBuffer
@ -968,9 +1030,6 @@ class AWSEventStreamDecoder:
# Try to parse the accumulated JSON # Try to parse the accumulated JSON
try: try:
_data = json.loads(accumulated_json) _data = json.loads(accumulated_json)
if self.is_messages_api:
yield self._chunk_parser_messages_api(chunk_data=_data)
else:
yield self._chunk_parser(chunk_data=_data) yield self._chunk_parser(chunk_data=_data)
# Reset accumulated_json after successful parsing # Reset accumulated_json after successful parsing
accumulated_json = "" accumulated_json = ""
@ -982,16 +1041,12 @@ class AWSEventStreamDecoder:
if accumulated_json: if accumulated_json:
try: try:
_data = json.loads(accumulated_json) _data = json.loads(accumulated_json)
if self.is_messages_api:
yield self._chunk_parser_messages_api(chunk_data=_data)
else:
yield self._chunk_parser(chunk_data=_data) yield self._chunk_parser(chunk_data=_data)
except json.JSONDecodeError: except json.JSONDecodeError:
# Handle or log any unparseable data at the end # Handle or log any unparseable data at the end
verbose_logger.error( verbose_logger.error(
f"Warning: Unparseable JSON data remained: {accumulated_json}" f"Warning: Unparseable JSON data remained: {accumulated_json}"
) )
yield None
def _parse_message_from_event(self, event) -> Optional[str]: def _parse_message_from_event(self, event) -> Optional[str]:
response_dict = event.to_response_dict() response_dict = event.to_response_dict()

View file

@ -1,7 +1,12 @@
model_list: model_list:
- model_name: "*" - model_name: fake-openai-endpoint
litellm_params: litellm_params:
model: "*" model: sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614
# sagemaker_base_url: https://exampleopenaiendpoint-production.up.railway.app/invocations/
# api_base: https://exampleopenaiendpoint-production.up.railway.app
general_settings:
global_max_parallel_requests: 0