From 2cf149fbad34a2c30812cc46aaa98c236252cb69 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 23 Aug 2024 15:45:42 -0700 Subject: [PATCH] perf(sagemaker.py): asyncify hf prompt template check leads to 189% improvement in RPS @ 100 users --- litellm/litellm_core_utils/asyncify.py | 67 +++++ litellm/llms/prompt_templates/factory.py | 1 + litellm/llms/sagemaker/sagemaker.py | 297 ++++++++++++++--------- litellm/proxy/_new_secret_config.yaml | 13 +- 4 files changed, 253 insertions(+), 125 deletions(-) create mode 100644 litellm/litellm_core_utils/asyncify.py diff --git a/litellm/litellm_core_utils/asyncify.py b/litellm/litellm_core_utils/asyncify.py new file mode 100644 index 000000000..a9ce62a8b --- /dev/null +++ b/litellm/litellm_core_utils/asyncify.py @@ -0,0 +1,67 @@ +import functools +from typing import Awaitable, Callable, ParamSpec, TypeVar + +import anyio +from anyio import to_thread + +T_ParamSpec = ParamSpec("T_ParamSpec") +T_Retval = TypeVar("T_Retval") + + +def function_has_argument(function: Callable, arg_name: str) -> bool: + """Helper function to check if a function has a specific argument.""" + import inspect + + signature = inspect.signature(function) + return arg_name in signature.parameters + + +def asyncify( + function: Callable[T_ParamSpec, T_Retval], + *, + cancellable: bool = False, + limiter: anyio.CapacityLimiter | None = None, +) -> Callable[T_ParamSpec, Awaitable[T_Retval]]: + """ + Take a blocking function and create an async one that receives the same + positional and keyword arguments, and that when called, calls the original function + in a worker thread using `anyio.to_thread.run_sync()`. + + If the `cancellable` option is enabled and the task waiting for its completion is + cancelled, the thread will still run its course but its return value (or any raised + exception) will be ignored. + + ## Arguments + - `function`: a blocking regular callable (e.g. a function) + - `cancellable`: `True` to allow cancellation of the operation + - `limiter`: capacity limiter to use to limit the total amount of threads running + (if omitted, the default limiter is used) + + ## Return + An async function that takes the same positional and keyword arguments as the + original one, that when called runs the same original function in a thread worker + and returns the result. + """ + + async def wrapper( + *args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs + ) -> T_Retval: + partial_f = functools.partial(function, *args, **kwargs) + + # In `v4.1.0` anyio added the `abandon_on_cancel` argument and deprecated the old + # `cancellable` argument, so we need to use the new `abandon_on_cancel` to avoid + # surfacing deprecation warnings. + if function_has_argument(anyio.to_thread.run_sync, "abandon_on_cancel"): + return await anyio.to_thread.run_sync( + partial_f, + abandon_on_cancel=cancellable, + limiter=limiter, + ) + + return await anyio.to_thread.run_sync( + partial_f, + cancellable=cancellable, + limiter=limiter, + ) + + return wrapper diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index 8dbab29ad..2396cd26c 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -400,6 +400,7 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] = tokenizer_config = known_tokenizer_config[model] else: tokenizer_config = _get_tokenizer_config(model) + known_tokenizer_config.update({model: tokenizer_config}) if ( tokenizer_config["status"] == "failure" diff --git a/litellm/llms/sagemaker/sagemaker.py b/litellm/llms/sagemaker/sagemaker.py index 32f73f7ee..c83b80bcb 100644 --- a/litellm/llms/sagemaker/sagemaker.py +++ b/litellm/llms/sagemaker/sagemaker.py @@ -15,6 +15,7 @@ import requests # type: ignore import litellm from litellm._logging import verbose_logger +from litellm.litellm_core_utils.asyncify import asyncify from litellm.llms.custom_httpx.http_handler import ( AsyncHTTPHandler, HTTPHandler, @@ -24,11 +25,8 @@ from litellm.llms.custom_httpx.http_handler import ( from litellm.types.llms.openai import ( ChatCompletionToolCallChunk, ChatCompletionUsageBlock, - OpenAIChatCompletionChunk, ) -from litellm.types.utils import CustomStreamingDecoder from litellm.types.utils import GenericStreamingChunk as GChunk -from litellm.types.utils import StreamingChatCompletionChunk from litellm.utils import ( CustomStreamWrapper, EmbeddingResponse, @@ -37,8 +35,8 @@ from litellm.utils import ( get_secret, ) -from ..base_aws_llm import BaseAWSLLM -from ..prompt_templates.factory import custom_prompt, prompt_factory +from .base_aws_llm import BaseAWSLLM +from .prompt_templates.factory import custom_prompt, prompt_factory _response_stream_shape_cache = None @@ -201,6 +199,49 @@ class SagemakerLLM(BaseAWSLLM): return prepped_request + def _transform_prompt( + self, + model: str, + messages: List, + custom_prompt_dict: dict, + hf_model_name: Optional[str], + ) -> str: + if model in custom_prompt_dict: + # check if the model has a registered custom prompt + model_prompt_details = custom_prompt_dict[model] + prompt = custom_prompt( + role_dict=model_prompt_details.get("roles", None), + initial_prompt_value=model_prompt_details.get( + "initial_prompt_value", "" + ), + final_prompt_value=model_prompt_details.get("final_prompt_value", ""), + messages=messages, + ) + elif hf_model_name in custom_prompt_dict: + # check if the base huggingface model has a registered custom prompt + model_prompt_details = custom_prompt_dict[hf_model_name] + prompt = custom_prompt( + role_dict=model_prompt_details.get("roles", None), + initial_prompt_value=model_prompt_details.get( + "initial_prompt_value", "" + ), + final_prompt_value=model_prompt_details.get("final_prompt_value", ""), + messages=messages, + ) + else: + if hf_model_name is None: + if "llama-2" in model.lower(): # llama-2 model + if "chat" in model.lower(): # apply llama2 chat template + hf_model_name = "meta-llama/Llama-2-7b-chat-hf" + else: # apply regular llama2 template + hf_model_name = "meta-llama/Llama-2-7b" + hf_model_name = ( + hf_model_name or model + ) # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt) + prompt = prompt_factory(model=hf_model_name, messages=messages) + + return prompt + def completion( self, model: str, @@ -244,10 +285,6 @@ class SagemakerLLM(BaseAWSLLM): aws_region_name=aws_region_name, ) - custom_stream_decoder = AWSEventStreamDecoder( - model="", is_messages_api=True - ) - return openai_like_chat_completions.completion( model=model, messages=messages, @@ -266,7 +303,6 @@ class SagemakerLLM(BaseAWSLLM): headers=prepared_request.headers, custom_endpoint=True, custom_llm_provider="sagemaker_chat", - streaming_decoder=custom_stream_decoder, # type: ignore ) ## Load Config @@ -277,42 +313,8 @@ class SagemakerLLM(BaseAWSLLM): ): # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in inference_params[k] = v - if model in custom_prompt_dict: - # check if the model has a registered custom prompt - model_prompt_details = custom_prompt_dict[model] - prompt = custom_prompt( - role_dict=model_prompt_details.get("roles", None), - initial_prompt_value=model_prompt_details.get( - "initial_prompt_value", "" - ), - final_prompt_value=model_prompt_details.get("final_prompt_value", ""), - messages=messages, - ) - elif hf_model_name in custom_prompt_dict: - # check if the base huggingface model has a registered custom prompt - model_prompt_details = custom_prompt_dict[hf_model_name] - prompt = custom_prompt( - role_dict=model_prompt_details.get("roles", None), - initial_prompt_value=model_prompt_details.get( - "initial_prompt_value", "" - ), - final_prompt_value=model_prompt_details.get("final_prompt_value", ""), - messages=messages, - ) - else: - if hf_model_name is None: - if "llama-2" in model.lower(): # llama-2 model - if "chat" in model.lower(): # apply llama2 chat template - hf_model_name = "meta-llama/Llama-2-7b-chat-hf" - else: # apply regular llama2 template - hf_model_name = "meta-llama/Llama-2-7b" - hf_model_name = ( - hf_model_name or model - ) # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt) - prompt = prompt_factory(model=hf_model_name, messages=messages) - if stream is True: - data = {"inputs": prompt, "parameters": inference_params, "stream": True} + data = {"parameters": inference_params, "stream": True} prepared_request = self._prepare_request( model=model, data=data, @@ -329,43 +331,66 @@ class SagemakerLLM(BaseAWSLLM): if acompletion is True: response = self.async_streaming( - prepared_request=prepared_request, + messages=messages, + model=model, + custom_prompt_dict=custom_prompt_dict, + hf_model_name=hf_model_name, optional_params=optional_params, encoding=encoding, model_response=model_response, - model=model, logging_obj=logging_obj, data=data, model_id=model_id, + aws_region_name=aws_region_name, + credentials=credentials, ) return response else: - if stream is not None and stream is True: - sync_handler = _get_httpx_client() - sync_response = sync_handler.post( - url=prepared_request.url, - headers=prepared_request.headers, # type: ignore - json=data, - stream=stream, + prompt = self._transform_prompt( + model=model, + messages=messages, + custom_prompt_dict=custom_prompt_dict, + hf_model_name=hf_model_name, + ) + data["inputs"] = prompt + prepared_request = self._prepare_request( + model=model, + data=data, + optional_params=optional_params, + credentials=credentials, + aws_region_name=aws_region_name, + ) + if model_id is not None: + # Add model_id as InferenceComponentName header + # boto3 doc: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_runtime_InvokeEndpoint.html + prepared_request.headers.update( + {"X-Amzn-SageMaker-Inference-Component": model_id} + ) + sync_handler = _get_httpx_client() + sync_response = sync_handler.post( + url=prepared_request.url, + headers=prepared_request.headers, # type: ignore + json=data, + stream=stream, + ) + + if sync_response.status_code != 200: + raise SagemakerError( + status_code=sync_response.status_code, + message=sync_response.read(), ) - if sync_response.status_code != 200: - raise SagemakerError( - status_code=sync_response.status_code, - message=sync_response.read(), - ) + decoder = AWSEventStreamDecoder(model="") - decoder = AWSEventStreamDecoder(model="") - - completion_stream = decoder.iter_bytes( - sync_response.iter_bytes(chunk_size=1024) - ) - streaming_response = CustomStreamWrapper( - completion_stream=completion_stream, - model=model, - custom_llm_provider="sagemaker", - logging_obj=logging_obj, - ) + completion_stream = decoder.iter_bytes( + sync_response.iter_bytes(chunk_size=1024) + ) + streaming_response = CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider="sagemaker", + logging_obj=logging_obj, + ) ## LOGGING logging_obj.post_call( @@ -377,27 +402,41 @@ class SagemakerLLM(BaseAWSLLM): return streaming_response # Non-Streaming Requests - _data = {"inputs": prompt, "parameters": inference_params} - prepared_request = self._prepare_request( - model=model, - data=_data, - optional_params=optional_params, - credentials=credentials, - aws_region_name=aws_region_name, - ) + _data = {"parameters": inference_params} + prepared_request_args = { + "model": model, + "data": _data, + "optional_params": optional_params, + "credentials": credentials, + "aws_region_name": aws_region_name, + } # Async completion if acompletion is True: return self.async_completion( - prepared_request=prepared_request, + messages=messages, + model=model, + custom_prompt_dict=custom_prompt_dict, + hf_model_name=hf_model_name, model_response=model_response, encoding=encoding, - model=model, logging_obj=logging_obj, data=_data, model_id=model_id, + optional_params=optional_params, + credentials=credentials, + aws_region_name=aws_region_name, ) + + prompt = self._transform_prompt( + model=model, + messages=messages, + custom_prompt_dict=custom_prompt_dict, + hf_model_name=hf_model_name, + ) + _data["inputs"] = prompt ## Non-Streaming completion CALL + prepared_request = self._prepare_request(**prepared_request_args) try: if model_id is not None: # Add model_id as InferenceComponentName header @@ -483,7 +522,7 @@ class SagemakerLLM(BaseAWSLLM): completion_output = completion_output.replace(prompt, "", 1) model_response.choices[0].message.content = completion_output # type: ignore - except: + except Exception: raise SagemakerError( message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}", status_code=500, @@ -555,15 +594,34 @@ class SagemakerLLM(BaseAWSLLM): async def async_streaming( self, - prepared_request, + messages: list, + model: str, + custom_prompt_dict: dict, + hf_model_name: Optional[str], + credentials, + aws_region_name: str, optional_params, encoding, model_response: ModelResponse, - model: str, model_id: Optional[str], logging_obj: Any, data, ): + data["inputs"] = self._transform_prompt( + model=model, + messages=messages, + custom_prompt_dict=custom_prompt_dict, + hf_model_name=hf_model_name, + ) + asyncified_prepare_request = asyncify(self._prepare_request) + prepared_request_args = { + "model": model, + "data": data, + "optional_params": optional_params, + "credentials": credentials, + "aws_region_name": aws_region_name, + } + prepared_request = await asyncified_prepare_request(**prepared_request_args) streaming_response = CustomStreamWrapper( completion_stream=None, make_call=partial( @@ -590,16 +648,40 @@ class SagemakerLLM(BaseAWSLLM): async def async_completion( self, - prepared_request, + messages: list, + model: str, + custom_prompt_dict: dict, + hf_model_name: Optional[str], + credentials, + aws_region_name: str, encoding, model_response: ModelResponse, - model: str, + optional_params: dict, logging_obj: Any, data: dict, model_id: Optional[str], ): timeout = 300.0 async_handler = _get_async_httpx_client() + + async_transform_prompt = asyncify(self._transform_prompt) + + data["inputs"] = await async_transform_prompt( + model=model, + messages=messages, + custom_prompt_dict=custom_prompt_dict, + hf_model_name=hf_model_name, + ) + asyncified_prepare_request = asyncify(self._prepare_request) + prepared_request_args = { + "model": model, + "data": data, + "optional_params": optional_params, + "credentials": credentials, + "aws_region_name": aws_region_name, + } + + prepared_request = await asyncified_prepare_request(**prepared_request_args) ## LOGGING logging_obj.pre_call( input=[], @@ -669,7 +751,7 @@ class SagemakerLLM(BaseAWSLLM): completion_output = completion_output.replace(data["inputs"], "", 1) model_response.choices[0].message.content = completion_output # type: ignore - except: + except Exception: raise SagemakerError( message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}", status_code=500, @@ -855,21 +937,12 @@ def get_response_stream_shape(): class AWSEventStreamDecoder: - def __init__(self, model: str, is_messages_api: Optional[bool] = None) -> None: + def __init__(self, model: str) -> None: from botocore.parsers import EventStreamJSONParser self.model = model self.parser = EventStreamJSONParser() self.content_blocks: List = [] - self.is_messages_api = is_messages_api - - def _chunk_parser_messages_api( - self, chunk_data: dict - ) -> StreamingChatCompletionChunk: - - openai_chunk = StreamingChatCompletionChunk(**chunk_data) - - return openai_chunk def _chunk_parser(self, chunk_data: dict) -> GChunk: verbose_logger.debug("in sagemaker chunk parser, chunk_data %s", chunk_data) @@ -885,7 +958,6 @@ class AWSEventStreamDecoder: index=_index, is_finished=True, finish_reason="stop", - usage=None, ) return GChunk( @@ -893,12 +965,9 @@ class AWSEventStreamDecoder: index=_index, is_finished=is_finished, finish_reason=finish_reason, - usage=None, ) - def iter_bytes( - self, iterator: Iterator[bytes] - ) -> Iterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]: + def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GChunk]: """Given an iterator that yields lines, iterate over it & yield every event encountered""" from botocore.eventstream import EventStreamBuffer @@ -919,10 +988,7 @@ class AWSEventStreamDecoder: # Try to parse the accumulated JSON try: _data = json.loads(accumulated_json) - if self.is_messages_api: - yield self._chunk_parser_messages_api(chunk_data=_data) - else: - yield self._chunk_parser(chunk_data=_data) + yield self._chunk_parser(chunk_data=_data) # Reset accumulated_json after successful parsing accumulated_json = "" except json.JSONDecodeError: @@ -933,20 +999,16 @@ class AWSEventStreamDecoder: if accumulated_json: try: _data = json.loads(accumulated_json) - if self.is_messages_api: - yield self._chunk_parser_messages_api(chunk_data=_data) - else: - yield self._chunk_parser(chunk_data=_data) - except json.JSONDecodeError as e: + yield self._chunk_parser(chunk_data=_data) + except json.JSONDecodeError: # Handle or log any unparseable data at the end verbose_logger.error( f"Warning: Unparseable JSON data remained: {accumulated_json}" ) - yield None async def aiter_bytes( self, iterator: AsyncIterator[bytes] - ) -> AsyncIterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]: + ) -> AsyncIterator[GChunk]: """Given an async iterator that yields lines, iterate over it & yield every event encountered""" from botocore.eventstream import EventStreamBuffer @@ -968,10 +1030,7 @@ class AWSEventStreamDecoder: # Try to parse the accumulated JSON try: _data = json.loads(accumulated_json) - if self.is_messages_api: - yield self._chunk_parser_messages_api(chunk_data=_data) - else: - yield self._chunk_parser(chunk_data=_data) + yield self._chunk_parser(chunk_data=_data) # Reset accumulated_json after successful parsing accumulated_json = "" except json.JSONDecodeError: @@ -982,16 +1041,12 @@ class AWSEventStreamDecoder: if accumulated_json: try: _data = json.loads(accumulated_json) - if self.is_messages_api: - yield self._chunk_parser_messages_api(chunk_data=_data) - else: - yield self._chunk_parser(chunk_data=_data) + yield self._chunk_parser(chunk_data=_data) except json.JSONDecodeError: # Handle or log any unparseable data at the end verbose_logger.error( f"Warning: Unparseable JSON data remained: {accumulated_json}" ) - yield None def _parse_message_from_event(self, event) -> Optional[str]: response_dict = event.to_response_dict() diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index f049bfeb3..2bc2a8590 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,7 +1,12 @@ model_list: - - model_name: "*" + - model_name: fake-openai-endpoint litellm_params: - model: "*" + model: sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614 + # sagemaker_base_url: https://exampleopenaiendpoint-production.up.railway.app/invocations/ + # api_base: https://exampleopenaiendpoint-production.up.railway.app + + + + + -general_settings: - global_max_parallel_requests: 0 \ No newline at end of file