mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 46s
* VoyageEmbeddingConfig * fix voyage logic to get params * add voyage embedding transformation * add get_provider_embedding_config * use BaseEmbeddingConfig * voyage clean up * use llm http handler for embedding transformations * test_voyage_ai_embedding_extra_params * add voyage async * test_voyage_ai_embedding_extra_params * add async for llm http handler * update BaseLLMEmbeddingTest * test_voyage_ai_embedding_extra_params * fix linting * fix get_provider_embedding_config * fix anthropic text test * update location of base/chat/transformation * fix import path * fix IBMWatsonXAIConfig
198 lines
7.5 KiB
Python
198 lines
7.5 KiB
Python
import json
|
|
from typing import AsyncIterator, Iterator, List, Optional, Union
|
|
|
|
import httpx
|
|
|
|
from litellm import verbose_logger
|
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
|
from litellm.types.utils import GenericStreamingChunk as GChunk
|
|
from litellm.types.utils import StreamingChatCompletionChunk
|
|
|
|
_response_stream_shape_cache = None
|
|
|
|
|
|
class SagemakerError(BaseLLMException):
|
|
def __init__(
|
|
self,
|
|
status_code: int,
|
|
message: str,
|
|
headers: Optional[Union[dict, httpx.Headers]] = None,
|
|
):
|
|
super().__init__(status_code=status_code, message=message, headers=headers)
|
|
|
|
|
|
class AWSEventStreamDecoder:
|
|
def __init__(self, model: str, is_messages_api: Optional[bool] = None) -> None:
|
|
from botocore.parsers import EventStreamJSONParser
|
|
|
|
self.model = model
|
|
self.parser = EventStreamJSONParser()
|
|
self.content_blocks: List = []
|
|
self.is_messages_api = is_messages_api
|
|
|
|
def _chunk_parser_messages_api(
|
|
self, chunk_data: dict
|
|
) -> StreamingChatCompletionChunk:
|
|
|
|
openai_chunk = StreamingChatCompletionChunk(**chunk_data)
|
|
|
|
return openai_chunk
|
|
|
|
def _chunk_parser(self, chunk_data: dict) -> GChunk:
|
|
verbose_logger.debug("in sagemaker chunk parser, chunk_data %s", chunk_data)
|
|
_token = chunk_data.get("token", {}) or {}
|
|
_index = chunk_data.get("index", None) or 0
|
|
is_finished = False
|
|
finish_reason = ""
|
|
|
|
_text = _token.get("text", "")
|
|
if _text == "<|endoftext|>":
|
|
return GChunk(
|
|
text="",
|
|
index=_index,
|
|
is_finished=True,
|
|
finish_reason="stop",
|
|
usage=None,
|
|
)
|
|
|
|
return GChunk(
|
|
text=_text,
|
|
index=_index,
|
|
is_finished=is_finished,
|
|
finish_reason=finish_reason,
|
|
usage=None,
|
|
)
|
|
|
|
def iter_bytes(
|
|
self, iterator: Iterator[bytes]
|
|
) -> Iterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]:
|
|
"""Given an iterator that yields lines, iterate over it & yield every event encountered"""
|
|
from botocore.eventstream import EventStreamBuffer
|
|
|
|
event_stream_buffer = EventStreamBuffer()
|
|
accumulated_json = ""
|
|
|
|
for chunk in iterator:
|
|
event_stream_buffer.add_data(chunk)
|
|
for event in event_stream_buffer:
|
|
message = self._parse_message_from_event(event)
|
|
if message:
|
|
# remove data: prefix and "\n\n" at the end
|
|
message = message.replace("data:", "").replace("\n\n", "")
|
|
|
|
# Accumulate JSON data
|
|
accumulated_json += message
|
|
|
|
# Try to parse the accumulated JSON
|
|
try:
|
|
_data = json.loads(accumulated_json)
|
|
if self.is_messages_api:
|
|
yield self._chunk_parser_messages_api(chunk_data=_data)
|
|
else:
|
|
yield self._chunk_parser(chunk_data=_data)
|
|
# Reset accumulated_json after successful parsing
|
|
accumulated_json = ""
|
|
except json.JSONDecodeError:
|
|
# If it's not valid JSON yet, continue to the next event
|
|
continue
|
|
|
|
# Handle any remaining data after the iterator is exhausted
|
|
if accumulated_json:
|
|
try:
|
|
_data = json.loads(accumulated_json)
|
|
if self.is_messages_api:
|
|
yield self._chunk_parser_messages_api(chunk_data=_data)
|
|
else:
|
|
yield self._chunk_parser(chunk_data=_data)
|
|
except json.JSONDecodeError:
|
|
# Handle or log any unparseable data at the end
|
|
verbose_logger.error(
|
|
f"Warning: Unparseable JSON data remained: {accumulated_json}"
|
|
)
|
|
yield None
|
|
|
|
async def aiter_bytes(
|
|
self, iterator: AsyncIterator[bytes]
|
|
) -> AsyncIterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]:
|
|
"""Given an async iterator that yields lines, iterate over it & yield every event encountered"""
|
|
from botocore.eventstream import EventStreamBuffer
|
|
|
|
event_stream_buffer = EventStreamBuffer()
|
|
accumulated_json = ""
|
|
|
|
async for chunk in iterator:
|
|
event_stream_buffer.add_data(chunk)
|
|
for event in event_stream_buffer:
|
|
message = self._parse_message_from_event(event)
|
|
if message:
|
|
verbose_logger.debug("sagemaker parsed chunk bytes %s", message)
|
|
# remove data: prefix and "\n\n" at the end
|
|
message = message.replace("data:", "").replace("\n\n", "")
|
|
|
|
# Accumulate JSON data
|
|
accumulated_json += message
|
|
|
|
# Try to parse the accumulated JSON
|
|
try:
|
|
_data = json.loads(accumulated_json)
|
|
if self.is_messages_api:
|
|
yield self._chunk_parser_messages_api(chunk_data=_data)
|
|
else:
|
|
yield self._chunk_parser(chunk_data=_data)
|
|
# Reset accumulated_json after successful parsing
|
|
accumulated_json = ""
|
|
except json.JSONDecodeError:
|
|
# If it's not valid JSON yet, continue to the next event
|
|
continue
|
|
|
|
# Handle any remaining data after the iterator is exhausted
|
|
if accumulated_json:
|
|
try:
|
|
_data = json.loads(accumulated_json)
|
|
if self.is_messages_api:
|
|
yield self._chunk_parser_messages_api(chunk_data=_data)
|
|
else:
|
|
yield self._chunk_parser(chunk_data=_data)
|
|
except json.JSONDecodeError:
|
|
# Handle or log any unparseable data at the end
|
|
verbose_logger.error(
|
|
f"Warning: Unparseable JSON data remained: {accumulated_json}"
|
|
)
|
|
yield None
|
|
|
|
def _parse_message_from_event(self, event) -> Optional[str]:
|
|
response_dict = event.to_response_dict()
|
|
parsed_response = self.parser.parse(response_dict, get_response_stream_shape())
|
|
|
|
if response_dict["status_code"] != 200:
|
|
raise ValueError(f"Bad response code, expected 200: {response_dict}")
|
|
|
|
if "chunk" in parsed_response:
|
|
chunk = parsed_response.get("chunk")
|
|
if not chunk:
|
|
return None
|
|
return chunk.get("bytes").decode() # type: ignore[no-any-return]
|
|
else:
|
|
chunk = response_dict.get("body")
|
|
if not chunk:
|
|
return None
|
|
|
|
return chunk.decode() # type: ignore[no-any-return]
|
|
|
|
|
|
def get_response_stream_shape():
|
|
global _response_stream_shape_cache
|
|
if _response_stream_shape_cache is None:
|
|
|
|
from botocore.loaders import Loader
|
|
from botocore.model import ServiceModel
|
|
|
|
loader = Loader()
|
|
sagemaker_service_dict = loader.load_service_model(
|
|
"sagemaker-runtime", "service-2"
|
|
)
|
|
sagemaker_service_model = ServiceModel(sagemaker_service_dict)
|
|
_response_stream_shape_cache = sagemaker_service_model.shape_for(
|
|
"InvokeEndpointWithResponseStreamOutput"
|
|
)
|
|
return _response_stream_shape_cache
|