mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
fix(sagemaker.py): support streaming for messages api
Fixes https://github.com/BerriAI/litellm/issues/5372
This commit is contained in:
parent
174b1c43e3
commit
8e9acd117b
8 changed files with 142 additions and 32 deletions
|
@ -856,7 +856,7 @@ from .llms.vertex_httpx import (
|
|||
from .llms.vertex_ai import VertexAITextEmbeddingConfig
|
||||
from .llms.vertex_ai_anthropic import VertexAIAnthropicConfig
|
||||
from .llms.vertex_ai_partner import VertexAILlama3Config
|
||||
from .llms.sagemaker import SagemakerConfig
|
||||
from .llms.sagemaker.sagemaker import SagemakerConfig
|
||||
from .llms.ollama import OllamaConfig
|
||||
from .llms.ollama_chat import OllamaChatConfig
|
||||
from .llms.maritalk import MaritTalkConfig
|
||||
|
|
|
@ -13,4 +13,6 @@ def generic_chunk_has_all_required_fields(chunk: dict) -> bool:
|
|||
# this is an optional field in GenericStreamingChunk, it's not required to be present
|
||||
_all_fields.pop("provider_specific_fields", None)
|
||||
|
||||
return all(key in chunk for key in _all_fields)
|
||||
decision = all(key in _all_fields for key in chunk)
|
||||
|
||||
return decision
|
||||
|
|
|
@ -7,7 +7,7 @@ import time
|
|||
import types
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Callable, List, Literal, Optional, Tuple, Union
|
||||
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import httpx # type: ignore
|
||||
import requests # type: ignore
|
||||
|
@ -22,7 +22,11 @@ from litellm.types.llms.openai import (
|
|||
ChatCompletionToolCallFunctionChunk,
|
||||
ChatCompletionUsageBlock,
|
||||
)
|
||||
from litellm.types.utils import GenericStreamingChunk, ProviderField
|
||||
from litellm.types.utils import (
|
||||
CustomStreamingDecoder,
|
||||
GenericStreamingChunk,
|
||||
ProviderField,
|
||||
)
|
||||
from litellm.utils import CustomStreamWrapper, EmbeddingResponse, ModelResponse, Usage
|
||||
|
||||
from .base import BaseLLM
|
||||
|
@ -171,15 +175,21 @@ async def make_call(
|
|||
model: str,
|
||||
messages: list,
|
||||
logging_obj,
|
||||
streaming_decoder: Optional[CustomStreamingDecoder] = None,
|
||||
):
|
||||
response = await client.post(api_base, headers=headers, data=data, stream=True)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise DatabricksError(status_code=response.status_code, message=response.text)
|
||||
|
||||
completion_stream = ModelResponseIterator(
|
||||
streaming_response=response.aiter_lines(), sync_stream=False
|
||||
)
|
||||
if streaming_decoder is not None:
|
||||
completion_stream: Any = streaming_decoder.aiter_bytes(
|
||||
response.aiter_bytes(chunk_size=1024)
|
||||
)
|
||||
else:
|
||||
completion_stream = ModelResponseIterator(
|
||||
streaming_response=response.aiter_lines(), sync_stream=False
|
||||
)
|
||||
# LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
|
@ -199,6 +209,7 @@ def make_sync_call(
|
|||
model: str,
|
||||
messages: list,
|
||||
logging_obj,
|
||||
streaming_decoder: Optional[CustomStreamingDecoder] = None,
|
||||
):
|
||||
if client is None:
|
||||
client = HTTPHandler() # Create a new client if none provided
|
||||
|
@ -208,9 +219,14 @@ def make_sync_call(
|
|||
if response.status_code != 200:
|
||||
raise DatabricksError(status_code=response.status_code, message=response.read())
|
||||
|
||||
completion_stream = ModelResponseIterator(
|
||||
streaming_response=response.iter_lines(), sync_stream=True
|
||||
)
|
||||
if streaming_decoder is not None:
|
||||
completion_stream = streaming_decoder.iter_bytes(
|
||||
response.iter_bytes(chunk_size=1024)
|
||||
)
|
||||
else:
|
||||
completion_stream = ModelResponseIterator(
|
||||
streaming_response=response.iter_lines(), sync_stream=True
|
||||
)
|
||||
|
||||
# LOGGING
|
||||
logging_obj.post_call(
|
||||
|
@ -283,6 +299,7 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
logger_fn=None,
|
||||
headers={},
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
streaming_decoder: Optional[CustomStreamingDecoder] = None,
|
||||
) -> CustomStreamWrapper:
|
||||
|
||||
data["stream"] = True
|
||||
|
@ -296,6 +313,7 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
streaming_decoder=streaming_decoder,
|
||||
),
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
|
@ -371,6 +389,9 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
custom_endpoint: Optional[bool] = None,
|
||||
streaming_decoder: Optional[
|
||||
CustomStreamingDecoder
|
||||
] = None, # if openai-compatible api needs custom stream decoder - e.g. sagemaker
|
||||
):
|
||||
custom_endpoint = custom_endpoint or optional_params.pop(
|
||||
"custom_endpoint", None
|
||||
|
@ -436,6 +457,7 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
headers=headers,
|
||||
client=client,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
streaming_decoder=streaming_decoder,
|
||||
)
|
||||
else:
|
||||
return self.acompletion_function(
|
||||
|
@ -473,6 +495,7 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
streaming_decoder=streaming_decoder,
|
||||
),
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
|
|
|
@ -24,8 +24,11 @@ from litellm.llms.custom_httpx.http_handler import (
|
|||
from litellm.types.llms.openai import (
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionUsageBlock,
|
||||
OpenAIChatCompletionChunk,
|
||||
)
|
||||
from litellm.types.utils import CustomStreamingDecoder
|
||||
from litellm.types.utils import GenericStreamingChunk as GChunk
|
||||
from litellm.types.utils import StreamingChatCompletionChunk
|
||||
from litellm.utils import (
|
||||
CustomStreamWrapper,
|
||||
EmbeddingResponse,
|
||||
|
@ -34,8 +37,8 @@ from litellm.utils import (
|
|||
get_secret,
|
||||
)
|
||||
|
||||
from .base_aws_llm import BaseAWSLLM
|
||||
from .prompt_templates.factory import custom_prompt, prompt_factory
|
||||
from ..base_aws_llm import BaseAWSLLM
|
||||
from ..prompt_templates.factory import custom_prompt, prompt_factory
|
||||
|
||||
_response_stream_shape_cache = None
|
||||
|
||||
|
@ -241,6 +244,10 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
aws_region_name=aws_region_name,
|
||||
)
|
||||
|
||||
custom_stream_decoder = AWSEventStreamDecoder(
|
||||
model="", is_messages_api=True
|
||||
)
|
||||
|
||||
return openai_like_chat_completions.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -259,6 +266,7 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
headers=prepared_request.headers,
|
||||
custom_endpoint=True,
|
||||
custom_llm_provider="sagemaker_chat",
|
||||
streaming_decoder=custom_stream_decoder, # type: ignore
|
||||
)
|
||||
|
||||
## Load Config
|
||||
|
@ -332,7 +340,7 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
)
|
||||
return response
|
||||
else:
|
||||
if stream is not None and stream == True:
|
||||
if stream is not None and stream is True:
|
||||
sync_handler = _get_httpx_client()
|
||||
sync_response = sync_handler.post(
|
||||
url=prepared_request.url,
|
||||
|
@ -847,12 +855,21 @@ def get_response_stream_shape():
|
|||
|
||||
|
||||
class AWSEventStreamDecoder:
|
||||
def __init__(self, model: str) -> None:
|
||||
def __init__(self, model: str, is_messages_api: Optional[bool] = None) -> None:
|
||||
from botocore.parsers import EventStreamJSONParser
|
||||
|
||||
self.model = model
|
||||
self.parser = EventStreamJSONParser()
|
||||
self.content_blocks: List = []
|
||||
self.is_messages_api = is_messages_api
|
||||
|
||||
def _chunk_parser_messages_api(
|
||||
self, chunk_data: dict
|
||||
) -> StreamingChatCompletionChunk:
|
||||
|
||||
openai_chunk = StreamingChatCompletionChunk(**chunk_data)
|
||||
|
||||
return openai_chunk
|
||||
|
||||
def _chunk_parser(self, chunk_data: dict) -> GChunk:
|
||||
verbose_logger.debug("in sagemaker chunk parser, chunk_data %s", chunk_data)
|
||||
|
@ -868,6 +885,7 @@ class AWSEventStreamDecoder:
|
|||
index=_index,
|
||||
is_finished=True,
|
||||
finish_reason="stop",
|
||||
usage=None,
|
||||
)
|
||||
|
||||
return GChunk(
|
||||
|
@ -875,9 +893,12 @@ class AWSEventStreamDecoder:
|
|||
index=_index,
|
||||
is_finished=is_finished,
|
||||
finish_reason=finish_reason,
|
||||
usage=None,
|
||||
)
|
||||
|
||||
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GChunk]:
|
||||
def iter_bytes(
|
||||
self, iterator: Iterator[bytes]
|
||||
) -> Iterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]:
|
||||
"""Given an iterator that yields lines, iterate over it & yield every event encountered"""
|
||||
from botocore.eventstream import EventStreamBuffer
|
||||
|
||||
|
@ -898,7 +919,10 @@ class AWSEventStreamDecoder:
|
|||
# Try to parse the accumulated JSON
|
||||
try:
|
||||
_data = json.loads(accumulated_json)
|
||||
yield self._chunk_parser(chunk_data=_data)
|
||||
if self.is_messages_api:
|
||||
yield self._chunk_parser_messages_api(chunk_data=_data)
|
||||
else:
|
||||
yield self._chunk_parser(chunk_data=_data)
|
||||
# Reset accumulated_json after successful parsing
|
||||
accumulated_json = ""
|
||||
except json.JSONDecodeError:
|
||||
|
@ -909,16 +933,20 @@ class AWSEventStreamDecoder:
|
|||
if accumulated_json:
|
||||
try:
|
||||
_data = json.loads(accumulated_json)
|
||||
yield self._chunk_parser(chunk_data=_data)
|
||||
except json.JSONDecodeError:
|
||||
if self.is_messages_api:
|
||||
yield self._chunk_parser_messages_api(chunk_data=_data)
|
||||
else:
|
||||
yield self._chunk_parser(chunk_data=_data)
|
||||
except json.JSONDecodeError as e:
|
||||
# Handle or log any unparseable data at the end
|
||||
verbose_logger.error(
|
||||
f"Warning: Unparseable JSON data remained: {accumulated_json}"
|
||||
)
|
||||
yield None
|
||||
|
||||
async def aiter_bytes(
|
||||
self, iterator: AsyncIterator[bytes]
|
||||
) -> AsyncIterator[GChunk]:
|
||||
) -> AsyncIterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]:
|
||||
"""Given an async iterator that yields lines, iterate over it & yield every event encountered"""
|
||||
from botocore.eventstream import EventStreamBuffer
|
||||
|
||||
|
@ -940,7 +968,10 @@ class AWSEventStreamDecoder:
|
|||
# Try to parse the accumulated JSON
|
||||
try:
|
||||
_data = json.loads(accumulated_json)
|
||||
yield self._chunk_parser(chunk_data=_data)
|
||||
if self.is_messages_api:
|
||||
yield self._chunk_parser_messages_api(chunk_data=_data)
|
||||
else:
|
||||
yield self._chunk_parser(chunk_data=_data)
|
||||
# Reset accumulated_json after successful parsing
|
||||
accumulated_json = ""
|
||||
except json.JSONDecodeError:
|
||||
|
@ -951,12 +982,16 @@ class AWSEventStreamDecoder:
|
|||
if accumulated_json:
|
||||
try:
|
||||
_data = json.loads(accumulated_json)
|
||||
yield self._chunk_parser(chunk_data=_data)
|
||||
if self.is_messages_api:
|
||||
yield self._chunk_parser_messages_api(chunk_data=_data)
|
||||
else:
|
||||
yield self._chunk_parser(chunk_data=_data)
|
||||
except json.JSONDecodeError:
|
||||
# Handle or log any unparseable data at the end
|
||||
verbose_logger.error(
|
||||
f"Warning: Unparseable JSON data remained: {accumulated_json}"
|
||||
)
|
||||
yield None
|
||||
|
||||
def _parse_message_from_event(self, event) -> Optional[str]:
|
||||
response_dict = event.to_response_dict()
|
|
@ -119,7 +119,7 @@ from .llms.prompt_templates.factory import (
|
|||
prompt_factory,
|
||||
stringify_json_tool_call_content,
|
||||
)
|
||||
from .llms.sagemaker import SagemakerLLM
|
||||
from .llms.sagemaker.sagemaker import SagemakerLLM
|
||||
from .llms.text_completion_codestral import CodestralTextCompletion
|
||||
from .llms.text_to_speech.vertex_ai import VertexTextToSpeechAPI
|
||||
from .llms.triton import TritonChatCompletion
|
||||
|
|
|
@ -120,15 +120,24 @@ async def test_completion_sagemaker_messages_api(sync_mode):
|
|||
|
||||
@pytest.mark.asyncio()
|
||||
@pytest.mark.parametrize("sync_mode", [False, True])
|
||||
async def test_completion_sagemaker_stream(sync_mode):
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245",
|
||||
"sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
||||
],
|
||||
)
|
||||
async def test_completion_sagemaker_stream(sync_mode, model):
|
||||
try:
|
||||
from litellm.tests.test_streaming import streaming_format_tests
|
||||
|
||||
litellm.set_verbose = False
|
||||
print("testing sagemaker")
|
||||
verbose_logger.setLevel(logging.DEBUG)
|
||||
full_text = ""
|
||||
if sync_mode is True:
|
||||
response = litellm.completion(
|
||||
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "user", "content": "hi - what is ur name"},
|
||||
],
|
||||
|
@ -138,14 +147,15 @@ async def test_completion_sagemaker_stream(sync_mode):
|
|||
input_cost_per_second=0.000420,
|
||||
)
|
||||
|
||||
for chunk in response:
|
||||
for idx, chunk in enumerate(response):
|
||||
print(chunk)
|
||||
streaming_format_tests(idx=idx, chunk=chunk)
|
||||
full_text += chunk.choices[0].delta.content or ""
|
||||
|
||||
print("SYNC RESPONSE full text", full_text)
|
||||
else:
|
||||
response = await litellm.acompletion(
|
||||
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "user", "content": "hi - what is ur name"},
|
||||
],
|
||||
|
@ -156,10 +166,12 @@ async def test_completion_sagemaker_stream(sync_mode):
|
|||
)
|
||||
|
||||
print("streaming response")
|
||||
|
||||
idx = 0
|
||||
async for chunk in response:
|
||||
print(chunk)
|
||||
streaming_format_tests(idx=idx, chunk=chunk)
|
||||
full_text += chunk.choices[0].delta.content or ""
|
||||
idx += 1
|
||||
|
||||
print("ASYNC RESPONSE full text", full_text)
|
||||
|
||||
|
|
|
@ -29,6 +29,7 @@ from openai.types.beta.thread_create_params import (
|
|||
from openai.types.beta.threads.message import Message as OpenAIMessage
|
||||
from openai.types.beta.threads.message_content import MessageContent
|
||||
from openai.types.beta.threads.run import Run
|
||||
from openai.types.chat import ChatCompletionChunk
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Dict, Required, TypedDict, override
|
||||
|
||||
|
@ -456,6 +457,13 @@ class ChatCompletionUsageBlock(TypedDict):
|
|||
total_tokens: int
|
||||
|
||||
|
||||
class OpenAIChatCompletionChunk(ChatCompletionChunk):
|
||||
def __init__(self, **kwargs):
|
||||
# Set the 'object' kwarg to 'chat.completion.chunk'
|
||||
kwargs["object"] = "chat.completion.chunk"
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class Hyperparameters(BaseModel):
|
||||
batch_size: Optional[Union[str, int]] = None # "Number of examples in each batch."
|
||||
learning_rate_multiplier: Optional[Union[str, float]] = (
|
||||
|
|
|
@ -5,11 +5,16 @@ from enum import Enum
|
|||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from openai._models import BaseModel as OpenAIObject
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
from pydantic import ConfigDict, Field, PrivateAttr
|
||||
from typing_extensions import Callable, Dict, Required, TypedDict, override
|
||||
|
||||
from ..litellm_core_utils.core_helpers import map_finish_reason
|
||||
from .llms.openai import ChatCompletionToolCallChunk, ChatCompletionUsageBlock
|
||||
from .llms.openai import (
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionUsageBlock,
|
||||
OpenAIChatCompletionChunk,
|
||||
)
|
||||
|
||||
|
||||
def _generate_id(): # private helper function
|
||||
|
@ -85,7 +90,7 @@ class GenericStreamingChunk(TypedDict, total=False):
|
|||
tool_use: Optional[ChatCompletionToolCallChunk]
|
||||
is_finished: Required[bool]
|
||||
finish_reason: Required[str]
|
||||
usage: Optional[ChatCompletionUsageBlock]
|
||||
usage: Required[Optional[ChatCompletionUsageBlock]]
|
||||
index: int
|
||||
|
||||
# use this dict if you want to return any provider specific fields in the response
|
||||
|
@ -448,9 +453,6 @@ class Choices(OpenAIObject):
|
|||
setattr(self, key, value)
|
||||
|
||||
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
|
||||
|
||||
class Usage(CompletionUsage):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -535,6 +537,17 @@ class StreamingChoices(OpenAIObject):
|
|||
setattr(self, key, value)
|
||||
|
||||
|
||||
class StreamingChatCompletionChunk(OpenAIChatCompletionChunk):
|
||||
def __init__(self, **kwargs):
|
||||
|
||||
new_choices = []
|
||||
for choice in kwargs["choices"]:
|
||||
new_choice = StreamingChoices(**choice).model_dump()
|
||||
new_choices.append(new_choice)
|
||||
kwargs["choices"] = new_choices
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class ModelResponse(OpenAIObject):
|
||||
id: str
|
||||
"""A unique identifier for the completion."""
|
||||
|
@ -1231,3 +1244,20 @@ class StandardLoggingPayload(TypedDict):
|
|||
response: Optional[Union[str, list, dict]]
|
||||
model_parameters: dict
|
||||
hidden_params: StandardLoggingHiddenParams
|
||||
|
||||
|
||||
from typing import AsyncIterator, Iterator
|
||||
|
||||
|
||||
class CustomStreamingDecoder:
|
||||
async def aiter_bytes(
|
||||
self, iterator: AsyncIterator[bytes]
|
||||
) -> AsyncIterator[
|
||||
Optional[Union[GenericStreamingChunk, StreamingChatCompletionChunk]]
|
||||
]:
|
||||
raise NotImplementedError
|
||||
|
||||
def iter_bytes(
|
||||
self, iterator: Iterator[bytes]
|
||||
) -> Iterator[Optional[Union[GenericStreamingChunk, StreamingChatCompletionChunk]]]:
|
||||
raise NotImplementedError
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue