mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
added streaming support for aiohttp_openai
This commit is contained in:
parent
887c66c6b7
commit
7ac3a9cb83
4 changed files with 355 additions and 23 deletions
|
@ -7,13 +7,14 @@ https://github.com/BerriAI/litellm/issues/6592
|
|||
New config to ensure we introduce this without causing breaking changes for users
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, AsyncIterator, Union, Iterator, List, Optional
|
||||
|
||||
from aiohttp import ClientResponse
|
||||
|
||||
from litellm.llms.openai_like.chat.transformation import OpenAILikeChatConfig
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import Choices, ModelResponse
|
||||
from ..common_utils import ModelResponseIterator as AiohttpOpenAIResponseIterator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
@ -77,3 +78,15 @@ class AiohttpOpenAIChatConfig(OpenAILikeChatConfig):
|
|||
model_response.object = _json_response.get("object")
|
||||
model_response.system_fingerprint = _json_response.get("system_fingerprint")
|
||||
return model_response
|
||||
|
||||
def get_model_response_iterator(
|
||||
self,
|
||||
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||
sync_stream: bool,
|
||||
json_mode: Optional[bool] = False,
|
||||
):
|
||||
return AiohttpOpenAIResponseIterator(
|
||||
streaming_response=streaming_response,
|
||||
sync_stream=sync_stream,
|
||||
json_mode=json_mode,
|
||||
)
|
169
litellm/llms/aiohttp_openai/common_utils.py
Normal file
169
litellm/llms/aiohttp_openai/common_utils.py
Normal file
|
@ -0,0 +1,169 @@
|
|||
import json
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionUsageBlock,
|
||||
GenericStreamingChunk,
|
||||
ModelResponseStream
|
||||
)
|
||||
|
||||
|
||||
class AioHttpOpenAIError(BaseLLMException):
|
||||
def __init__(self, status_code, message):
|
||||
super().__init__(status_code=status_code, message=message)
|
||||
|
||||
|
||||
def validate_environment(
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Return headers to use for aiopenhttp_openai chat completion request
|
||||
"""
|
||||
headers.update(
|
||||
{
|
||||
"Request-Source": "unspecified:litellm",
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
)
|
||||
if api_key:
|
||||
headers["Authorization"] = f"bearer {api_key}"
|
||||
return headers
|
||||
|
||||
class ModelResponseIterator:
|
||||
def __init__(
|
||||
self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False
|
||||
):
|
||||
self.streaming_response = streaming_response
|
||||
self.response_iterator = self.streaming_response
|
||||
self.json_mode = json_mode
|
||||
|
||||
def chunk_parser(self, chunk: dict) -> Union[GenericStreamingChunk, ModelResponseStream]:
|
||||
try:
|
||||
# Initialize default values
|
||||
text = ""
|
||||
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||
is_finished = False
|
||||
finish_reason = ""
|
||||
usage: Optional[ChatCompletionUsageBlock] = None
|
||||
provider_specific_fields = None
|
||||
|
||||
# Extract the index from the chunk
|
||||
index = int(chunk.get("choices", [{}])[0].get("index", 0))
|
||||
|
||||
# Extract the text or delta content from the first choice
|
||||
delta = chunk.get("choices", [{}])[0].get("delta", {})
|
||||
if "content" in delta:
|
||||
text = delta["content"]
|
||||
|
||||
# Check for finish_reason
|
||||
finish_reason = chunk.get("choices", [{}])[0].get("finish_reason", "")
|
||||
|
||||
# Determine if the stream has finished
|
||||
is_finished = finish_reason in ("length", "stop")
|
||||
|
||||
# Create and return the parsed chunk
|
||||
returned_chunk = GenericStreamingChunk(
|
||||
text=text,
|
||||
tool_use=tool_use,
|
||||
is_finished=is_finished,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
index=index,
|
||||
provider_specific_fields=provider_specific_fields,
|
||||
)
|
||||
|
||||
return returned_chunk
|
||||
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
|
||||
|
||||
|
||||
# Sync iterator
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def _handle_string_chunk(
|
||||
self, str_line: str
|
||||
) -> Union[GenericStreamingChunk, ModelResponseStream]:
|
||||
# chunk is a str at this point
|
||||
if "[DONE]" in str_line:
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
is_finished=True,
|
||||
finish_reason="stop",
|
||||
usage=None,
|
||||
index=0,
|
||||
tool_use=None,
|
||||
)
|
||||
elif str_line.startswith("data:"):
|
||||
data_json = json.loads(str_line[5:])
|
||||
return self.chunk_parser(chunk=data_json)
|
||||
else:
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
is_finished=False,
|
||||
finish_reason="",
|
||||
usage=None,
|
||||
index=0,
|
||||
tool_use=None,
|
||||
)
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
chunk = self.response_iterator.__next__()
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||
|
||||
try:
|
||||
str_line = chunk
|
||||
if isinstance(chunk, bytes): # Handle binary data
|
||||
str_line = chunk.decode("utf-8") # Convert bytes to string
|
||||
index = str_line.find("data:")
|
||||
if index != -1:
|
||||
str_line = str_line[index:]
|
||||
# chunk is a str at this point
|
||||
return self._handle_string_chunk(str_line=str_line)
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||
|
||||
# Async iterator
|
||||
def __aiter__(self):
|
||||
self.async_response_iterator = self.streaming_response.__aiter__()
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
try:
|
||||
chunk = await self.async_response_iterator.__anext__()
|
||||
except StopAsyncIteration:
|
||||
raise StopAsyncIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||
|
||||
try:
|
||||
str_line = chunk
|
||||
if isinstance(chunk, bytes): # Handle binary data
|
||||
str_line = chunk.decode("utf-8") # Convert bytes to string
|
||||
index = str_line.find("data:")
|
||||
if index != -1:
|
||||
str_line = str_line[index:]
|
||||
|
||||
# chunk is a str at this point
|
||||
return self._handle_string_chunk(str_line=str_line)
|
||||
except StopAsyncIteration:
|
||||
raise StopAsyncIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||
|
|
@ -174,6 +174,7 @@ class BaseLLMAIOHTTPHandler:
|
|||
api_key: Optional[str] = None,
|
||||
client: Optional[ClientSession] = None,
|
||||
):
|
||||
data.pop("max_retries", None) #added this as this was extra param which is not needed for openai
|
||||
_response = await self._make_common_async_call(
|
||||
async_client_session=client,
|
||||
provider_config=provider_config,
|
||||
|
@ -257,27 +258,50 @@ class BaseLLMAIOHTTPHandler:
|
|||
)
|
||||
|
||||
if acompletion is True:
|
||||
return self.async_completion(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
provider_config=provider_config,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=data,
|
||||
timeout=timeout,
|
||||
model=model,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
client=(
|
||||
client
|
||||
if client is not None and isinstance(client, ClientSession)
|
||||
else None
|
||||
),
|
||||
)
|
||||
if stream is True:
|
||||
if fake_stream is not True:
|
||||
data["stream"] = stream
|
||||
return self.acompletion_stream_function(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
provider_config=provider_config,
|
||||
timeout=timeout,
|
||||
logging_obj=logging_obj,
|
||||
data=data,
|
||||
fake_stream=fake_stream,
|
||||
client=(
|
||||
client
|
||||
if client is not None and isinstance(client, ClientSession)
|
||||
else None
|
||||
),
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
else:
|
||||
return self.async_completion(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
provider_config=provider_config,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=data,
|
||||
timeout=timeout,
|
||||
model=model,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
client=(
|
||||
client
|
||||
if client is not None and isinstance(client, ClientSession)
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
if stream is True:
|
||||
if fake_stream is not True:
|
||||
|
@ -332,7 +356,95 @@ class BaseLLMAIOHTTPHandler:
|
|||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
async def acompletion_stream_function(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_llm_provider: str,
|
||||
headers: dict,
|
||||
provider_config: BaseConfig,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
data: dict,
|
||||
litellm_params: dict,
|
||||
fake_stream: bool = False,
|
||||
client: Optional[ClientSession] = None,
|
||||
):
|
||||
completion_stream, _response_headers = await self.make_async_call(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
provider_config=provider_config,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=data,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
timeout=timeout,
|
||||
fake_stream=fake_stream,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return streamwrapper
|
||||
|
||||
async def make_async_call(
|
||||
self,
|
||||
custom_llm_provider: str,
|
||||
provider_config: BaseConfig,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: dict,
|
||||
messages: list,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
litellm_params: dict,
|
||||
fake_stream: bool = False,
|
||||
client: Optional[Union[AsyncHTTPHandler, ClientSession]] = None,
|
||||
) -> Tuple[Any, httpx.Headers]:
|
||||
if client is None or not isinstance(client, ClientSession):
|
||||
async_client_session = self._get_async_client_session()
|
||||
|
||||
stream = True
|
||||
if fake_stream is True:
|
||||
stream = False
|
||||
data.pop("max_retries", None)
|
||||
response = await self._make_common_async_call(
|
||||
async_client_session=async_client_session,
|
||||
provider_config=provider_config,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=data,
|
||||
timeout=timeout,
|
||||
litellm_params=litellm_params,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
if fake_stream is True:
|
||||
json_response = await response.json()
|
||||
completion_stream = provider_config.get_model_response_iterator(
|
||||
streaming_response=json_response, sync_stream=False
|
||||
)
|
||||
else:
|
||||
completion_stream = provider_config.get_model_response_iterator(
|
||||
streaming_response=response.content, sync_stream=False
|
||||
)
|
||||
# LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response="first stream response received:: ",
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
return completion_stream, response.headers
|
||||
|
||||
def make_sync_call(
|
||||
self,
|
||||
provider_config: BaseConfig,
|
||||
|
@ -372,7 +484,7 @@ class BaseLLMAIOHTTPHandler:
|
|||
)
|
||||
else:
|
||||
completion_stream = provider_config.get_model_response_iterator(
|
||||
streaming_response=response.iter_lines(), sync_stream=True
|
||||
streaming_response=response.content, sync_stream=True
|
||||
)
|
||||
|
||||
# LOGGING
|
||||
|
|
|
@ -10,6 +10,8 @@ sys.path.insert(
|
|||
|
||||
import litellm
|
||||
|
||||
from local_testing.test_streaming import streaming_format_tests
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_aiohttp_openai():
|
||||
|
@ -31,3 +33,39 @@ async def test_aiohttp_openai_gpt_4o():
|
|||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||
)
|
||||
print(response)
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_completion_model_stream():
|
||||
litellm.set_verbose = True
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
assert api_key is not None, "API key is not set in environment variables"
|
||||
|
||||
try:
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "how does a court case get to the Supreme Court?",
|
||||
},
|
||||
]
|
||||
response = await litellm.acompletion(
|
||||
api_key=api_key, model="aiohttp_openai/gpt-4o", messages=messages, stream=True, max_tokens=50
|
||||
)
|
||||
|
||||
complete_response = ""
|
||||
idx = 0 # Initialize index manually
|
||||
async for chunk in response: # Use async for to handle async iterator
|
||||
chunk, finished = streaming_format_tests(idx, chunk) # Await if streaming_format_tests is async
|
||||
print(f"outside chunk: {chunk}")
|
||||
if finished:
|
||||
break
|
||||
complete_response += chunk
|
||||
idx += 1 # Increment index manually
|
||||
|
||||
if complete_response.strip() == "":
|
||||
raise Exception("Empty response received")
|
||||
print(f"complete response: {complete_response}")
|
||||
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
Loading…
Add table
Add a link
Reference in a new issue