From 47e811d6cee259a9a62b8b49214003e8ac1b20ca Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Wed, 16 Apr 2025 10:15:11 -0700 Subject: [PATCH] fix(llm_http_handler.py): fix fake streaming (#10061) * fix(llm_http_handler.py): fix fake streaming allows groq to work with llm_http_handler * fix(groq.py): migrate groq to openai like config ensures json mode handling works correctly --- litellm/llms/base_llm/base_model_iterator.py | 64 ++++++++++++++++++- litellm/llms/custom_httpx/llm_http_handler.py | 52 +++++++++++++-- litellm/llms/groq/chat/transformation.py | 11 +++- .../llms/openai_like/chat/transformation.py | 45 +++++++++++-- litellm/types/utils.py | 4 +- 5 files changed, 157 insertions(+), 19 deletions(-) diff --git a/litellm/llms/base_llm/base_model_iterator.py b/litellm/llms/base_llm/base_model_iterator.py index 90dcc52fef..4cf757d6cd 100644 --- a/litellm/llms/base_llm/base_model_iterator.py +++ b/litellm/llms/base_llm/base_model_iterator.py @@ -1,9 +1,16 @@ import json from abc import abstractmethod -from typing import Optional, Union +from typing import List, Optional, Union, cast import litellm -from litellm.types.utils import GenericStreamingChunk, ModelResponseStream +from litellm.types.utils import ( + Choices, + Delta, + GenericStreamingChunk, + ModelResponse, + ModelResponseStream, + StreamingChoices, +) class BaseModelResponseIterator: @@ -121,6 +128,59 @@ class BaseModelResponseIterator: raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") +class MockResponseIterator: # for returning ai21 streaming responses + def __init__( + self, model_response: ModelResponse, json_mode: Optional[bool] = False + ): + self.model_response = model_response + self.json_mode = json_mode + self.is_done = False + + # Sync iterator + def __iter__(self): + return self + + def _chunk_parser(self, chunk_data: ModelResponse) -> ModelResponseStream: + try: + streaming_choices: List[StreamingChoices] = [] + for choice in chunk_data.choices: + streaming_choices.append( + StreamingChoices( + index=choice.index, + delta=Delta( + **cast(Choices, choice).message.model_dump(), + ), + finish_reason=choice.finish_reason, + ) + ) + processed_chunk = ModelResponseStream( + id=chunk_data.id, + object="chat.completion", + created=chunk_data.created, + model=chunk_data.model, + choices=streaming_choices, + ) + return processed_chunk + except Exception as e: + raise ValueError(f"Failed to decode chunk: {chunk_data}. Error: {e}") + + def __next__(self): + if self.is_done: + raise StopIteration + self.is_done = True + return self._chunk_parser(self.model_response) + + # Async iterator + def __aiter__(self): + return self + + async def __anext__(self): + if self.is_done: + raise StopAsyncIteration + self.is_done = True + return self._chunk_parser(self.model_response) + + class FakeStreamResponseIterator: def __init__(self, model_response, json_mode: Optional[bool] = False): self.model_response = model_response diff --git a/litellm/llms/custom_httpx/llm_http_handler.py b/litellm/llms/custom_httpx/llm_http_handler.py index 4a48120e1e..1ab8a94adf 100644 --- a/litellm/llms/custom_httpx/llm_http_handler.py +++ b/litellm/llms/custom_httpx/llm_http_handler.py @@ -11,6 +11,7 @@ from litellm._logging import verbose_logger from litellm.llms.base_llm.audio_transcription.transformation import ( BaseAudioTranscriptionConfig, ) +from litellm.llms.base_llm.base_model_iterator import MockResponseIterator from litellm.llms.base_llm.chat.transformation import BaseConfig from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig from litellm.llms.base_llm.files.transformation import BaseFilesConfig @@ -231,6 +232,7 @@ class BaseLLMHTTPHandler: ): json_mode: bool = optional_params.pop("json_mode", False) extra_body: Optional[dict] = optional_params.pop("extra_body", None) + fake_stream = fake_stream or optional_params.pop("fake_stream", False) provider_config = ProviderConfigManager.get_provider_chat_config( model=model, provider=litellm.LlmProviders(custom_llm_provider) @@ -317,6 +319,7 @@ class BaseLLMHTTPHandler: ), litellm_params=litellm_params, json_mode=json_mode, + optional_params=optional_params, ) else: @@ -378,6 +381,7 @@ class BaseLLMHTTPHandler: ), litellm_params=litellm_params, json_mode=json_mode, + optional_params=optional_params, ) return CustomStreamWrapper( completion_stream=completion_stream, @@ -426,6 +430,7 @@ class BaseLLMHTTPHandler: model: str, messages: list, logging_obj, + optional_params: dict, litellm_params: dict, timeout: Union[float, httpx.Timeout], fake_stream: bool = False, @@ -457,11 +462,22 @@ class BaseLLMHTTPHandler: ) if fake_stream is True: - completion_stream = provider_config.get_model_response_iterator( - streaming_response=response.json(), - sync_stream=True, + model_response: (ModelResponse) = provider_config.transform_response( + model=model, + raw_response=response, + model_response=litellm.ModelResponse(), + logging_obj=logging_obj, + request_data=data, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + encoding=None, json_mode=json_mode, ) + + completion_stream: Any = MockResponseIterator( + model_response=model_response, json_mode=json_mode + ) else: completion_stream = provider_config.get_model_response_iterator( streaming_response=response.iter_lines(), @@ -491,6 +507,7 @@ class BaseLLMHTTPHandler: logging_obj: LiteLLMLoggingObj, data: dict, litellm_params: dict, + optional_params: dict, fake_stream: bool = False, client: Optional[AsyncHTTPHandler] = None, json_mode: Optional[bool] = None, @@ -509,6 +526,7 @@ class BaseLLMHTTPHandler: ) completion_stream, _response_headers = await self.make_async_call_stream_helper( + model=model, custom_llm_provider=custom_llm_provider, provider_config=provider_config, api_base=api_base, @@ -520,6 +538,8 @@ class BaseLLMHTTPHandler: fake_stream=fake_stream, client=client, litellm_params=litellm_params, + optional_params=optional_params, + json_mode=json_mode, ) streamwrapper = CustomStreamWrapper( completion_stream=completion_stream, @@ -531,6 +551,7 @@ class BaseLLMHTTPHandler: async def make_async_call_stream_helper( self, + model: str, custom_llm_provider: str, provider_config: BaseConfig, api_base: str, @@ -540,8 +561,10 @@ class BaseLLMHTTPHandler: logging_obj: LiteLLMLoggingObj, timeout: Union[float, httpx.Timeout], litellm_params: dict, + optional_params: dict, fake_stream: bool = False, client: Optional[AsyncHTTPHandler] = None, + json_mode: Optional[bool] = None, ) -> Tuple[Any, httpx.Headers]: """ Helper function for making an async call with stream. @@ -572,8 +595,21 @@ class BaseLLMHTTPHandler: ) if fake_stream is True: - completion_stream = provider_config.get_model_response_iterator( - streaming_response=response.json(), sync_stream=False + model_response: (ModelResponse) = provider_config.transform_response( + model=model, + raw_response=response, + model_response=litellm.ModelResponse(), + logging_obj=logging_obj, + request_data=data, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + encoding=None, + json_mode=json_mode, + ) + + completion_stream: Any = MockResponseIterator( + model_response=model_response, json_mode=json_mode ) else: completion_stream = provider_config.get_model_response_iterator( @@ -598,8 +634,12 @@ class BaseLLMHTTPHandler: """ Some providers like Bedrock invoke do not support the stream parameter in the request body, we only pass `stream` in the request body the provider supports it. """ + if fake_stream is True: - return data + # remove 'stream' from data + new_data = data.copy() + new_data.pop("stream", None) + return new_data if provider_config.supports_stream_param_in_request_body is True: data["stream"] = True return data diff --git a/litellm/llms/groq/chat/transformation.py b/litellm/llms/groq/chat/transformation.py index a8972635f3..4befdc504e 100644 --- a/litellm/llms/groq/chat/transformation.py +++ b/litellm/llms/groq/chat/transformation.py @@ -14,10 +14,10 @@ from litellm.types.llms.openai import ( ChatCompletionToolParamFunctionChunk, ) -from ...openai.chat.gpt_transformation import OpenAIGPTConfig +from ...openai_like.chat.transformation import OpenAILikeChatConfig -class GroqChatConfig(OpenAIGPTConfig): +class GroqChatConfig(OpenAILikeChatConfig): frequency_penalty: Optional[int] = None function_call: Optional[Union[str, dict]] = None functions: Optional[list] = None @@ -132,8 +132,11 @@ class GroqChatConfig(OpenAIGPTConfig): optional_params: dict, model: str, drop_params: bool = False, + replace_max_completion_tokens_with_max_tokens: bool = False, # groq supports max_completion_tokens ) -> dict: _response_format = non_default_params.get("response_format") + if self._should_fake_stream(non_default_params): + optional_params["fake_stream"] = True if _response_format is not None and isinstance(_response_format, dict): json_schema: Optional[dict] = None if "response_schema" in _response_format: @@ -160,6 +163,8 @@ class GroqChatConfig(OpenAIGPTConfig): non_default_params.pop( "response_format", None ) # only remove if it's a json_schema - handled via using groq's tool calling params. - return super().map_openai_params( + optional_params = super().map_openai_params( non_default_params, optional_params, model, drop_params ) + + return optional_params diff --git a/litellm/llms/openai_like/chat/transformation.py b/litellm/llms/openai_like/chat/transformation.py index ea9757a855..068d3d8dfd 100644 --- a/litellm/llms/openai_like/chat/transformation.py +++ b/litellm/llms/openai_like/chat/transformation.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union import httpx from litellm.secret_managers.main import get_secret_str -from litellm.types.llms.openai import ChatCompletionAssistantMessage +from litellm.types.llms.openai import AllMessageValues, ChatCompletionAssistantMessage from litellm.types.utils import ModelResponse from ...openai.chat.gpt_transformation import OpenAIGPTConfig @@ -25,7 +25,6 @@ class OpenAILikeChatConfig(OpenAIGPTConfig): self, api_base: Optional[str], api_key: Optional[str], - model: Optional[str] = None, ) -> Tuple[Optional[str], Optional[str]]: api_base = api_base or get_secret_str("OPENAI_LIKE_API_BASE") # type: ignore dynamic_api_key = ( @@ -74,8 +73,8 @@ class OpenAILikeChatConfig(OpenAIGPTConfig): messages: List, print_verbose, encoding, - json_mode: bool, - custom_llm_provider: str, + json_mode: Optional[bool], + custom_llm_provider: Optional[str], base_model: Optional[str], ) -> ModelResponse: response_json = response.json() @@ -97,14 +96,46 @@ class OpenAILikeChatConfig(OpenAIGPTConfig): returned_response = ModelResponse(**response_json) - returned_response.model = ( - custom_llm_provider + "/" + (returned_response.model or "") - ) + if custom_llm_provider is not None: + returned_response.model = ( + custom_llm_provider + "/" + (returned_response.model or "") + ) if base_model is not None: returned_response._hidden_params["model"] = base_model return returned_response + def transform_response( + self, + model: str, + raw_response: httpx.Response, + model_response: ModelResponse, + logging_obj: LiteLLMLoggingObj, + request_data: dict, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + encoding: Any, + api_key: Optional[str] = None, + json_mode: Optional[bool] = None, + ) -> ModelResponse: + return OpenAILikeChatConfig._transform_response( + model=model, + response=raw_response, + model_response=model_response, + stream=optional_params.get("stream", False), + logging_obj=logging_obj, + optional_params=optional_params, + api_key=api_key, + data=request_data, + messages=messages, + print_verbose=None, + encoding=None, + json_mode=json_mode, + custom_llm_provider=None, + base_model=None, + ) + def map_openai_params( self, non_default_params: dict, diff --git a/litellm/types/utils.py b/litellm/types/utils.py index d15c66ab98..35a584b6cf 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1011,7 +1011,9 @@ class ModelResponseStream(ModelResponseBase): def __init__( self, - choices: Optional[List[Union[StreamingChoices, dict, BaseModel]]] = None, + choices: Optional[ + Union[List[StreamingChoices], Union[StreamingChoices, dict, BaseModel]] + ] = None, id: Optional[str] = None, created: Optional[int] = None, provider_specific_fields: Optional[Dict[str, Any]] = None,