fix(llm_http_handler.py): fix fake streaming (#10061)

* fix(llm_http_handler.py): fix fake streaming

allows groq to work with llm_http_handler

* fix(groq.py): migrate groq to openai like config

ensures json mode handling works correctly
This commit is contained in:
Krish Dholakia 2025-04-16 10:15:11 -07:00 committed by GitHub
parent dc29fc2ea0
commit 3a3cc97fc8
5 changed files with 157 additions and 19 deletions

View file

@ -1,9 +1,16 @@
import json import json
from abc import abstractmethod from abc import abstractmethod
from typing import Optional, Union from typing import List, Optional, Union, cast
import litellm import litellm
from litellm.types.utils import GenericStreamingChunk, ModelResponseStream from litellm.types.utils import (
Choices,
Delta,
GenericStreamingChunk,
ModelResponse,
ModelResponseStream,
StreamingChoices,
)
class BaseModelResponseIterator: class BaseModelResponseIterator:
@ -121,6 +128,59 @@ class BaseModelResponseIterator:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
class MockResponseIterator: # for returning ai21 streaming responses
def __init__(
self, model_response: ModelResponse, json_mode: Optional[bool] = False
):
self.model_response = model_response
self.json_mode = json_mode
self.is_done = False
# Sync iterator
def __iter__(self):
return self
def _chunk_parser(self, chunk_data: ModelResponse) -> ModelResponseStream:
try:
streaming_choices: List[StreamingChoices] = []
for choice in chunk_data.choices:
streaming_choices.append(
StreamingChoices(
index=choice.index,
delta=Delta(
**cast(Choices, choice).message.model_dump(),
),
finish_reason=choice.finish_reason,
)
)
processed_chunk = ModelResponseStream(
id=chunk_data.id,
object="chat.completion",
created=chunk_data.created,
model=chunk_data.model,
choices=streaming_choices,
)
return processed_chunk
except Exception as e:
raise ValueError(f"Failed to decode chunk: {chunk_data}. Error: {e}")
def __next__(self):
if self.is_done:
raise StopIteration
self.is_done = True
return self._chunk_parser(self.model_response)
# Async iterator
def __aiter__(self):
return self
async def __anext__(self):
if self.is_done:
raise StopAsyncIteration
self.is_done = True
return self._chunk_parser(self.model_response)
class FakeStreamResponseIterator: class FakeStreamResponseIterator:
def __init__(self, model_response, json_mode: Optional[bool] = False): def __init__(self, model_response, json_mode: Optional[bool] = False):
self.model_response = model_response self.model_response = model_response

View file

@ -11,6 +11,7 @@ from litellm._logging import verbose_logger
from litellm.llms.base_llm.audio_transcription.transformation import ( from litellm.llms.base_llm.audio_transcription.transformation import (
BaseAudioTranscriptionConfig, BaseAudioTranscriptionConfig,
) )
from litellm.llms.base_llm.base_model_iterator import MockResponseIterator
from litellm.llms.base_llm.chat.transformation import BaseConfig from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
from litellm.llms.base_llm.files.transformation import BaseFilesConfig from litellm.llms.base_llm.files.transformation import BaseFilesConfig
@ -231,6 +232,7 @@ class BaseLLMHTTPHandler:
): ):
json_mode: bool = optional_params.pop("json_mode", False) json_mode: bool = optional_params.pop("json_mode", False)
extra_body: Optional[dict] = optional_params.pop("extra_body", None) extra_body: Optional[dict] = optional_params.pop("extra_body", None)
fake_stream = fake_stream or optional_params.pop("fake_stream", False)
provider_config = ProviderConfigManager.get_provider_chat_config( provider_config = ProviderConfigManager.get_provider_chat_config(
model=model, provider=litellm.LlmProviders(custom_llm_provider) model=model, provider=litellm.LlmProviders(custom_llm_provider)
@ -317,6 +319,7 @@ class BaseLLMHTTPHandler:
), ),
litellm_params=litellm_params, litellm_params=litellm_params,
json_mode=json_mode, json_mode=json_mode,
optional_params=optional_params,
) )
else: else:
@ -378,6 +381,7 @@ class BaseLLMHTTPHandler:
), ),
litellm_params=litellm_params, litellm_params=litellm_params,
json_mode=json_mode, json_mode=json_mode,
optional_params=optional_params,
) )
return CustomStreamWrapper( return CustomStreamWrapper(
completion_stream=completion_stream, completion_stream=completion_stream,
@ -426,6 +430,7 @@ class BaseLLMHTTPHandler:
model: str, model: str,
messages: list, messages: list,
logging_obj, logging_obj,
optional_params: dict,
litellm_params: dict, litellm_params: dict,
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
fake_stream: bool = False, fake_stream: bool = False,
@ -457,11 +462,22 @@ class BaseLLMHTTPHandler:
) )
if fake_stream is True: if fake_stream is True:
completion_stream = provider_config.get_model_response_iterator( model_response: (ModelResponse) = provider_config.transform_response(
streaming_response=response.json(), model=model,
sync_stream=True, raw_response=response,
model_response=litellm.ModelResponse(),
logging_obj=logging_obj,
request_data=data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=None,
json_mode=json_mode, json_mode=json_mode,
) )
completion_stream: Any = MockResponseIterator(
model_response=model_response, json_mode=json_mode
)
else: else:
completion_stream = provider_config.get_model_response_iterator( completion_stream = provider_config.get_model_response_iterator(
streaming_response=response.iter_lines(), streaming_response=response.iter_lines(),
@ -491,6 +507,7 @@ class BaseLLMHTTPHandler:
logging_obj: LiteLLMLoggingObj, logging_obj: LiteLLMLoggingObj,
data: dict, data: dict,
litellm_params: dict, litellm_params: dict,
optional_params: dict,
fake_stream: bool = False, fake_stream: bool = False,
client: Optional[AsyncHTTPHandler] = None, client: Optional[AsyncHTTPHandler] = None,
json_mode: Optional[bool] = None, json_mode: Optional[bool] = None,
@ -509,6 +526,7 @@ class BaseLLMHTTPHandler:
) )
completion_stream, _response_headers = await self.make_async_call_stream_helper( completion_stream, _response_headers = await self.make_async_call_stream_helper(
model=model,
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
provider_config=provider_config, provider_config=provider_config,
api_base=api_base, api_base=api_base,
@ -520,6 +538,8 @@ class BaseLLMHTTPHandler:
fake_stream=fake_stream, fake_stream=fake_stream,
client=client, client=client,
litellm_params=litellm_params, litellm_params=litellm_params,
optional_params=optional_params,
json_mode=json_mode,
) )
streamwrapper = CustomStreamWrapper( streamwrapper = CustomStreamWrapper(
completion_stream=completion_stream, completion_stream=completion_stream,
@ -531,6 +551,7 @@ class BaseLLMHTTPHandler:
async def make_async_call_stream_helper( async def make_async_call_stream_helper(
self, self,
model: str,
custom_llm_provider: str, custom_llm_provider: str,
provider_config: BaseConfig, provider_config: BaseConfig,
api_base: str, api_base: str,
@ -540,8 +561,10 @@ class BaseLLMHTTPHandler:
logging_obj: LiteLLMLoggingObj, logging_obj: LiteLLMLoggingObj,
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
litellm_params: dict, litellm_params: dict,
optional_params: dict,
fake_stream: bool = False, fake_stream: bool = False,
client: Optional[AsyncHTTPHandler] = None, client: Optional[AsyncHTTPHandler] = None,
json_mode: Optional[bool] = None,
) -> Tuple[Any, httpx.Headers]: ) -> Tuple[Any, httpx.Headers]:
""" """
Helper function for making an async call with stream. Helper function for making an async call with stream.
@ -572,8 +595,21 @@ class BaseLLMHTTPHandler:
) )
if fake_stream is True: if fake_stream is True:
completion_stream = provider_config.get_model_response_iterator( model_response: (ModelResponse) = provider_config.transform_response(
streaming_response=response.json(), sync_stream=False model=model,
raw_response=response,
model_response=litellm.ModelResponse(),
logging_obj=logging_obj,
request_data=data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=None,
json_mode=json_mode,
)
completion_stream: Any = MockResponseIterator(
model_response=model_response, json_mode=json_mode
) )
else: else:
completion_stream = provider_config.get_model_response_iterator( completion_stream = provider_config.get_model_response_iterator(
@ -598,8 +634,12 @@ class BaseLLMHTTPHandler:
""" """
Some providers like Bedrock invoke do not support the stream parameter in the request body, we only pass `stream` in the request body the provider supports it. Some providers like Bedrock invoke do not support the stream parameter in the request body, we only pass `stream` in the request body the provider supports it.
""" """
if fake_stream is True: if fake_stream is True:
return data # remove 'stream' from data
new_data = data.copy()
new_data.pop("stream", None)
return new_data
if provider_config.supports_stream_param_in_request_body is True: if provider_config.supports_stream_param_in_request_body is True:
data["stream"] = True data["stream"] = True
return data return data

View file

@ -14,10 +14,10 @@ from litellm.types.llms.openai import (
ChatCompletionToolParamFunctionChunk, ChatCompletionToolParamFunctionChunk,
) )
from ...openai.chat.gpt_transformation import OpenAIGPTConfig from ...openai_like.chat.transformation import OpenAILikeChatConfig
class GroqChatConfig(OpenAIGPTConfig): class GroqChatConfig(OpenAILikeChatConfig):
frequency_penalty: Optional[int] = None frequency_penalty: Optional[int] = None
function_call: Optional[Union[str, dict]] = None function_call: Optional[Union[str, dict]] = None
functions: Optional[list] = None functions: Optional[list] = None
@ -132,8 +132,11 @@ class GroqChatConfig(OpenAIGPTConfig):
optional_params: dict, optional_params: dict,
model: str, model: str,
drop_params: bool = False, drop_params: bool = False,
replace_max_completion_tokens_with_max_tokens: bool = False, # groq supports max_completion_tokens
) -> dict: ) -> dict:
_response_format = non_default_params.get("response_format") _response_format = non_default_params.get("response_format")
if self._should_fake_stream(non_default_params):
optional_params["fake_stream"] = True
if _response_format is not None and isinstance(_response_format, dict): if _response_format is not None and isinstance(_response_format, dict):
json_schema: Optional[dict] = None json_schema: Optional[dict] = None
if "response_schema" in _response_format: if "response_schema" in _response_format:
@ -160,6 +163,8 @@ class GroqChatConfig(OpenAIGPTConfig):
non_default_params.pop( non_default_params.pop(
"response_format", None "response_format", None
) # only remove if it's a json_schema - handled via using groq's tool calling params. ) # only remove if it's a json_schema - handled via using groq's tool calling params.
return super().map_openai_params( optional_params = super().map_openai_params(
non_default_params, optional_params, model, drop_params non_default_params, optional_params, model, drop_params
) )
return optional_params

View file

@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
import httpx import httpx
from litellm.secret_managers.main import get_secret_str from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import ChatCompletionAssistantMessage from litellm.types.llms.openai import AllMessageValues, ChatCompletionAssistantMessage
from litellm.types.utils import ModelResponse from litellm.types.utils import ModelResponse
from ...openai.chat.gpt_transformation import OpenAIGPTConfig from ...openai.chat.gpt_transformation import OpenAIGPTConfig
@ -25,7 +25,6 @@ class OpenAILikeChatConfig(OpenAIGPTConfig):
self, self,
api_base: Optional[str], api_base: Optional[str],
api_key: Optional[str], api_key: Optional[str],
model: Optional[str] = None,
) -> Tuple[Optional[str], Optional[str]]: ) -> Tuple[Optional[str], Optional[str]]:
api_base = api_base or get_secret_str("OPENAI_LIKE_API_BASE") # type: ignore api_base = api_base or get_secret_str("OPENAI_LIKE_API_BASE") # type: ignore
dynamic_api_key = ( dynamic_api_key = (
@ -74,8 +73,8 @@ class OpenAILikeChatConfig(OpenAIGPTConfig):
messages: List, messages: List,
print_verbose, print_verbose,
encoding, encoding,
json_mode: bool, json_mode: Optional[bool],
custom_llm_provider: str, custom_llm_provider: Optional[str],
base_model: Optional[str], base_model: Optional[str],
) -> ModelResponse: ) -> ModelResponse:
response_json = response.json() response_json = response.json()
@ -97,6 +96,7 @@ class OpenAILikeChatConfig(OpenAIGPTConfig):
returned_response = ModelResponse(**response_json) returned_response = ModelResponse(**response_json)
if custom_llm_provider is not None:
returned_response.model = ( returned_response.model = (
custom_llm_provider + "/" + (returned_response.model or "") custom_llm_provider + "/" + (returned_response.model or "")
) )
@ -105,6 +105,37 @@ class OpenAILikeChatConfig(OpenAIGPTConfig):
returned_response._hidden_params["model"] = base_model returned_response._hidden_params["model"] = base_model
return returned_response return returned_response
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
return OpenAILikeChatConfig._transform_response(
model=model,
response=raw_response,
model_response=model_response,
stream=optional_params.get("stream", False),
logging_obj=logging_obj,
optional_params=optional_params,
api_key=api_key,
data=request_data,
messages=messages,
print_verbose=None,
encoding=None,
json_mode=json_mode,
custom_llm_provider=None,
base_model=None,
)
def map_openai_params( def map_openai_params(
self, self,
non_default_params: dict, non_default_params: dict,

View file

@ -1011,7 +1011,9 @@ class ModelResponseStream(ModelResponseBase):
def __init__( def __init__(
self, self,
choices: Optional[List[Union[StreamingChoices, dict, BaseModel]]] = None, choices: Optional[
Union[List[StreamingChoices], Union[StreamingChoices, dict, BaseModel]]
] = None,
id: Optional[str] = None, id: Optional[str] = None,
created: Optional[int] = None, created: Optional[int] = None,
provider_specific_fields: Optional[Dict[str, Any]] = None, provider_specific_fields: Optional[Dict[str, Any]] = None,