This commit is contained in:
Jaswanth Karani 2025-04-24 00:59:23 -07:00 committed by GitHub
commit fe407c75b6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 339 additions and 42 deletions

View file

@ -7,13 +7,14 @@ https://github.com/BerriAI/litellm/issues/6592
New config to ensure we introduce this without causing breaking changes for users
"""
from typing import TYPE_CHECKING, Any, List, Optional
from typing import TYPE_CHECKING, Any, AsyncIterator, Union, Iterator, List, Optional
from aiohttp import ClientResponse
from litellm.llms.openai_like.chat.transformation import OpenAILikeChatConfig
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import Choices, ModelResponse
from ..common_utils import ModelResponseIterator as AiohttpOpenAIResponseIterator
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
@ -80,3 +81,15 @@ class AiohttpOpenAIChatConfig(OpenAILikeChatConfig):
model_response.object = _json_response.get("object")
model_response.system_fingerprint = _json_response.get("system_fingerprint")
return model_response
def get_model_response_iterator(
self,
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
sync_stream: bool,
json_mode: Optional[bool] = False,
):
return AiohttpOpenAIResponseIterator(
streaming_response=streaming_response,
sync_stream=sync_stream,
json_mode=json_mode,
)

View file

@ -0,0 +1,169 @@
import json
from typing import List, Optional, Union
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import (
ChatCompletionToolCallChunk,
ChatCompletionUsageBlock,
GenericStreamingChunk,
ModelResponseStream
)
class AioHttpOpenAIError(BaseLLMException):
def __init__(self, status_code, message):
super().__init__(status_code=status_code, message=message)
def validate_environment(
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
) -> dict:
"""
Return headers to use for aiopenhttp_openai chat completion request
"""
headers.update(
{
"Request-Source": "unspecified:litellm",
"accept": "application/json",
"content-type": "application/json",
}
)
if api_key:
headers["Authorization"] = f"bearer {api_key}"
return headers
class ModelResponseIterator:
def __init__(
self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False
):
self.streaming_response = streaming_response
self.response_iterator = self.streaming_response
self.json_mode = json_mode
def chunk_parser(self, chunk: dict) -> Union[GenericStreamingChunk, ModelResponseStream]:
try:
# Initialize default values
text = ""
tool_use: Optional[ChatCompletionToolCallChunk] = None
is_finished = False
finish_reason = ""
usage: Optional[ChatCompletionUsageBlock] = None
provider_specific_fields = None
# Extract the index from the chunk
index = int(chunk.get("choices", [{}])[0].get("index", 0))
# Extract the text or delta content from the first choice
delta = chunk.get("choices", [{}])[0].get("delta", {})
if "content" in delta:
text = delta["content"]
# Check for finish_reason
finish_reason = chunk.get("choices", [{}])[0].get("finish_reason", "")
# Determine if the stream has finished
is_finished = finish_reason in ("length", "stop")
# Create and return the parsed chunk
returned_chunk = GenericStreamingChunk(
text=text,
tool_use=tool_use,
is_finished=is_finished,
finish_reason=finish_reason,
usage=usage,
index=index,
provider_specific_fields=provider_specific_fields,
)
return returned_chunk
except json.JSONDecodeError:
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
# Sync iterator
def __iter__(self):
return self
def _handle_string_chunk(
self, str_line: str
) -> Union[GenericStreamingChunk, ModelResponseStream]:
# chunk is a str at this point
if "[DONE]" in str_line:
return GenericStreamingChunk(
text="",
is_finished=True,
finish_reason="stop",
usage=None,
index=0,
tool_use=None,
)
elif str_line.startswith("data:"):
data_json = json.loads(str_line[5:])
return self.chunk_parser(chunk=data_json)
else:
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
def __next__(self):
try:
chunk = self.response_iterator.__next__()
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
str_line = chunk
if isinstance(chunk, bytes): # Handle binary data
str_line = chunk.decode("utf-8") # Convert bytes to string
index = str_line.find("data:")
if index != -1:
str_line = str_line[index:]
# chunk is a str at this point
return self._handle_string_chunk(str_line=str_line)
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
# Async iterator
def __aiter__(self):
self.async_response_iterator = self.streaming_response.__aiter__()
return self
async def __anext__(self):
try:
chunk = await self.async_response_iterator.__anext__()
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
str_line = chunk
if isinstance(chunk, bytes): # Handle binary data
str_line = chunk.decode("utf-8") # Convert bytes to string
index = str_line.find("data:")
if index != -1:
str_line = str_line[index:]
# chunk is a str at this point
return self._handle_string_chunk(str_line=str_line)
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")

View file

@ -40,12 +40,10 @@ class BaseLLMAIOHTTPHandler:
) -> ClientSession:
if dynamic_client_session:
return dynamic_client_session
elif self.client_session:
return self.client_session
else:
# init client session, and then return new session
self.client_session = aiohttp.ClientSession()
if self.client_session:
return self.client_session
self.client_session = aiohttp.ClientSession()
return self.client_session
async def _make_common_async_call(
self,
@ -69,7 +67,7 @@ class BaseLLMAIOHTTPHandler:
dynamic_client_session=async_client_session
)
for i in range(max(max_retry_on_unprocessable_entity_error, 1)):
for _ in range(max(max_retry_on_unprocessable_entity_error, 1)):
try:
response = await async_client_session.post(
url=api_base,
@ -139,8 +137,7 @@ class BaseLLMAIOHTTPHandler:
)
)
continue
else:
raise self._handle_error(e=e, provider_config=provider_config)
raise self._handle_error(e=e, provider_config=provider_config)
except Exception as e:
raise self._handle_error(e=e, provider_config=provider_config)
break
@ -172,6 +169,7 @@ class BaseLLMAIOHTTPHandler:
api_key: Optional[str] = None,
client: Optional[ClientSession] = None,
):
data.pop("max_retries", None) #added this as this was extra param which is not needed for openai
_response = await self._make_common_async_call(
async_client_session=client,
provider_config=provider_config,
@ -261,7 +259,24 @@ class BaseLLMAIOHTTPHandler:
},
)
if acompletion is True:
if acompletion:
if stream:
if not fake_stream:
data["stream"] = stream
return self.acompletion_stream_function(
model=model,
messages=messages,
api_base=api_base,
headers=headers,
custom_llm_provider=custom_llm_provider,
provider_config=provider_config,
timeout=timeout,
logging_obj=logging_obj,
data=data,
fake_stream=fake_stream,
client=client if isinstance(client, ClientSession) else None,
litellm_params=litellm_params,
)
return self.async_completion(
custom_llm_provider=custom_llm_provider,
provider_config=provider_config,
@ -277,15 +292,11 @@ class BaseLLMAIOHTTPHandler:
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
client=(
client
if client is not None and isinstance(client, ClientSession)
else None
),
client=client if isinstance(client, ClientSession) else None,
)
if stream is True:
if fake_stream is not True:
if stream:
if not fake_stream:
data["stream"] = stream
completion_stream, headers = self.make_sync_call(
provider_config=provider_config,
@ -297,11 +308,7 @@ class BaseLLMAIOHTTPHandler:
logging_obj=logging_obj,
timeout=timeout,
fake_stream=fake_stream,
client=(
client
if client is not None and isinstance(client, HTTPHandler)
else None
),
client=client if isinstance(client, HTTPHandler) else None,
litellm_params=litellm_params,
)
return CustomStreamWrapper(
@ -311,11 +318,7 @@ class BaseLLMAIOHTTPHandler:
logging_obj=logging_obj,
)
if client is None or not isinstance(client, HTTPHandler):
sync_httpx_client = _get_httpx_client()
else:
sync_httpx_client = client
sync_httpx_client = client if isinstance(client, HTTPHandler) else _get_httpx_client()
response = self._make_common_sync_call(
sync_httpx_client=sync_httpx_client,
provider_config=provider_config,
@ -338,6 +341,89 @@ class BaseLLMAIOHTTPHandler:
encoding=encoding,
)
async def acompletion_stream_function(
self,
model: str,
messages: list,
api_base: str,
custom_llm_provider: str,
headers: dict,
provider_config: BaseConfig,
timeout: Union[float, httpx.Timeout],
logging_obj: LiteLLMLoggingObj,
data: dict,
litellm_params: dict,
fake_stream: bool = False,
client: Optional[ClientSession] = None,
):
completion_stream, _response_headers = await self.make_async_call(
custom_llm_provider=custom_llm_provider,
provider_config=provider_config,
api_base=api_base,
headers=headers,
data=data,
messages=messages,
logging_obj=logging_obj,
timeout=timeout,
fake_stream=fake_stream,
client=client,
litellm_params=litellm_params,
)
return CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider=custom_llm_provider,
logging_obj=logging_obj,
)
async def make_async_call(
self,
custom_llm_provider: str,
provider_config: BaseConfig,
api_base: str,
headers: dict,
data: dict,
messages: list,
logging_obj: LiteLLMLoggingObj,
timeout: Union[float, httpx.Timeout],
litellm_params: dict,
fake_stream: bool = False,
client: Optional[Union[AsyncHTTPHandler, ClientSession]] = None,
) -> Tuple[Any, httpx.Headers]:
async_client_session = self._get_async_client_session() if client is None or not isinstance(client, ClientSession) else client
stream = not fake_stream
data.pop("max_retries", None)
response = await self._make_common_async_call(
async_client_session=async_client_session,
provider_config=provider_config,
api_base=api_base,
headers=headers,
data=data,
timeout=timeout,
litellm_params=litellm_params,
stream=stream,
)
if fake_stream:
json_response = await response.json()
completion_stream = provider_config.get_model_response_iterator(
streaming_response=json_response, sync_stream=False
)
else:
completion_stream = provider_config.get_model_response_iterator(
streaming_response=response.content, sync_stream=False
)
# LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response="first stream response received:: ",
additional_args={"complete_input_dict": data},
)
return completion_stream, response.headers
def make_sync_call(
self,
provider_config: BaseConfig,
@ -352,13 +438,8 @@ class BaseLLMAIOHTTPHandler:
fake_stream: bool = False,
client: Optional[HTTPHandler] = None,
) -> Tuple[Any, dict]:
if client is None or not isinstance(client, HTTPHandler):
sync_httpx_client = _get_httpx_client()
else:
sync_httpx_client = client
stream = True
if fake_stream is True:
stream = False
sync_httpx_client = client if isinstance(client, HTTPHandler) else _get_httpx_client()
stream = not fake_stream
response = self._make_common_sync_call(
sync_httpx_client=sync_httpx_client,
@ -371,13 +452,13 @@ class BaseLLMAIOHTTPHandler:
stream=stream,
)
if fake_stream is True:
if fake_stream:
completion_stream = provider_config.get_model_response_iterator(
streaming_response=response.json(), sync_stream=True
)
else:
completion_stream = provider_config.get_model_response_iterator(
streaming_response=response.iter_lines(), sync_stream=True
streaming_response=response.content, sync_stream=True
)
# LOGGING
@ -535,13 +616,9 @@ class BaseLLMAIOHTTPHandler:
litellm_params=litellm_params,
image=image,
provider_config=provider_config,
) # type: ignore
if client is None or not isinstance(client, HTTPHandler):
sync_httpx_client = _get_httpx_client()
else:
sync_httpx_client = client
)
sync_httpx_client = client if isinstance(client, HTTPHandler) else _get_httpx_client()
response = self._make_common_sync_call(
sync_httpx_client=sync_httpx_client,
provider_config=provider_config,

View file

@ -10,6 +10,8 @@ sys.path.insert(
import litellm
from local_testing.test_streaming import streaming_format_tests
@pytest.mark.asyncio()
async def test_aiohttp_openai():
@ -31,3 +33,39 @@ async def test_aiohttp_openai_gpt_4o():
messages=[{"role": "user", "content": "Hello, world!"}],
)
print(response)
@pytest.mark.asyncio()
async def test_completion_model_stream():
litellm.set_verbose = True
api_key = os.getenv("OPENAI_API_KEY")
assert api_key is not None, "API key is not set in environment variables"
try:
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": "how does a court case get to the Supreme Court?",
},
]
response = await litellm.acompletion(
api_key=api_key, model="aiohttp_openai/gpt-4o", messages=messages, stream=True, max_tokens=50
)
complete_response = ""
idx = 0 # Initialize index manually
async for chunk in response: # Use async for to handle async iterator
chunk, finished = streaming_format_tests(idx, chunk) # Await if streaming_format_tests is async
print(f"outside chunk: {chunk}")
if finished:
break
complete_response += chunk
idx += 1 # Increment index manually
if complete_response.strip() == "":
raise Exception("Empty response received")
print(f"complete response: {complete_response}")
except Exception as e:
pytest.fail(f"Error occurred: {e}")