mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
fix(sagemaker.py): support streaming for messages api
Fixes https://github.com/BerriAI/litellm/issues/5372
This commit is contained in:
parent
174b1c43e3
commit
8e9acd117b
8 changed files with 142 additions and 32 deletions
|
@ -856,7 +856,7 @@ from .llms.vertex_httpx import (
|
||||||
from .llms.vertex_ai import VertexAITextEmbeddingConfig
|
from .llms.vertex_ai import VertexAITextEmbeddingConfig
|
||||||
from .llms.vertex_ai_anthropic import VertexAIAnthropicConfig
|
from .llms.vertex_ai_anthropic import VertexAIAnthropicConfig
|
||||||
from .llms.vertex_ai_partner import VertexAILlama3Config
|
from .llms.vertex_ai_partner import VertexAILlama3Config
|
||||||
from .llms.sagemaker import SagemakerConfig
|
from .llms.sagemaker.sagemaker import SagemakerConfig
|
||||||
from .llms.ollama import OllamaConfig
|
from .llms.ollama import OllamaConfig
|
||||||
from .llms.ollama_chat import OllamaChatConfig
|
from .llms.ollama_chat import OllamaChatConfig
|
||||||
from .llms.maritalk import MaritTalkConfig
|
from .llms.maritalk import MaritTalkConfig
|
||||||
|
|
|
@ -13,4 +13,6 @@ def generic_chunk_has_all_required_fields(chunk: dict) -> bool:
|
||||||
# this is an optional field in GenericStreamingChunk, it's not required to be present
|
# this is an optional field in GenericStreamingChunk, it's not required to be present
|
||||||
_all_fields.pop("provider_specific_fields", None)
|
_all_fields.pop("provider_specific_fields", None)
|
||||||
|
|
||||||
return all(key in chunk for key in _all_fields)
|
decision = all(key in _all_fields for key in chunk)
|
||||||
|
|
||||||
|
return decision
|
||||||
|
|
|
@ -7,7 +7,7 @@ import time
|
||||||
import types
|
import types
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable, List, Literal, Optional, Tuple, Union
|
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
import httpx # type: ignore
|
import httpx # type: ignore
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
|
@ -22,7 +22,11 @@ from litellm.types.llms.openai import (
|
||||||
ChatCompletionToolCallFunctionChunk,
|
ChatCompletionToolCallFunctionChunk,
|
||||||
ChatCompletionUsageBlock,
|
ChatCompletionUsageBlock,
|
||||||
)
|
)
|
||||||
from litellm.types.utils import GenericStreamingChunk, ProviderField
|
from litellm.types.utils import (
|
||||||
|
CustomStreamingDecoder,
|
||||||
|
GenericStreamingChunk,
|
||||||
|
ProviderField,
|
||||||
|
)
|
||||||
from litellm.utils import CustomStreamWrapper, EmbeddingResponse, ModelResponse, Usage
|
from litellm.utils import CustomStreamWrapper, EmbeddingResponse, ModelResponse, Usage
|
||||||
|
|
||||||
from .base import BaseLLM
|
from .base import BaseLLM
|
||||||
|
@ -171,15 +175,21 @@ async def make_call(
|
||||||
model: str,
|
model: str,
|
||||||
messages: list,
|
messages: list,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
|
streaming_decoder: Optional[CustomStreamingDecoder] = None,
|
||||||
):
|
):
|
||||||
response = await client.post(api_base, headers=headers, data=data, stream=True)
|
response = await client.post(api_base, headers=headers, data=data, stream=True)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise DatabricksError(status_code=response.status_code, message=response.text)
|
raise DatabricksError(status_code=response.status_code, message=response.text)
|
||||||
|
|
||||||
completion_stream = ModelResponseIterator(
|
if streaming_decoder is not None:
|
||||||
streaming_response=response.aiter_lines(), sync_stream=False
|
completion_stream: Any = streaming_decoder.aiter_bytes(
|
||||||
)
|
response.aiter_bytes(chunk_size=1024)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
completion_stream = ModelResponseIterator(
|
||||||
|
streaming_response=response.aiter_lines(), sync_stream=False
|
||||||
|
)
|
||||||
# LOGGING
|
# LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=messages,
|
input=messages,
|
||||||
|
@ -199,6 +209,7 @@ def make_sync_call(
|
||||||
model: str,
|
model: str,
|
||||||
messages: list,
|
messages: list,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
|
streaming_decoder: Optional[CustomStreamingDecoder] = None,
|
||||||
):
|
):
|
||||||
if client is None:
|
if client is None:
|
||||||
client = HTTPHandler() # Create a new client if none provided
|
client = HTTPHandler() # Create a new client if none provided
|
||||||
|
@ -208,9 +219,14 @@ def make_sync_call(
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise DatabricksError(status_code=response.status_code, message=response.read())
|
raise DatabricksError(status_code=response.status_code, message=response.read())
|
||||||
|
|
||||||
completion_stream = ModelResponseIterator(
|
if streaming_decoder is not None:
|
||||||
streaming_response=response.iter_lines(), sync_stream=True
|
completion_stream = streaming_decoder.iter_bytes(
|
||||||
)
|
response.iter_bytes(chunk_size=1024)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
completion_stream = ModelResponseIterator(
|
||||||
|
streaming_response=response.iter_lines(), sync_stream=True
|
||||||
|
)
|
||||||
|
|
||||||
# LOGGING
|
# LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
|
@ -283,6 +299,7 @@ class DatabricksChatCompletion(BaseLLM):
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
headers={},
|
headers={},
|
||||||
client: Optional[AsyncHTTPHandler] = None,
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
streaming_decoder: Optional[CustomStreamingDecoder] = None,
|
||||||
) -> CustomStreamWrapper:
|
) -> CustomStreamWrapper:
|
||||||
|
|
||||||
data["stream"] = True
|
data["stream"] = True
|
||||||
|
@ -296,6 +313,7 @@ class DatabricksChatCompletion(BaseLLM):
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
|
streaming_decoder=streaming_decoder,
|
||||||
),
|
),
|
||||||
model=model,
|
model=model,
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
@ -371,6 +389,9 @@ class DatabricksChatCompletion(BaseLLM):
|
||||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||||
custom_endpoint: Optional[bool] = None,
|
custom_endpoint: Optional[bool] = None,
|
||||||
|
streaming_decoder: Optional[
|
||||||
|
CustomStreamingDecoder
|
||||||
|
] = None, # if openai-compatible api needs custom stream decoder - e.g. sagemaker
|
||||||
):
|
):
|
||||||
custom_endpoint = custom_endpoint or optional_params.pop(
|
custom_endpoint = custom_endpoint or optional_params.pop(
|
||||||
"custom_endpoint", None
|
"custom_endpoint", None
|
||||||
|
@ -436,6 +457,7 @@ class DatabricksChatCompletion(BaseLLM):
|
||||||
headers=headers,
|
headers=headers,
|
||||||
client=client,
|
client=client,
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
streaming_decoder=streaming_decoder,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self.acompletion_function(
|
return self.acompletion_function(
|
||||||
|
@ -473,6 +495,7 @@ class DatabricksChatCompletion(BaseLLM):
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
|
streaming_decoder=streaming_decoder,
|
||||||
),
|
),
|
||||||
model=model,
|
model=model,
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
|
|
@ -24,8 +24,11 @@ from litellm.llms.custom_httpx.http_handler import (
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
ChatCompletionToolCallChunk,
|
ChatCompletionToolCallChunk,
|
||||||
ChatCompletionUsageBlock,
|
ChatCompletionUsageBlock,
|
||||||
|
OpenAIChatCompletionChunk,
|
||||||
)
|
)
|
||||||
|
from litellm.types.utils import CustomStreamingDecoder
|
||||||
from litellm.types.utils import GenericStreamingChunk as GChunk
|
from litellm.types.utils import GenericStreamingChunk as GChunk
|
||||||
|
from litellm.types.utils import StreamingChatCompletionChunk
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
CustomStreamWrapper,
|
CustomStreamWrapper,
|
||||||
EmbeddingResponse,
|
EmbeddingResponse,
|
||||||
|
@ -34,8 +37,8 @@ from litellm.utils import (
|
||||||
get_secret,
|
get_secret,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .base_aws_llm import BaseAWSLLM
|
from ..base_aws_llm import BaseAWSLLM
|
||||||
from .prompt_templates.factory import custom_prompt, prompt_factory
|
from ..prompt_templates.factory import custom_prompt, prompt_factory
|
||||||
|
|
||||||
_response_stream_shape_cache = None
|
_response_stream_shape_cache = None
|
||||||
|
|
||||||
|
@ -241,6 +244,10 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
aws_region_name=aws_region_name,
|
aws_region_name=aws_region_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
custom_stream_decoder = AWSEventStreamDecoder(
|
||||||
|
model="", is_messages_api=True
|
||||||
|
)
|
||||||
|
|
||||||
return openai_like_chat_completions.completion(
|
return openai_like_chat_completions.completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -259,6 +266,7 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
headers=prepared_request.headers,
|
headers=prepared_request.headers,
|
||||||
custom_endpoint=True,
|
custom_endpoint=True,
|
||||||
custom_llm_provider="sagemaker_chat",
|
custom_llm_provider="sagemaker_chat",
|
||||||
|
streaming_decoder=custom_stream_decoder, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
## Load Config
|
## Load Config
|
||||||
|
@ -332,7 +340,7 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
else:
|
else:
|
||||||
if stream is not None and stream == True:
|
if stream is not None and stream is True:
|
||||||
sync_handler = _get_httpx_client()
|
sync_handler = _get_httpx_client()
|
||||||
sync_response = sync_handler.post(
|
sync_response = sync_handler.post(
|
||||||
url=prepared_request.url,
|
url=prepared_request.url,
|
||||||
|
@ -847,12 +855,21 @@ def get_response_stream_shape():
|
||||||
|
|
||||||
|
|
||||||
class AWSEventStreamDecoder:
|
class AWSEventStreamDecoder:
|
||||||
def __init__(self, model: str) -> None:
|
def __init__(self, model: str, is_messages_api: Optional[bool] = None) -> None:
|
||||||
from botocore.parsers import EventStreamJSONParser
|
from botocore.parsers import EventStreamJSONParser
|
||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
self.parser = EventStreamJSONParser()
|
self.parser = EventStreamJSONParser()
|
||||||
self.content_blocks: List = []
|
self.content_blocks: List = []
|
||||||
|
self.is_messages_api = is_messages_api
|
||||||
|
|
||||||
|
def _chunk_parser_messages_api(
|
||||||
|
self, chunk_data: dict
|
||||||
|
) -> StreamingChatCompletionChunk:
|
||||||
|
|
||||||
|
openai_chunk = StreamingChatCompletionChunk(**chunk_data)
|
||||||
|
|
||||||
|
return openai_chunk
|
||||||
|
|
||||||
def _chunk_parser(self, chunk_data: dict) -> GChunk:
|
def _chunk_parser(self, chunk_data: dict) -> GChunk:
|
||||||
verbose_logger.debug("in sagemaker chunk parser, chunk_data %s", chunk_data)
|
verbose_logger.debug("in sagemaker chunk parser, chunk_data %s", chunk_data)
|
||||||
|
@ -868,6 +885,7 @@ class AWSEventStreamDecoder:
|
||||||
index=_index,
|
index=_index,
|
||||||
is_finished=True,
|
is_finished=True,
|
||||||
finish_reason="stop",
|
finish_reason="stop",
|
||||||
|
usage=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
return GChunk(
|
return GChunk(
|
||||||
|
@ -875,9 +893,12 @@ class AWSEventStreamDecoder:
|
||||||
index=_index,
|
index=_index,
|
||||||
is_finished=is_finished,
|
is_finished=is_finished,
|
||||||
finish_reason=finish_reason,
|
finish_reason=finish_reason,
|
||||||
|
usage=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GChunk]:
|
def iter_bytes(
|
||||||
|
self, iterator: Iterator[bytes]
|
||||||
|
) -> Iterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]:
|
||||||
"""Given an iterator that yields lines, iterate over it & yield every event encountered"""
|
"""Given an iterator that yields lines, iterate over it & yield every event encountered"""
|
||||||
from botocore.eventstream import EventStreamBuffer
|
from botocore.eventstream import EventStreamBuffer
|
||||||
|
|
||||||
|
@ -898,7 +919,10 @@ class AWSEventStreamDecoder:
|
||||||
# Try to parse the accumulated JSON
|
# Try to parse the accumulated JSON
|
||||||
try:
|
try:
|
||||||
_data = json.loads(accumulated_json)
|
_data = json.loads(accumulated_json)
|
||||||
yield self._chunk_parser(chunk_data=_data)
|
if self.is_messages_api:
|
||||||
|
yield self._chunk_parser_messages_api(chunk_data=_data)
|
||||||
|
else:
|
||||||
|
yield self._chunk_parser(chunk_data=_data)
|
||||||
# Reset accumulated_json after successful parsing
|
# Reset accumulated_json after successful parsing
|
||||||
accumulated_json = ""
|
accumulated_json = ""
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
|
@ -909,16 +933,20 @@ class AWSEventStreamDecoder:
|
||||||
if accumulated_json:
|
if accumulated_json:
|
||||||
try:
|
try:
|
||||||
_data = json.loads(accumulated_json)
|
_data = json.loads(accumulated_json)
|
||||||
yield self._chunk_parser(chunk_data=_data)
|
if self.is_messages_api:
|
||||||
except json.JSONDecodeError:
|
yield self._chunk_parser_messages_api(chunk_data=_data)
|
||||||
|
else:
|
||||||
|
yield self._chunk_parser(chunk_data=_data)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
# Handle or log any unparseable data at the end
|
# Handle or log any unparseable data at the end
|
||||||
verbose_logger.error(
|
verbose_logger.error(
|
||||||
f"Warning: Unparseable JSON data remained: {accumulated_json}"
|
f"Warning: Unparseable JSON data remained: {accumulated_json}"
|
||||||
)
|
)
|
||||||
|
yield None
|
||||||
|
|
||||||
async def aiter_bytes(
|
async def aiter_bytes(
|
||||||
self, iterator: AsyncIterator[bytes]
|
self, iterator: AsyncIterator[bytes]
|
||||||
) -> AsyncIterator[GChunk]:
|
) -> AsyncIterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]:
|
||||||
"""Given an async iterator that yields lines, iterate over it & yield every event encountered"""
|
"""Given an async iterator that yields lines, iterate over it & yield every event encountered"""
|
||||||
from botocore.eventstream import EventStreamBuffer
|
from botocore.eventstream import EventStreamBuffer
|
||||||
|
|
||||||
|
@ -940,7 +968,10 @@ class AWSEventStreamDecoder:
|
||||||
# Try to parse the accumulated JSON
|
# Try to parse the accumulated JSON
|
||||||
try:
|
try:
|
||||||
_data = json.loads(accumulated_json)
|
_data = json.loads(accumulated_json)
|
||||||
yield self._chunk_parser(chunk_data=_data)
|
if self.is_messages_api:
|
||||||
|
yield self._chunk_parser_messages_api(chunk_data=_data)
|
||||||
|
else:
|
||||||
|
yield self._chunk_parser(chunk_data=_data)
|
||||||
# Reset accumulated_json after successful parsing
|
# Reset accumulated_json after successful parsing
|
||||||
accumulated_json = ""
|
accumulated_json = ""
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
|
@ -951,12 +982,16 @@ class AWSEventStreamDecoder:
|
||||||
if accumulated_json:
|
if accumulated_json:
|
||||||
try:
|
try:
|
||||||
_data = json.loads(accumulated_json)
|
_data = json.loads(accumulated_json)
|
||||||
yield self._chunk_parser(chunk_data=_data)
|
if self.is_messages_api:
|
||||||
|
yield self._chunk_parser_messages_api(chunk_data=_data)
|
||||||
|
else:
|
||||||
|
yield self._chunk_parser(chunk_data=_data)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
# Handle or log any unparseable data at the end
|
# Handle or log any unparseable data at the end
|
||||||
verbose_logger.error(
|
verbose_logger.error(
|
||||||
f"Warning: Unparseable JSON data remained: {accumulated_json}"
|
f"Warning: Unparseable JSON data remained: {accumulated_json}"
|
||||||
)
|
)
|
||||||
|
yield None
|
||||||
|
|
||||||
def _parse_message_from_event(self, event) -> Optional[str]:
|
def _parse_message_from_event(self, event) -> Optional[str]:
|
||||||
response_dict = event.to_response_dict()
|
response_dict = event.to_response_dict()
|
|
@ -119,7 +119,7 @@ from .llms.prompt_templates.factory import (
|
||||||
prompt_factory,
|
prompt_factory,
|
||||||
stringify_json_tool_call_content,
|
stringify_json_tool_call_content,
|
||||||
)
|
)
|
||||||
from .llms.sagemaker import SagemakerLLM
|
from .llms.sagemaker.sagemaker import SagemakerLLM
|
||||||
from .llms.text_completion_codestral import CodestralTextCompletion
|
from .llms.text_completion_codestral import CodestralTextCompletion
|
||||||
from .llms.text_to_speech.vertex_ai import VertexTextToSpeechAPI
|
from .llms.text_to_speech.vertex_ai import VertexTextToSpeechAPI
|
||||||
from .llms.triton import TritonChatCompletion
|
from .llms.triton import TritonChatCompletion
|
||||||
|
|
|
@ -120,15 +120,24 @@ async def test_completion_sagemaker_messages_api(sync_mode):
|
||||||
|
|
||||||
@pytest.mark.asyncio()
|
@pytest.mark.asyncio()
|
||||||
@pytest.mark.parametrize("sync_mode", [False, True])
|
@pytest.mark.parametrize("sync_mode", [False, True])
|
||||||
async def test_completion_sagemaker_stream(sync_mode):
|
@pytest.mark.parametrize(
|
||||||
|
"model",
|
||||||
|
[
|
||||||
|
"sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245",
|
||||||
|
"sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_completion_sagemaker_stream(sync_mode, model):
|
||||||
try:
|
try:
|
||||||
|
from litellm.tests.test_streaming import streaming_format_tests
|
||||||
|
|
||||||
litellm.set_verbose = False
|
litellm.set_verbose = False
|
||||||
print("testing sagemaker")
|
print("testing sagemaker")
|
||||||
verbose_logger.setLevel(logging.DEBUG)
|
verbose_logger.setLevel(logging.DEBUG)
|
||||||
full_text = ""
|
full_text = ""
|
||||||
if sync_mode is True:
|
if sync_mode is True:
|
||||||
response = litellm.completion(
|
response = litellm.completion(
|
||||||
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
model=model,
|
||||||
messages=[
|
messages=[
|
||||||
{"role": "user", "content": "hi - what is ur name"},
|
{"role": "user", "content": "hi - what is ur name"},
|
||||||
],
|
],
|
||||||
|
@ -138,14 +147,15 @@ async def test_completion_sagemaker_stream(sync_mode):
|
||||||
input_cost_per_second=0.000420,
|
input_cost_per_second=0.000420,
|
||||||
)
|
)
|
||||||
|
|
||||||
for chunk in response:
|
for idx, chunk in enumerate(response):
|
||||||
print(chunk)
|
print(chunk)
|
||||||
|
streaming_format_tests(idx=idx, chunk=chunk)
|
||||||
full_text += chunk.choices[0].delta.content or ""
|
full_text += chunk.choices[0].delta.content or ""
|
||||||
|
|
||||||
print("SYNC RESPONSE full text", full_text)
|
print("SYNC RESPONSE full text", full_text)
|
||||||
else:
|
else:
|
||||||
response = await litellm.acompletion(
|
response = await litellm.acompletion(
|
||||||
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
model=model,
|
||||||
messages=[
|
messages=[
|
||||||
{"role": "user", "content": "hi - what is ur name"},
|
{"role": "user", "content": "hi - what is ur name"},
|
||||||
],
|
],
|
||||||
|
@ -156,10 +166,12 @@ async def test_completion_sagemaker_stream(sync_mode):
|
||||||
)
|
)
|
||||||
|
|
||||||
print("streaming response")
|
print("streaming response")
|
||||||
|
idx = 0
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
print(chunk)
|
print(chunk)
|
||||||
|
streaming_format_tests(idx=idx, chunk=chunk)
|
||||||
full_text += chunk.choices[0].delta.content or ""
|
full_text += chunk.choices[0].delta.content or ""
|
||||||
|
idx += 1
|
||||||
|
|
||||||
print("ASYNC RESPONSE full text", full_text)
|
print("ASYNC RESPONSE full text", full_text)
|
||||||
|
|
||||||
|
|
|
@ -29,6 +29,7 @@ from openai.types.beta.thread_create_params import (
|
||||||
from openai.types.beta.threads.message import Message as OpenAIMessage
|
from openai.types.beta.threads.message import Message as OpenAIMessage
|
||||||
from openai.types.beta.threads.message_content import MessageContent
|
from openai.types.beta.threads.message_content import MessageContent
|
||||||
from openai.types.beta.threads.run import Run
|
from openai.types.beta.threads.run import Run
|
||||||
|
from openai.types.chat import ChatCompletionChunk
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Dict, Required, TypedDict, override
|
from typing_extensions import Dict, Required, TypedDict, override
|
||||||
|
|
||||||
|
@ -456,6 +457,13 @@ class ChatCompletionUsageBlock(TypedDict):
|
||||||
total_tokens: int
|
total_tokens: int
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIChatCompletionChunk(ChatCompletionChunk):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
# Set the 'object' kwarg to 'chat.completion.chunk'
|
||||||
|
kwargs["object"] = "chat.completion.chunk"
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
class Hyperparameters(BaseModel):
|
class Hyperparameters(BaseModel):
|
||||||
batch_size: Optional[Union[str, int]] = None # "Number of examples in each batch."
|
batch_size: Optional[Union[str, int]] = None # "Number of examples in each batch."
|
||||||
learning_rate_multiplier: Optional[Union[str, float]] = (
|
learning_rate_multiplier: Optional[Union[str, float]] = (
|
||||||
|
|
|
@ -5,11 +5,16 @@ from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
from openai._models import BaseModel as OpenAIObject
|
from openai._models import BaseModel as OpenAIObject
|
||||||
|
from openai.types.completion_usage import CompletionUsage
|
||||||
from pydantic import ConfigDict, Field, PrivateAttr
|
from pydantic import ConfigDict, Field, PrivateAttr
|
||||||
from typing_extensions import Callable, Dict, Required, TypedDict, override
|
from typing_extensions import Callable, Dict, Required, TypedDict, override
|
||||||
|
|
||||||
from ..litellm_core_utils.core_helpers import map_finish_reason
|
from ..litellm_core_utils.core_helpers import map_finish_reason
|
||||||
from .llms.openai import ChatCompletionToolCallChunk, ChatCompletionUsageBlock
|
from .llms.openai import (
|
||||||
|
ChatCompletionToolCallChunk,
|
||||||
|
ChatCompletionUsageBlock,
|
||||||
|
OpenAIChatCompletionChunk,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _generate_id(): # private helper function
|
def _generate_id(): # private helper function
|
||||||
|
@ -85,7 +90,7 @@ class GenericStreamingChunk(TypedDict, total=False):
|
||||||
tool_use: Optional[ChatCompletionToolCallChunk]
|
tool_use: Optional[ChatCompletionToolCallChunk]
|
||||||
is_finished: Required[bool]
|
is_finished: Required[bool]
|
||||||
finish_reason: Required[str]
|
finish_reason: Required[str]
|
||||||
usage: Optional[ChatCompletionUsageBlock]
|
usage: Required[Optional[ChatCompletionUsageBlock]]
|
||||||
index: int
|
index: int
|
||||||
|
|
||||||
# use this dict if you want to return any provider specific fields in the response
|
# use this dict if you want to return any provider specific fields in the response
|
||||||
|
@ -448,9 +453,6 @@ class Choices(OpenAIObject):
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
|
||||||
from openai.types.completion_usage import CompletionUsage
|
|
||||||
|
|
||||||
|
|
||||||
class Usage(CompletionUsage):
|
class Usage(CompletionUsage):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -535,6 +537,17 @@ class StreamingChoices(OpenAIObject):
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
|
||||||
|
class StreamingChatCompletionChunk(OpenAIChatCompletionChunk):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
|
||||||
|
new_choices = []
|
||||||
|
for choice in kwargs["choices"]:
|
||||||
|
new_choice = StreamingChoices(**choice).model_dump()
|
||||||
|
new_choices.append(new_choice)
|
||||||
|
kwargs["choices"] = new_choices
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
class ModelResponse(OpenAIObject):
|
class ModelResponse(OpenAIObject):
|
||||||
id: str
|
id: str
|
||||||
"""A unique identifier for the completion."""
|
"""A unique identifier for the completion."""
|
||||||
|
@ -1231,3 +1244,20 @@ class StandardLoggingPayload(TypedDict):
|
||||||
response: Optional[Union[str, list, dict]]
|
response: Optional[Union[str, list, dict]]
|
||||||
model_parameters: dict
|
model_parameters: dict
|
||||||
hidden_params: StandardLoggingHiddenParams
|
hidden_params: StandardLoggingHiddenParams
|
||||||
|
|
||||||
|
|
||||||
|
from typing import AsyncIterator, Iterator
|
||||||
|
|
||||||
|
|
||||||
|
class CustomStreamingDecoder:
|
||||||
|
async def aiter_bytes(
|
||||||
|
self, iterator: AsyncIterator[bytes]
|
||||||
|
) -> AsyncIterator[
|
||||||
|
Optional[Union[GenericStreamingChunk, StreamingChatCompletionChunk]]
|
||||||
|
]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def iter_bytes(
|
||||||
|
self, iterator: Iterator[bytes]
|
||||||
|
) -> Iterator[Optional[Union[GenericStreamingChunk, StreamingChatCompletionChunk]]]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue