fix(sagemaker.py): support streaming for messages api

Fixes https://github.com/BerriAI/litellm/issues/5372
This commit is contained in:
Krrish Dholakia 2024-08-26 15:08:08 -07:00
parent 174b1c43e3
commit 8e9acd117b
8 changed files with 142 additions and 32 deletions

View file

@ -7,7 +7,7 @@ import time
import types
from enum import Enum
from functools import partial
from typing import Callable, List, Literal, Optional, Tuple, Union
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
import httpx # type: ignore
import requests # type: ignore
@ -22,7 +22,11 @@ from litellm.types.llms.openai import (
ChatCompletionToolCallFunctionChunk,
ChatCompletionUsageBlock,
)
from litellm.types.utils import GenericStreamingChunk, ProviderField
from litellm.types.utils import (
CustomStreamingDecoder,
GenericStreamingChunk,
ProviderField,
)
from litellm.utils import CustomStreamWrapper, EmbeddingResponse, ModelResponse, Usage
from .base import BaseLLM
@ -171,15 +175,21 @@ async def make_call(
model: str,
messages: list,
logging_obj,
streaming_decoder: Optional[CustomStreamingDecoder] = None,
):
response = await client.post(api_base, headers=headers, data=data, stream=True)
if response.status_code != 200:
raise DatabricksError(status_code=response.status_code, message=response.text)
completion_stream = ModelResponseIterator(
streaming_response=response.aiter_lines(), sync_stream=False
)
if streaming_decoder is not None:
completion_stream: Any = streaming_decoder.aiter_bytes(
response.aiter_bytes(chunk_size=1024)
)
else:
completion_stream = ModelResponseIterator(
streaming_response=response.aiter_lines(), sync_stream=False
)
# LOGGING
logging_obj.post_call(
input=messages,
@ -199,6 +209,7 @@ def make_sync_call(
model: str,
messages: list,
logging_obj,
streaming_decoder: Optional[CustomStreamingDecoder] = None,
):
if client is None:
client = HTTPHandler() # Create a new client if none provided
@ -208,9 +219,14 @@ def make_sync_call(
if response.status_code != 200:
raise DatabricksError(status_code=response.status_code, message=response.read())
completion_stream = ModelResponseIterator(
streaming_response=response.iter_lines(), sync_stream=True
)
if streaming_decoder is not None:
completion_stream = streaming_decoder.iter_bytes(
response.iter_bytes(chunk_size=1024)
)
else:
completion_stream = ModelResponseIterator(
streaming_response=response.iter_lines(), sync_stream=True
)
# LOGGING
logging_obj.post_call(
@ -283,6 +299,7 @@ class DatabricksChatCompletion(BaseLLM):
logger_fn=None,
headers={},
client: Optional[AsyncHTTPHandler] = None,
streaming_decoder: Optional[CustomStreamingDecoder] = None,
) -> CustomStreamWrapper:
data["stream"] = True
@ -296,6 +313,7 @@ class DatabricksChatCompletion(BaseLLM):
model=model,
messages=messages,
logging_obj=logging_obj,
streaming_decoder=streaming_decoder,
),
model=model,
custom_llm_provider=custom_llm_provider,
@ -371,6 +389,9 @@ class DatabricksChatCompletion(BaseLLM):
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
custom_endpoint: Optional[bool] = None,
streaming_decoder: Optional[
CustomStreamingDecoder
] = None, # if openai-compatible api needs custom stream decoder - e.g. sagemaker
):
custom_endpoint = custom_endpoint or optional_params.pop(
"custom_endpoint", None
@ -436,6 +457,7 @@ class DatabricksChatCompletion(BaseLLM):
headers=headers,
client=client,
custom_llm_provider=custom_llm_provider,
streaming_decoder=streaming_decoder,
)
else:
return self.acompletion_function(
@ -473,6 +495,7 @@ class DatabricksChatCompletion(BaseLLM):
model=model,
messages=messages,
logging_obj=logging_obj,
streaming_decoder=streaming_decoder,
),
model=model,
custom_llm_provider=custom_llm_provider,