forked from phoenix/litellm-mirror
perf(sagemaker.py): asyncify hf prompt template check
leads to 189% improvement in RPS @ 100 users
This commit is contained in:
parent
b0f01e5b95
commit
2cf149fbad
4 changed files with 253 additions and 125 deletions
67
litellm/litellm_core_utils/asyncify.py
Normal file
67
litellm/litellm_core_utils/asyncify.py
Normal file
|
@ -0,0 +1,67 @@
|
|||
import functools
|
||||
from typing import Awaitable, Callable, ParamSpec, TypeVar
|
||||
|
||||
import anyio
|
||||
from anyio import to_thread
|
||||
|
||||
T_ParamSpec = ParamSpec("T_ParamSpec")
|
||||
T_Retval = TypeVar("T_Retval")
|
||||
|
||||
|
||||
def function_has_argument(function: Callable, arg_name: str) -> bool:
|
||||
"""Helper function to check if a function has a specific argument."""
|
||||
import inspect
|
||||
|
||||
signature = inspect.signature(function)
|
||||
return arg_name in signature.parameters
|
||||
|
||||
|
||||
def asyncify(
|
||||
function: Callable[T_ParamSpec, T_Retval],
|
||||
*,
|
||||
cancellable: bool = False,
|
||||
limiter: anyio.CapacityLimiter | None = None,
|
||||
) -> Callable[T_ParamSpec, Awaitable[T_Retval]]:
|
||||
"""
|
||||
Take a blocking function and create an async one that receives the same
|
||||
positional and keyword arguments, and that when called, calls the original function
|
||||
in a worker thread using `anyio.to_thread.run_sync()`.
|
||||
|
||||
If the `cancellable` option is enabled and the task waiting for its completion is
|
||||
cancelled, the thread will still run its course but its return value (or any raised
|
||||
exception) will be ignored.
|
||||
|
||||
## Arguments
|
||||
- `function`: a blocking regular callable (e.g. a function)
|
||||
- `cancellable`: `True` to allow cancellation of the operation
|
||||
- `limiter`: capacity limiter to use to limit the total amount of threads running
|
||||
(if omitted, the default limiter is used)
|
||||
|
||||
## Return
|
||||
An async function that takes the same positional and keyword arguments as the
|
||||
original one, that when called runs the same original function in a thread worker
|
||||
and returns the result.
|
||||
"""
|
||||
|
||||
async def wrapper(
|
||||
*args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs
|
||||
) -> T_Retval:
|
||||
partial_f = functools.partial(function, *args, **kwargs)
|
||||
|
||||
# In `v4.1.0` anyio added the `abandon_on_cancel` argument and deprecated the old
|
||||
# `cancellable` argument, so we need to use the new `abandon_on_cancel` to avoid
|
||||
# surfacing deprecation warnings.
|
||||
if function_has_argument(anyio.to_thread.run_sync, "abandon_on_cancel"):
|
||||
return await anyio.to_thread.run_sync(
|
||||
partial_f,
|
||||
abandon_on_cancel=cancellable,
|
||||
limiter=limiter,
|
||||
)
|
||||
|
||||
return await anyio.to_thread.run_sync(
|
||||
partial_f,
|
||||
cancellable=cancellable,
|
||||
limiter=limiter,
|
||||
)
|
||||
|
||||
return wrapper
|
|
@ -400,6 +400,7 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] =
|
|||
tokenizer_config = known_tokenizer_config[model]
|
||||
else:
|
||||
tokenizer_config = _get_tokenizer_config(model)
|
||||
known_tokenizer_config.update({model: tokenizer_config})
|
||||
|
||||
if (
|
||||
tokenizer_config["status"] == "failure"
|
||||
|
|
|
@ -15,6 +15,7 @@ import requests # type: ignore
|
|||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.asyncify import asyncify
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
|
@ -24,11 +25,8 @@ from litellm.llms.custom_httpx.http_handler import (
|
|||
from litellm.types.llms.openai import (
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionUsageBlock,
|
||||
OpenAIChatCompletionChunk,
|
||||
)
|
||||
from litellm.types.utils import CustomStreamingDecoder
|
||||
from litellm.types.utils import GenericStreamingChunk as GChunk
|
||||
from litellm.types.utils import StreamingChatCompletionChunk
|
||||
from litellm.utils import (
|
||||
CustomStreamWrapper,
|
||||
EmbeddingResponse,
|
||||
|
@ -37,8 +35,8 @@ from litellm.utils import (
|
|||
get_secret,
|
||||
)
|
||||
|
||||
from ..base_aws_llm import BaseAWSLLM
|
||||
from ..prompt_templates.factory import custom_prompt, prompt_factory
|
||||
from .base_aws_llm import BaseAWSLLM
|
||||
from .prompt_templates.factory import custom_prompt, prompt_factory
|
||||
|
||||
_response_stream_shape_cache = None
|
||||
|
||||
|
@ -201,6 +199,49 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
|
||||
return prepped_request
|
||||
|
||||
def _transform_prompt(
|
||||
self,
|
||||
model: str,
|
||||
messages: List,
|
||||
custom_prompt_dict: dict,
|
||||
hf_model_name: Optional[str],
|
||||
) -> str:
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details.get("roles", None),
|
||||
initial_prompt_value=model_prompt_details.get(
|
||||
"initial_prompt_value", ""
|
||||
),
|
||||
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
||||
messages=messages,
|
||||
)
|
||||
elif hf_model_name in custom_prompt_dict:
|
||||
# check if the base huggingface model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[hf_model_name]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details.get("roles", None),
|
||||
initial_prompt_value=model_prompt_details.get(
|
||||
"initial_prompt_value", ""
|
||||
),
|
||||
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
if hf_model_name is None:
|
||||
if "llama-2" in model.lower(): # llama-2 model
|
||||
if "chat" in model.lower(): # apply llama2 chat template
|
||||
hf_model_name = "meta-llama/Llama-2-7b-chat-hf"
|
||||
else: # apply regular llama2 template
|
||||
hf_model_name = "meta-llama/Llama-2-7b"
|
||||
hf_model_name = (
|
||||
hf_model_name or model
|
||||
) # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt)
|
||||
prompt = prompt_factory(model=hf_model_name, messages=messages)
|
||||
|
||||
return prompt
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -244,10 +285,6 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
aws_region_name=aws_region_name,
|
||||
)
|
||||
|
||||
custom_stream_decoder = AWSEventStreamDecoder(
|
||||
model="", is_messages_api=True
|
||||
)
|
||||
|
||||
return openai_like_chat_completions.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -266,7 +303,6 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
headers=prepared_request.headers,
|
||||
custom_endpoint=True,
|
||||
custom_llm_provider="sagemaker_chat",
|
||||
streaming_decoder=custom_stream_decoder, # type: ignore
|
||||
)
|
||||
|
||||
## Load Config
|
||||
|
@ -277,42 +313,8 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
): # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details.get("roles", None),
|
||||
initial_prompt_value=model_prompt_details.get(
|
||||
"initial_prompt_value", ""
|
||||
),
|
||||
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
||||
messages=messages,
|
||||
)
|
||||
elif hf_model_name in custom_prompt_dict:
|
||||
# check if the base huggingface model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[hf_model_name]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details.get("roles", None),
|
||||
initial_prompt_value=model_prompt_details.get(
|
||||
"initial_prompt_value", ""
|
||||
),
|
||||
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
if hf_model_name is None:
|
||||
if "llama-2" in model.lower(): # llama-2 model
|
||||
if "chat" in model.lower(): # apply llama2 chat template
|
||||
hf_model_name = "meta-llama/Llama-2-7b-chat-hf"
|
||||
else: # apply regular llama2 template
|
||||
hf_model_name = "meta-llama/Llama-2-7b"
|
||||
hf_model_name = (
|
||||
hf_model_name or model
|
||||
) # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt)
|
||||
prompt = prompt_factory(model=hf_model_name, messages=messages)
|
||||
|
||||
if stream is True:
|
||||
data = {"inputs": prompt, "parameters": inference_params, "stream": True}
|
||||
data = {"parameters": inference_params, "stream": True}
|
||||
prepared_request = self._prepare_request(
|
||||
model=model,
|
||||
data=data,
|
||||
|
@ -329,18 +331,41 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
|
||||
if acompletion is True:
|
||||
response = self.async_streaming(
|
||||
prepared_request=prepared_request,
|
||||
messages=messages,
|
||||
model=model,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
hf_model_name=hf_model_name,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
model_response=model_response,
|
||||
model=model,
|
||||
logging_obj=logging_obj,
|
||||
data=data,
|
||||
model_id=model_id,
|
||||
aws_region_name=aws_region_name,
|
||||
credentials=credentials,
|
||||
)
|
||||
return response
|
||||
else:
|
||||
if stream is not None and stream is True:
|
||||
prompt = self._transform_prompt(
|
||||
model=model,
|
||||
messages=messages,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
hf_model_name=hf_model_name,
|
||||
)
|
||||
data["inputs"] = prompt
|
||||
prepared_request = self._prepare_request(
|
||||
model=model,
|
||||
data=data,
|
||||
optional_params=optional_params,
|
||||
credentials=credentials,
|
||||
aws_region_name=aws_region_name,
|
||||
)
|
||||
if model_id is not None:
|
||||
# Add model_id as InferenceComponentName header
|
||||
# boto3 doc: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_runtime_InvokeEndpoint.html
|
||||
prepared_request.headers.update(
|
||||
{"X-Amzn-SageMaker-Inference-Component": model_id}
|
||||
)
|
||||
sync_handler = _get_httpx_client()
|
||||
sync_response = sync_handler.post(
|
||||
url=prepared_request.url,
|
||||
|
@ -377,27 +402,41 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
return streaming_response
|
||||
|
||||
# Non-Streaming Requests
|
||||
_data = {"inputs": prompt, "parameters": inference_params}
|
||||
prepared_request = self._prepare_request(
|
||||
_data = {"parameters": inference_params}
|
||||
prepared_request_args = {
|
||||
"model": model,
|
||||
"data": _data,
|
||||
"optional_params": optional_params,
|
||||
"credentials": credentials,
|
||||
"aws_region_name": aws_region_name,
|
||||
}
|
||||
|
||||
# Async completion
|
||||
if acompletion is True:
|
||||
return self.async_completion(
|
||||
messages=messages,
|
||||
model=model,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
hf_model_name=hf_model_name,
|
||||
model_response=model_response,
|
||||
encoding=encoding,
|
||||
logging_obj=logging_obj,
|
||||
data=_data,
|
||||
model_id=model_id,
|
||||
optional_params=optional_params,
|
||||
credentials=credentials,
|
||||
aws_region_name=aws_region_name,
|
||||
)
|
||||
|
||||
# Async completion
|
||||
if acompletion is True:
|
||||
return self.async_completion(
|
||||
prepared_request=prepared_request,
|
||||
model_response=model_response,
|
||||
encoding=encoding,
|
||||
prompt = self._transform_prompt(
|
||||
model=model,
|
||||
logging_obj=logging_obj,
|
||||
data=_data,
|
||||
model_id=model_id,
|
||||
messages=messages,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
hf_model_name=hf_model_name,
|
||||
)
|
||||
_data["inputs"] = prompt
|
||||
## Non-Streaming completion CALL
|
||||
prepared_request = self._prepare_request(**prepared_request_args)
|
||||
try:
|
||||
if model_id is not None:
|
||||
# Add model_id as InferenceComponentName header
|
||||
|
@ -483,7 +522,7 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
completion_output = completion_output.replace(prompt, "", 1)
|
||||
|
||||
model_response.choices[0].message.content = completion_output # type: ignore
|
||||
except:
|
||||
except Exception:
|
||||
raise SagemakerError(
|
||||
message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}",
|
||||
status_code=500,
|
||||
|
@ -555,15 +594,34 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
|
||||
async def async_streaming(
|
||||
self,
|
||||
prepared_request,
|
||||
messages: list,
|
||||
model: str,
|
||||
custom_prompt_dict: dict,
|
||||
hf_model_name: Optional[str],
|
||||
credentials,
|
||||
aws_region_name: str,
|
||||
optional_params,
|
||||
encoding,
|
||||
model_response: ModelResponse,
|
||||
model: str,
|
||||
model_id: Optional[str],
|
||||
logging_obj: Any,
|
||||
data,
|
||||
):
|
||||
data["inputs"] = self._transform_prompt(
|
||||
model=model,
|
||||
messages=messages,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
hf_model_name=hf_model_name,
|
||||
)
|
||||
asyncified_prepare_request = asyncify(self._prepare_request)
|
||||
prepared_request_args = {
|
||||
"model": model,
|
||||
"data": data,
|
||||
"optional_params": optional_params,
|
||||
"credentials": credentials,
|
||||
"aws_region_name": aws_region_name,
|
||||
}
|
||||
prepared_request = await asyncified_prepare_request(**prepared_request_args)
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=None,
|
||||
make_call=partial(
|
||||
|
@ -590,16 +648,40 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
|
||||
async def async_completion(
|
||||
self,
|
||||
prepared_request,
|
||||
messages: list,
|
||||
model: str,
|
||||
custom_prompt_dict: dict,
|
||||
hf_model_name: Optional[str],
|
||||
credentials,
|
||||
aws_region_name: str,
|
||||
encoding,
|
||||
model_response: ModelResponse,
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
logging_obj: Any,
|
||||
data: dict,
|
||||
model_id: Optional[str],
|
||||
):
|
||||
timeout = 300.0
|
||||
async_handler = _get_async_httpx_client()
|
||||
|
||||
async_transform_prompt = asyncify(self._transform_prompt)
|
||||
|
||||
data["inputs"] = await async_transform_prompt(
|
||||
model=model,
|
||||
messages=messages,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
hf_model_name=hf_model_name,
|
||||
)
|
||||
asyncified_prepare_request = asyncify(self._prepare_request)
|
||||
prepared_request_args = {
|
||||
"model": model,
|
||||
"data": data,
|
||||
"optional_params": optional_params,
|
||||
"credentials": credentials,
|
||||
"aws_region_name": aws_region_name,
|
||||
}
|
||||
|
||||
prepared_request = await asyncified_prepare_request(**prepared_request_args)
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=[],
|
||||
|
@ -669,7 +751,7 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
completion_output = completion_output.replace(data["inputs"], "", 1)
|
||||
|
||||
model_response.choices[0].message.content = completion_output # type: ignore
|
||||
except:
|
||||
except Exception:
|
||||
raise SagemakerError(
|
||||
message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}",
|
||||
status_code=500,
|
||||
|
@ -855,21 +937,12 @@ def get_response_stream_shape():
|
|||
|
||||
|
||||
class AWSEventStreamDecoder:
|
||||
def __init__(self, model: str, is_messages_api: Optional[bool] = None) -> None:
|
||||
def __init__(self, model: str) -> None:
|
||||
from botocore.parsers import EventStreamJSONParser
|
||||
|
||||
self.model = model
|
||||
self.parser = EventStreamJSONParser()
|
||||
self.content_blocks: List = []
|
||||
self.is_messages_api = is_messages_api
|
||||
|
||||
def _chunk_parser_messages_api(
|
||||
self, chunk_data: dict
|
||||
) -> StreamingChatCompletionChunk:
|
||||
|
||||
openai_chunk = StreamingChatCompletionChunk(**chunk_data)
|
||||
|
||||
return openai_chunk
|
||||
|
||||
def _chunk_parser(self, chunk_data: dict) -> GChunk:
|
||||
verbose_logger.debug("in sagemaker chunk parser, chunk_data %s", chunk_data)
|
||||
|
@ -885,7 +958,6 @@ class AWSEventStreamDecoder:
|
|||
index=_index,
|
||||
is_finished=True,
|
||||
finish_reason="stop",
|
||||
usage=None,
|
||||
)
|
||||
|
||||
return GChunk(
|
||||
|
@ -893,12 +965,9 @@ class AWSEventStreamDecoder:
|
|||
index=_index,
|
||||
is_finished=is_finished,
|
||||
finish_reason=finish_reason,
|
||||
usage=None,
|
||||
)
|
||||
|
||||
def iter_bytes(
|
||||
self, iterator: Iterator[bytes]
|
||||
) -> Iterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]:
|
||||
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GChunk]:
|
||||
"""Given an iterator that yields lines, iterate over it & yield every event encountered"""
|
||||
from botocore.eventstream import EventStreamBuffer
|
||||
|
||||
|
@ -919,9 +988,6 @@ class AWSEventStreamDecoder:
|
|||
# Try to parse the accumulated JSON
|
||||
try:
|
||||
_data = json.loads(accumulated_json)
|
||||
if self.is_messages_api:
|
||||
yield self._chunk_parser_messages_api(chunk_data=_data)
|
||||
else:
|
||||
yield self._chunk_parser(chunk_data=_data)
|
||||
# Reset accumulated_json after successful parsing
|
||||
accumulated_json = ""
|
||||
|
@ -933,20 +999,16 @@ class AWSEventStreamDecoder:
|
|||
if accumulated_json:
|
||||
try:
|
||||
_data = json.loads(accumulated_json)
|
||||
if self.is_messages_api:
|
||||
yield self._chunk_parser_messages_api(chunk_data=_data)
|
||||
else:
|
||||
yield self._chunk_parser(chunk_data=_data)
|
||||
except json.JSONDecodeError as e:
|
||||
except json.JSONDecodeError:
|
||||
# Handle or log any unparseable data at the end
|
||||
verbose_logger.error(
|
||||
f"Warning: Unparseable JSON data remained: {accumulated_json}"
|
||||
)
|
||||
yield None
|
||||
|
||||
async def aiter_bytes(
|
||||
self, iterator: AsyncIterator[bytes]
|
||||
) -> AsyncIterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]:
|
||||
) -> AsyncIterator[GChunk]:
|
||||
"""Given an async iterator that yields lines, iterate over it & yield every event encountered"""
|
||||
from botocore.eventstream import EventStreamBuffer
|
||||
|
||||
|
@ -968,9 +1030,6 @@ class AWSEventStreamDecoder:
|
|||
# Try to parse the accumulated JSON
|
||||
try:
|
||||
_data = json.loads(accumulated_json)
|
||||
if self.is_messages_api:
|
||||
yield self._chunk_parser_messages_api(chunk_data=_data)
|
||||
else:
|
||||
yield self._chunk_parser(chunk_data=_data)
|
||||
# Reset accumulated_json after successful parsing
|
||||
accumulated_json = ""
|
||||
|
@ -982,16 +1041,12 @@ class AWSEventStreamDecoder:
|
|||
if accumulated_json:
|
||||
try:
|
||||
_data = json.loads(accumulated_json)
|
||||
if self.is_messages_api:
|
||||
yield self._chunk_parser_messages_api(chunk_data=_data)
|
||||
else:
|
||||
yield self._chunk_parser(chunk_data=_data)
|
||||
except json.JSONDecodeError:
|
||||
# Handle or log any unparseable data at the end
|
||||
verbose_logger.error(
|
||||
f"Warning: Unparseable JSON data remained: {accumulated_json}"
|
||||
)
|
||||
yield None
|
||||
|
||||
def _parse_message_from_event(self, event) -> Optional[str]:
|
||||
response_dict = event.to_response_dict()
|
||||
|
|
|
@ -1,7 +1,12 @@
|
|||
model_list:
|
||||
- model_name: "*"
|
||||
- model_name: fake-openai-endpoint
|
||||
litellm_params:
|
||||
model: "*"
|
||||
model: sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614
|
||||
# sagemaker_base_url: https://exampleopenaiendpoint-production.up.railway.app/invocations/
|
||||
# api_base: https://exampleopenaiendpoint-production.up.railway.app
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
general_settings:
|
||||
global_max_parallel_requests: 0
|
Loading…
Add table
Add a link
Reference in a new issue