mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
perf(sagemaker.py): asyncify hf prompt template check
leads to 189% improvement in RPS @ 100 users
This commit is contained in:
parent
b0f01e5b95
commit
2cf149fbad
4 changed files with 253 additions and 125 deletions
67
litellm/litellm_core_utils/asyncify.py
Normal file
67
litellm/litellm_core_utils/asyncify.py
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
import functools
|
||||||
|
from typing import Awaitable, Callable, ParamSpec, TypeVar
|
||||||
|
|
||||||
|
import anyio
|
||||||
|
from anyio import to_thread
|
||||||
|
|
||||||
|
T_ParamSpec = ParamSpec("T_ParamSpec")
|
||||||
|
T_Retval = TypeVar("T_Retval")
|
||||||
|
|
||||||
|
|
||||||
|
def function_has_argument(function: Callable, arg_name: str) -> bool:
|
||||||
|
"""Helper function to check if a function has a specific argument."""
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
signature = inspect.signature(function)
|
||||||
|
return arg_name in signature.parameters
|
||||||
|
|
||||||
|
|
||||||
|
def asyncify(
|
||||||
|
function: Callable[T_ParamSpec, T_Retval],
|
||||||
|
*,
|
||||||
|
cancellable: bool = False,
|
||||||
|
limiter: anyio.CapacityLimiter | None = None,
|
||||||
|
) -> Callable[T_ParamSpec, Awaitable[T_Retval]]:
|
||||||
|
"""
|
||||||
|
Take a blocking function and create an async one that receives the same
|
||||||
|
positional and keyword arguments, and that when called, calls the original function
|
||||||
|
in a worker thread using `anyio.to_thread.run_sync()`.
|
||||||
|
|
||||||
|
If the `cancellable` option is enabled and the task waiting for its completion is
|
||||||
|
cancelled, the thread will still run its course but its return value (or any raised
|
||||||
|
exception) will be ignored.
|
||||||
|
|
||||||
|
## Arguments
|
||||||
|
- `function`: a blocking regular callable (e.g. a function)
|
||||||
|
- `cancellable`: `True` to allow cancellation of the operation
|
||||||
|
- `limiter`: capacity limiter to use to limit the total amount of threads running
|
||||||
|
(if omitted, the default limiter is used)
|
||||||
|
|
||||||
|
## Return
|
||||||
|
An async function that takes the same positional and keyword arguments as the
|
||||||
|
original one, that when called runs the same original function in a thread worker
|
||||||
|
and returns the result.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def wrapper(
|
||||||
|
*args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs
|
||||||
|
) -> T_Retval:
|
||||||
|
partial_f = functools.partial(function, *args, **kwargs)
|
||||||
|
|
||||||
|
# In `v4.1.0` anyio added the `abandon_on_cancel` argument and deprecated the old
|
||||||
|
# `cancellable` argument, so we need to use the new `abandon_on_cancel` to avoid
|
||||||
|
# surfacing deprecation warnings.
|
||||||
|
if function_has_argument(anyio.to_thread.run_sync, "abandon_on_cancel"):
|
||||||
|
return await anyio.to_thread.run_sync(
|
||||||
|
partial_f,
|
||||||
|
abandon_on_cancel=cancellable,
|
||||||
|
limiter=limiter,
|
||||||
|
)
|
||||||
|
|
||||||
|
return await anyio.to_thread.run_sync(
|
||||||
|
partial_f,
|
||||||
|
cancellable=cancellable,
|
||||||
|
limiter=limiter,
|
||||||
|
)
|
||||||
|
|
||||||
|
return wrapper
|
|
@ -400,6 +400,7 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] =
|
||||||
tokenizer_config = known_tokenizer_config[model]
|
tokenizer_config = known_tokenizer_config[model]
|
||||||
else:
|
else:
|
||||||
tokenizer_config = _get_tokenizer_config(model)
|
tokenizer_config = _get_tokenizer_config(model)
|
||||||
|
known_tokenizer_config.update({model: tokenizer_config})
|
||||||
|
|
||||||
if (
|
if (
|
||||||
tokenizer_config["status"] == "failure"
|
tokenizer_config["status"] == "failure"
|
||||||
|
|
|
@ -15,6 +15,7 @@ import requests # type: ignore
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
|
from litellm.litellm_core_utils.asyncify import asyncify
|
||||||
from litellm.llms.custom_httpx.http_handler import (
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
AsyncHTTPHandler,
|
AsyncHTTPHandler,
|
||||||
HTTPHandler,
|
HTTPHandler,
|
||||||
|
@ -24,11 +25,8 @@ from litellm.llms.custom_httpx.http_handler import (
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
ChatCompletionToolCallChunk,
|
ChatCompletionToolCallChunk,
|
||||||
ChatCompletionUsageBlock,
|
ChatCompletionUsageBlock,
|
||||||
OpenAIChatCompletionChunk,
|
|
||||||
)
|
)
|
||||||
from litellm.types.utils import CustomStreamingDecoder
|
|
||||||
from litellm.types.utils import GenericStreamingChunk as GChunk
|
from litellm.types.utils import GenericStreamingChunk as GChunk
|
||||||
from litellm.types.utils import StreamingChatCompletionChunk
|
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
CustomStreamWrapper,
|
CustomStreamWrapper,
|
||||||
EmbeddingResponse,
|
EmbeddingResponse,
|
||||||
|
@ -37,8 +35,8 @@ from litellm.utils import (
|
||||||
get_secret,
|
get_secret,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..base_aws_llm import BaseAWSLLM
|
from .base_aws_llm import BaseAWSLLM
|
||||||
from ..prompt_templates.factory import custom_prompt, prompt_factory
|
from .prompt_templates.factory import custom_prompt, prompt_factory
|
||||||
|
|
||||||
_response_stream_shape_cache = None
|
_response_stream_shape_cache = None
|
||||||
|
|
||||||
|
@ -201,6 +199,49 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
|
|
||||||
return prepped_request
|
return prepped_request
|
||||||
|
|
||||||
|
def _transform_prompt(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List,
|
||||||
|
custom_prompt_dict: dict,
|
||||||
|
hf_model_name: Optional[str],
|
||||||
|
) -> str:
|
||||||
|
if model in custom_prompt_dict:
|
||||||
|
# check if the model has a registered custom prompt
|
||||||
|
model_prompt_details = custom_prompt_dict[model]
|
||||||
|
prompt = custom_prompt(
|
||||||
|
role_dict=model_prompt_details.get("roles", None),
|
||||||
|
initial_prompt_value=model_prompt_details.get(
|
||||||
|
"initial_prompt_value", ""
|
||||||
|
),
|
||||||
|
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
elif hf_model_name in custom_prompt_dict:
|
||||||
|
# check if the base huggingface model has a registered custom prompt
|
||||||
|
model_prompt_details = custom_prompt_dict[hf_model_name]
|
||||||
|
prompt = custom_prompt(
|
||||||
|
role_dict=model_prompt_details.get("roles", None),
|
||||||
|
initial_prompt_value=model_prompt_details.get(
|
||||||
|
"initial_prompt_value", ""
|
||||||
|
),
|
||||||
|
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if hf_model_name is None:
|
||||||
|
if "llama-2" in model.lower(): # llama-2 model
|
||||||
|
if "chat" in model.lower(): # apply llama2 chat template
|
||||||
|
hf_model_name = "meta-llama/Llama-2-7b-chat-hf"
|
||||||
|
else: # apply regular llama2 template
|
||||||
|
hf_model_name = "meta-llama/Llama-2-7b"
|
||||||
|
hf_model_name = (
|
||||||
|
hf_model_name or model
|
||||||
|
) # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt)
|
||||||
|
prompt = prompt_factory(model=hf_model_name, messages=messages)
|
||||||
|
|
||||||
|
return prompt
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -244,10 +285,6 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
aws_region_name=aws_region_name,
|
aws_region_name=aws_region_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
custom_stream_decoder = AWSEventStreamDecoder(
|
|
||||||
model="", is_messages_api=True
|
|
||||||
)
|
|
||||||
|
|
||||||
return openai_like_chat_completions.completion(
|
return openai_like_chat_completions.completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -266,7 +303,6 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
headers=prepared_request.headers,
|
headers=prepared_request.headers,
|
||||||
custom_endpoint=True,
|
custom_endpoint=True,
|
||||||
custom_llm_provider="sagemaker_chat",
|
custom_llm_provider="sagemaker_chat",
|
||||||
streaming_decoder=custom_stream_decoder, # type: ignore
|
|
||||||
)
|
)
|
||||||
|
|
||||||
## Load Config
|
## Load Config
|
||||||
|
@ -277,42 +313,8 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
): # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in
|
): # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
inference_params[k] = v
|
inference_params[k] = v
|
||||||
|
|
||||||
if model in custom_prompt_dict:
|
|
||||||
# check if the model has a registered custom prompt
|
|
||||||
model_prompt_details = custom_prompt_dict[model]
|
|
||||||
prompt = custom_prompt(
|
|
||||||
role_dict=model_prompt_details.get("roles", None),
|
|
||||||
initial_prompt_value=model_prompt_details.get(
|
|
||||||
"initial_prompt_value", ""
|
|
||||||
),
|
|
||||||
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
|
||||||
messages=messages,
|
|
||||||
)
|
|
||||||
elif hf_model_name in custom_prompt_dict:
|
|
||||||
# check if the base huggingface model has a registered custom prompt
|
|
||||||
model_prompt_details = custom_prompt_dict[hf_model_name]
|
|
||||||
prompt = custom_prompt(
|
|
||||||
role_dict=model_prompt_details.get("roles", None),
|
|
||||||
initial_prompt_value=model_prompt_details.get(
|
|
||||||
"initial_prompt_value", ""
|
|
||||||
),
|
|
||||||
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
|
||||||
messages=messages,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if hf_model_name is None:
|
|
||||||
if "llama-2" in model.lower(): # llama-2 model
|
|
||||||
if "chat" in model.lower(): # apply llama2 chat template
|
|
||||||
hf_model_name = "meta-llama/Llama-2-7b-chat-hf"
|
|
||||||
else: # apply regular llama2 template
|
|
||||||
hf_model_name = "meta-llama/Llama-2-7b"
|
|
||||||
hf_model_name = (
|
|
||||||
hf_model_name or model
|
|
||||||
) # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt)
|
|
||||||
prompt = prompt_factory(model=hf_model_name, messages=messages)
|
|
||||||
|
|
||||||
if stream is True:
|
if stream is True:
|
||||||
data = {"inputs": prompt, "parameters": inference_params, "stream": True}
|
data = {"parameters": inference_params, "stream": True}
|
||||||
prepared_request = self._prepare_request(
|
prepared_request = self._prepare_request(
|
||||||
model=model,
|
model=model,
|
||||||
data=data,
|
data=data,
|
||||||
|
@ -329,18 +331,41 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
|
|
||||||
if acompletion is True:
|
if acompletion is True:
|
||||||
response = self.async_streaming(
|
response = self.async_streaming(
|
||||||
prepared_request=prepared_request,
|
messages=messages,
|
||||||
|
model=model,
|
||||||
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
|
hf_model_name=hf_model_name,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
model=model,
|
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
data=data,
|
data=data,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
|
aws_region_name=aws_region_name,
|
||||||
|
credentials=credentials,
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
else:
|
else:
|
||||||
if stream is not None and stream is True:
|
prompt = self._transform_prompt(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
|
hf_model_name=hf_model_name,
|
||||||
|
)
|
||||||
|
data["inputs"] = prompt
|
||||||
|
prepared_request = self._prepare_request(
|
||||||
|
model=model,
|
||||||
|
data=data,
|
||||||
|
optional_params=optional_params,
|
||||||
|
credentials=credentials,
|
||||||
|
aws_region_name=aws_region_name,
|
||||||
|
)
|
||||||
|
if model_id is not None:
|
||||||
|
# Add model_id as InferenceComponentName header
|
||||||
|
# boto3 doc: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_runtime_InvokeEndpoint.html
|
||||||
|
prepared_request.headers.update(
|
||||||
|
{"X-Amzn-SageMaker-Inference-Component": model_id}
|
||||||
|
)
|
||||||
sync_handler = _get_httpx_client()
|
sync_handler = _get_httpx_client()
|
||||||
sync_response = sync_handler.post(
|
sync_response = sync_handler.post(
|
||||||
url=prepared_request.url,
|
url=prepared_request.url,
|
||||||
|
@ -377,27 +402,41 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
return streaming_response
|
return streaming_response
|
||||||
|
|
||||||
# Non-Streaming Requests
|
# Non-Streaming Requests
|
||||||
_data = {"inputs": prompt, "parameters": inference_params}
|
_data = {"parameters": inference_params}
|
||||||
prepared_request = self._prepare_request(
|
prepared_request_args = {
|
||||||
|
"model": model,
|
||||||
|
"data": _data,
|
||||||
|
"optional_params": optional_params,
|
||||||
|
"credentials": credentials,
|
||||||
|
"aws_region_name": aws_region_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Async completion
|
||||||
|
if acompletion is True:
|
||||||
|
return self.async_completion(
|
||||||
|
messages=messages,
|
||||||
model=model,
|
model=model,
|
||||||
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
|
hf_model_name=hf_model_name,
|
||||||
|
model_response=model_response,
|
||||||
|
encoding=encoding,
|
||||||
|
logging_obj=logging_obj,
|
||||||
data=_data,
|
data=_data,
|
||||||
|
model_id=model_id,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
aws_region_name=aws_region_name,
|
aws_region_name=aws_region_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Async completion
|
prompt = self._transform_prompt(
|
||||||
if acompletion is True:
|
|
||||||
return self.async_completion(
|
|
||||||
prepared_request=prepared_request,
|
|
||||||
model_response=model_response,
|
|
||||||
encoding=encoding,
|
|
||||||
model=model,
|
model=model,
|
||||||
logging_obj=logging_obj,
|
messages=messages,
|
||||||
data=_data,
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
model_id=model_id,
|
hf_model_name=hf_model_name,
|
||||||
)
|
)
|
||||||
|
_data["inputs"] = prompt
|
||||||
## Non-Streaming completion CALL
|
## Non-Streaming completion CALL
|
||||||
|
prepared_request = self._prepare_request(**prepared_request_args)
|
||||||
try:
|
try:
|
||||||
if model_id is not None:
|
if model_id is not None:
|
||||||
# Add model_id as InferenceComponentName header
|
# Add model_id as InferenceComponentName header
|
||||||
|
@ -483,7 +522,7 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
completion_output = completion_output.replace(prompt, "", 1)
|
completion_output = completion_output.replace(prompt, "", 1)
|
||||||
|
|
||||||
model_response.choices[0].message.content = completion_output # type: ignore
|
model_response.choices[0].message.content = completion_output # type: ignore
|
||||||
except:
|
except Exception:
|
||||||
raise SagemakerError(
|
raise SagemakerError(
|
||||||
message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}",
|
message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}",
|
||||||
status_code=500,
|
status_code=500,
|
||||||
|
@ -555,15 +594,34 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
|
|
||||||
async def async_streaming(
|
async def async_streaming(
|
||||||
self,
|
self,
|
||||||
prepared_request,
|
messages: list,
|
||||||
|
model: str,
|
||||||
|
custom_prompt_dict: dict,
|
||||||
|
hf_model_name: Optional[str],
|
||||||
|
credentials,
|
||||||
|
aws_region_name: str,
|
||||||
optional_params,
|
optional_params,
|
||||||
encoding,
|
encoding,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
model: str,
|
|
||||||
model_id: Optional[str],
|
model_id: Optional[str],
|
||||||
logging_obj: Any,
|
logging_obj: Any,
|
||||||
data,
|
data,
|
||||||
):
|
):
|
||||||
|
data["inputs"] = self._transform_prompt(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
|
hf_model_name=hf_model_name,
|
||||||
|
)
|
||||||
|
asyncified_prepare_request = asyncify(self._prepare_request)
|
||||||
|
prepared_request_args = {
|
||||||
|
"model": model,
|
||||||
|
"data": data,
|
||||||
|
"optional_params": optional_params,
|
||||||
|
"credentials": credentials,
|
||||||
|
"aws_region_name": aws_region_name,
|
||||||
|
}
|
||||||
|
prepared_request = await asyncified_prepare_request(**prepared_request_args)
|
||||||
streaming_response = CustomStreamWrapper(
|
streaming_response = CustomStreamWrapper(
|
||||||
completion_stream=None,
|
completion_stream=None,
|
||||||
make_call=partial(
|
make_call=partial(
|
||||||
|
@ -590,16 +648,40 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
|
|
||||||
async def async_completion(
|
async def async_completion(
|
||||||
self,
|
self,
|
||||||
prepared_request,
|
messages: list,
|
||||||
|
model: str,
|
||||||
|
custom_prompt_dict: dict,
|
||||||
|
hf_model_name: Optional[str],
|
||||||
|
credentials,
|
||||||
|
aws_region_name: str,
|
||||||
encoding,
|
encoding,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
model: str,
|
optional_params: dict,
|
||||||
logging_obj: Any,
|
logging_obj: Any,
|
||||||
data: dict,
|
data: dict,
|
||||||
model_id: Optional[str],
|
model_id: Optional[str],
|
||||||
):
|
):
|
||||||
timeout = 300.0
|
timeout = 300.0
|
||||||
async_handler = _get_async_httpx_client()
|
async_handler = _get_async_httpx_client()
|
||||||
|
|
||||||
|
async_transform_prompt = asyncify(self._transform_prompt)
|
||||||
|
|
||||||
|
data["inputs"] = await async_transform_prompt(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
|
hf_model_name=hf_model_name,
|
||||||
|
)
|
||||||
|
asyncified_prepare_request = asyncify(self._prepare_request)
|
||||||
|
prepared_request_args = {
|
||||||
|
"model": model,
|
||||||
|
"data": data,
|
||||||
|
"optional_params": optional_params,
|
||||||
|
"credentials": credentials,
|
||||||
|
"aws_region_name": aws_region_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
prepared_request = await asyncified_prepare_request(**prepared_request_args)
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=[],
|
input=[],
|
||||||
|
@ -669,7 +751,7 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
completion_output = completion_output.replace(data["inputs"], "", 1)
|
completion_output = completion_output.replace(data["inputs"], "", 1)
|
||||||
|
|
||||||
model_response.choices[0].message.content = completion_output # type: ignore
|
model_response.choices[0].message.content = completion_output # type: ignore
|
||||||
except:
|
except Exception:
|
||||||
raise SagemakerError(
|
raise SagemakerError(
|
||||||
message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}",
|
message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}",
|
||||||
status_code=500,
|
status_code=500,
|
||||||
|
@ -855,21 +937,12 @@ def get_response_stream_shape():
|
||||||
|
|
||||||
|
|
||||||
class AWSEventStreamDecoder:
|
class AWSEventStreamDecoder:
|
||||||
def __init__(self, model: str, is_messages_api: Optional[bool] = None) -> None:
|
def __init__(self, model: str) -> None:
|
||||||
from botocore.parsers import EventStreamJSONParser
|
from botocore.parsers import EventStreamJSONParser
|
||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
self.parser = EventStreamJSONParser()
|
self.parser = EventStreamJSONParser()
|
||||||
self.content_blocks: List = []
|
self.content_blocks: List = []
|
||||||
self.is_messages_api = is_messages_api
|
|
||||||
|
|
||||||
def _chunk_parser_messages_api(
|
|
||||||
self, chunk_data: dict
|
|
||||||
) -> StreamingChatCompletionChunk:
|
|
||||||
|
|
||||||
openai_chunk = StreamingChatCompletionChunk(**chunk_data)
|
|
||||||
|
|
||||||
return openai_chunk
|
|
||||||
|
|
||||||
def _chunk_parser(self, chunk_data: dict) -> GChunk:
|
def _chunk_parser(self, chunk_data: dict) -> GChunk:
|
||||||
verbose_logger.debug("in sagemaker chunk parser, chunk_data %s", chunk_data)
|
verbose_logger.debug("in sagemaker chunk parser, chunk_data %s", chunk_data)
|
||||||
|
@ -885,7 +958,6 @@ class AWSEventStreamDecoder:
|
||||||
index=_index,
|
index=_index,
|
||||||
is_finished=True,
|
is_finished=True,
|
||||||
finish_reason="stop",
|
finish_reason="stop",
|
||||||
usage=None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return GChunk(
|
return GChunk(
|
||||||
|
@ -893,12 +965,9 @@ class AWSEventStreamDecoder:
|
||||||
index=_index,
|
index=_index,
|
||||||
is_finished=is_finished,
|
is_finished=is_finished,
|
||||||
finish_reason=finish_reason,
|
finish_reason=finish_reason,
|
||||||
usage=None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def iter_bytes(
|
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GChunk]:
|
||||||
self, iterator: Iterator[bytes]
|
|
||||||
) -> Iterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]:
|
|
||||||
"""Given an iterator that yields lines, iterate over it & yield every event encountered"""
|
"""Given an iterator that yields lines, iterate over it & yield every event encountered"""
|
||||||
from botocore.eventstream import EventStreamBuffer
|
from botocore.eventstream import EventStreamBuffer
|
||||||
|
|
||||||
|
@ -919,9 +988,6 @@ class AWSEventStreamDecoder:
|
||||||
# Try to parse the accumulated JSON
|
# Try to parse the accumulated JSON
|
||||||
try:
|
try:
|
||||||
_data = json.loads(accumulated_json)
|
_data = json.loads(accumulated_json)
|
||||||
if self.is_messages_api:
|
|
||||||
yield self._chunk_parser_messages_api(chunk_data=_data)
|
|
||||||
else:
|
|
||||||
yield self._chunk_parser(chunk_data=_data)
|
yield self._chunk_parser(chunk_data=_data)
|
||||||
# Reset accumulated_json after successful parsing
|
# Reset accumulated_json after successful parsing
|
||||||
accumulated_json = ""
|
accumulated_json = ""
|
||||||
|
@ -933,20 +999,16 @@ class AWSEventStreamDecoder:
|
||||||
if accumulated_json:
|
if accumulated_json:
|
||||||
try:
|
try:
|
||||||
_data = json.loads(accumulated_json)
|
_data = json.loads(accumulated_json)
|
||||||
if self.is_messages_api:
|
|
||||||
yield self._chunk_parser_messages_api(chunk_data=_data)
|
|
||||||
else:
|
|
||||||
yield self._chunk_parser(chunk_data=_data)
|
yield self._chunk_parser(chunk_data=_data)
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError:
|
||||||
# Handle or log any unparseable data at the end
|
# Handle or log any unparseable data at the end
|
||||||
verbose_logger.error(
|
verbose_logger.error(
|
||||||
f"Warning: Unparseable JSON data remained: {accumulated_json}"
|
f"Warning: Unparseable JSON data remained: {accumulated_json}"
|
||||||
)
|
)
|
||||||
yield None
|
|
||||||
|
|
||||||
async def aiter_bytes(
|
async def aiter_bytes(
|
||||||
self, iterator: AsyncIterator[bytes]
|
self, iterator: AsyncIterator[bytes]
|
||||||
) -> AsyncIterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]:
|
) -> AsyncIterator[GChunk]:
|
||||||
"""Given an async iterator that yields lines, iterate over it & yield every event encountered"""
|
"""Given an async iterator that yields lines, iterate over it & yield every event encountered"""
|
||||||
from botocore.eventstream import EventStreamBuffer
|
from botocore.eventstream import EventStreamBuffer
|
||||||
|
|
||||||
|
@ -968,9 +1030,6 @@ class AWSEventStreamDecoder:
|
||||||
# Try to parse the accumulated JSON
|
# Try to parse the accumulated JSON
|
||||||
try:
|
try:
|
||||||
_data = json.loads(accumulated_json)
|
_data = json.loads(accumulated_json)
|
||||||
if self.is_messages_api:
|
|
||||||
yield self._chunk_parser_messages_api(chunk_data=_data)
|
|
||||||
else:
|
|
||||||
yield self._chunk_parser(chunk_data=_data)
|
yield self._chunk_parser(chunk_data=_data)
|
||||||
# Reset accumulated_json after successful parsing
|
# Reset accumulated_json after successful parsing
|
||||||
accumulated_json = ""
|
accumulated_json = ""
|
||||||
|
@ -982,16 +1041,12 @@ class AWSEventStreamDecoder:
|
||||||
if accumulated_json:
|
if accumulated_json:
|
||||||
try:
|
try:
|
||||||
_data = json.loads(accumulated_json)
|
_data = json.loads(accumulated_json)
|
||||||
if self.is_messages_api:
|
|
||||||
yield self._chunk_parser_messages_api(chunk_data=_data)
|
|
||||||
else:
|
|
||||||
yield self._chunk_parser(chunk_data=_data)
|
yield self._chunk_parser(chunk_data=_data)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
# Handle or log any unparseable data at the end
|
# Handle or log any unparseable data at the end
|
||||||
verbose_logger.error(
|
verbose_logger.error(
|
||||||
f"Warning: Unparseable JSON data remained: {accumulated_json}"
|
f"Warning: Unparseable JSON data remained: {accumulated_json}"
|
||||||
)
|
)
|
||||||
yield None
|
|
||||||
|
|
||||||
def _parse_message_from_event(self, event) -> Optional[str]:
|
def _parse_message_from_event(self, event) -> Optional[str]:
|
||||||
response_dict = event.to_response_dict()
|
response_dict = event.to_response_dict()
|
||||||
|
|
|
@ -1,7 +1,12 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: "*"
|
- model_name: fake-openai-endpoint
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: "*"
|
model: sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614
|
||||||
|
# sagemaker_base_url: https://exampleopenaiendpoint-production.up.railway.app/invocations/
|
||||||
|
# api_base: https://exampleopenaiendpoint-production.up.railway.app
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
general_settings:
|
|
||||||
global_max_parallel_requests: 0
|
|
Loading…
Add table
Add a link
Reference in a new issue