diff --git a/litellm/litellm_core_utils/streaming_handler.py b/litellm/litellm_core_utils/streaming_handler.py index 8dff433a97..041c01abcf 100644 --- a/litellm/litellm_core_utils/streaming_handler.py +++ b/litellm/litellm_core_utils/streaming_handler.py @@ -803,49 +803,6 @@ class CustomStreamWrapper: except Exception as e: raise e - def handle_clarifai_completion_chunk(self, chunk): - try: - if isinstance(chunk, dict): - parsed_response = chunk - elif isinstance(chunk, (str, bytes)): - if isinstance(chunk, bytes): - parsed_response = chunk.decode("utf-8") - else: - parsed_response = chunk - else: - raise ValueError("Unable to parse streaming chunk") - if isinstance(parsed_response, dict): - data_json = parsed_response - else: - data_json = json.loads(parsed_response) - text = ( - data_json.get("outputs", "")[0] - .get("data", "") - .get("text", "") - .get("raw", "") - ) - len( - encoding.encode( - data_json.get("outputs", "")[0] - .get("input", "") - .get("data", "") - .get("text", "") - .get("raw", "") - ) - ) - len(encoding.encode(text)) - return { - "text": text, - "is_finished": True, - } - except Exception as e: - verbose_logger.exception( - "litellm.CustomStreamWrapper.handle_clarifai_chunk(): Exception occured - {}".format( - str(e) - ) - ) - return "" - def model_response_creator( self, chunk: Optional[dict] = None, hidden_params: Optional[dict] = None ): @@ -1112,11 +1069,6 @@ class CustomStreamWrapper: completion_obj["content"] = response_obj["text"] if response_obj["is_finished"]: self.received_finish_reason = response_obj["finish_reason"] - elif self.custom_llm_provider and self.custom_llm_provider == "clarifai": - response_obj = self.handle_clarifai_completion_chunk(chunk) - completion_obj["content"] = response_obj["text"] - if response_obj["is_finished"]: - self.received_finish_reason = response_obj["finish_reason"] elif self.model == "replicate" or self.custom_llm_provider == "replicate": response_obj = self.handle_replicate_chunk(chunk) completion_obj["content"] = response_obj["text"] diff --git a/litellm/llms/base_llm/base_model_iterator.py b/litellm/llms/base_llm/base_model_iterator.py new file mode 100644 index 0000000000..530a6a79be --- /dev/null +++ b/litellm/llms/base_llm/base_model_iterator.py @@ -0,0 +1,43 @@ +import json +from abc import abstractmethod +from typing import List, Optional, Tuple + +import litellm +from litellm.litellm_core_utils.core_helpers import map_finish_reason +from litellm.types.utils import ( + ChatCompletionToolCallChunk, + ChatCompletionUsageBlock, + GenericStreamingChunk, + ModelResponse, +) + + +class FakeStreamResponseIterator: + def __init__(self, model_response, json_mode: Optional[bool] = False): + self.model_response = model_response + self.json_mode = json_mode + self.is_done = False + + # Sync iterator + def __iter__(self): + return self + + @abstractmethod + def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: + pass + + def __next__(self): + if self.is_done: + raise StopIteration + self.is_done = True + return self.chunk_parser(self.model_response) + + # Async iterator + def __aiter__(self): + return self + + async def __anext__(self): + if self.is_done: + raise StopAsyncIteration + self.is_done = True + return self.chunk_parser(self.model_response) diff --git a/litellm/llms/base_llm/transformation.py b/litellm/llms/base_llm/transformation.py index 1d7fe31983..7b7cb82587 100644 --- a/litellm/llms/base_llm/transformation.py +++ b/litellm/llms/base_llm/transformation.py @@ -136,7 +136,7 @@ class BaseConfig(ABC): def get_model_response_iterator( self, - streaming_response: Union[Iterator[str], AsyncIterator[str]], + streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse], sync_stream: bool, json_mode: Optional[bool] = False, ) -> Any: diff --git a/litellm/llms/clarifai/chat/handler.py b/litellm/llms/clarifai/chat/handler.py deleted file mode 100644 index cf6da51cfd..0000000000 --- a/litellm/llms/clarifai/chat/handler.py +++ /dev/null @@ -1,177 +0,0 @@ -import json -import os -import time -import traceback -import types -from typing import Callable, List, Optional - -import httpx -import requests - -import litellm -from litellm.llms.custom_httpx.http_handler import ( - AsyncHTTPHandler, - _get_httpx_client, - get_async_httpx_client, -) -from litellm.types.llms.openai import AllMessageValues -from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage - -from ...prompt_templates.factory import custom_prompt, prompt_factory -from ..common_utils import ClarifaiError - - -async def async_completion( - model: str, - messages: List[AllMessageValues], - model_response: ModelResponse, - encoding, - api_key, - api_base: str, - logging_obj, - data: dict, - optional_params: dict, - litellm_params=None, - logger_fn=None, - headers={}, -): - - async_handler = get_async_httpx_client( - llm_provider=litellm.LlmProviders.CLARIFAI, - params={"timeout": 600.0}, - ) - response = await async_handler.post( - url=api_base, headers=headers, data=json.dumps(data) - ) - - return litellm.ClarifaiConfig().transform_response( - model=model, - raw_response=response, - model_response=model_response, - logging_obj=logging_obj, - api_key=api_key, - request_data=data, - messages=messages, - optional_params=optional_params, - encoding=encoding, - ) - - -def completion( - model: str, - messages: list, - api_base: str, - model_response: ModelResponse, - print_verbose: Callable, - encoding, - api_key, - logging_obj, - optional_params: dict, - litellm_params: dict, - custom_prompt_dict={}, - acompletion=False, - logger_fn=None, - headers={}, -): - headers = litellm.ClarifaiConfig().validate_environment( - api_key=api_key, - headers=headers, - model=model, - messages=messages, - optional_params=optional_params, - ) - data = litellm.ClarifaiConfig().transform_request( - model=model, - messages=messages, - optional_params=optional_params, - litellm_params=litellm_params, - headers=headers, - ) - - ## LOGGING - logging_obj.pre_call( - input=data, - api_key=api_key, - additional_args={ - "complete_input_dict": data, - "headers": headers, - "api_base": model, - }, - ) - if acompletion is True: - return async_completion( - model=model, - messages=messages, - api_base=api_base, - model_response=model_response, - encoding=encoding, - api_key=api_key, - logging_obj=logging_obj, - data=data, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - headers=headers, - ) - else: - ## COMPLETION CALL - httpx_client = _get_httpx_client( - params={"timeout": 600.0}, - ) - response = httpx_client.post( - url=api_base, - headers=headers, - data=json.dumps(data), - ) - - if response.status_code != 200: - raise ClarifaiError(status_code=response.status_code, message=response.text) - - if "stream" in optional_params and optional_params["stream"] is True: - completion_stream = response.iter_lines() - stream_response = CustomStreamWrapper( - completion_stream=completion_stream, - model=model, - custom_llm_provider="clarifai", - logging_obj=logging_obj, - ) - return stream_response - - else: - return litellm.ClarifaiConfig().transform_response( - model=model, - raw_response=response, - model_response=model_response, - logging_obj=logging_obj, - api_key=api_key, - request_data=data, - messages=messages, - optional_params=optional_params, - encoding=encoding, - ) - - -class ModelResponseIterator: - def __init__(self, model_response): - self.model_response = model_response - self.is_done = False - - # Sync iterator - def __iter__(self): - return self - - def __next__(self): - if self.is_done: - raise StopIteration - self.is_done = True - return self.model_response - - # Async iterator - def __aiter__(self): - return self - - async def __anext__(self): - if self.is_done: - raise StopAsyncIteration - self.is_done = True - return self.model_response diff --git a/litellm/llms/clarifai/chat/transformation.py b/litellm/llms/clarifai/chat/transformation.py index dde3b42a3c..ae2705d025 100644 --- a/litellm/llms/clarifai/chat/transformation.py +++ b/litellm/llms/clarifai/chat/transformation.py @@ -1,13 +1,23 @@ +import json import types -from typing import TYPE_CHECKING, Any, List, Optional, Union +from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union import httpx import litellm +from litellm.llms.base_llm.base_model_iterator import FakeStreamResponseIterator from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException from litellm.llms.prompt_templates.common_utils import convert_content_list_to_str from litellm.types.llms.openai import AllMessageValues -from litellm.types.utils import Choices, Message, ModelResponse, Usage +from litellm.types.utils import ( + ChatCompletionToolCallChunk, + ChatCompletionUsageBlock, + Choices, + GenericStreamingChunk, + Message, + ModelResponse, + Usage, +) from litellm.utils import token_counter from ..common_utils import ClarifaiError @@ -199,3 +209,56 @@ class ClarifaiConfig(BaseConfig): ), ) return model_response + + def get_model_response_iterator( + self, + streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse], + sync_stream: bool, + json_mode: Optional[bool] = False, + ) -> Any: + return ClarifaiModelResponseIterator( + model_response=streaming_response, + json_mode=json_mode, + ) + + +class ClarifaiModelResponseIterator(FakeStreamResponseIterator): + def __init__( + self, + model_response: Union[Iterator[str], AsyncIterator[str], ModelResponse], + json_mode: Optional[bool] = False, + ): + super().__init__( + model_response=model_response, + json_mode=json_mode, + ) + + def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: + try: + text = "" + tool_use: Optional[ChatCompletionToolCallChunk] = None + is_finished = False + finish_reason = "" + usage: Optional[ChatCompletionUsageBlock] = None + provider_specific_fields = None + + text = ( + chunk.get("outputs", "")[0] + .get("data", "") + .get("text", "") + .get("raw", "") + ) + + index: int = 0 + + return GenericStreamingChunk( + text=text, + tool_use=tool_use, + is_finished=is_finished, + finish_reason=finish_reason, + usage=usage, + index=index, + provider_specific_fields=provider_specific_fields, + ) + except json.JSONDecodeError: + raise ValueError(f"Failed to decode JSON from chunk: {chunk}") diff --git a/litellm/llms/cohere/chat/transformation.py b/litellm/llms/cohere/chat/transformation.py index 7723907569..e5af5a0235 100644 --- a/litellm/llms/cohere/chat/transformation.py +++ b/litellm/llms/cohere/chat/transformation.py @@ -368,7 +368,7 @@ class CohereChatConfig(BaseConfig): def get_model_response_iterator( self, - streaming_response: Union[Iterator[str], AsyncIterator[str]], + streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse], sync_stream: bool, json_mode: Optional[bool] = False, ): diff --git a/litellm/llms/custom_httpx/llm_http_handler.py b/litellm/llms/custom_httpx/llm_http_handler.py index f960cbbb55..51df6dfbbe 100644 --- a/litellm/llms/custom_httpx/llm_http_handler.py +++ b/litellm/llms/custom_httpx/llm_http_handler.py @@ -93,6 +93,7 @@ class BaseLLMHTTPHandler: litellm_params: dict, acompletion: bool, stream: Optional[bool] = False, + fake_stream: bool = False, api_key: Optional[str] = None, headers={}, ): @@ -129,7 +130,8 @@ class BaseLLMHTTPHandler: if acompletion is True: if stream is True: - data["stream"] = stream + if fake_stream is not True: + data["stream"] = stream return self.acompletion_stream_function( model=model, messages=messages, @@ -140,6 +142,7 @@ class BaseLLMHTTPHandler: timeout=timeout, logging_obj=logging_obj, data=data, + fake_stream=fake_stream, ) else: @@ -160,7 +163,8 @@ class BaseLLMHTTPHandler: ) if stream is True: - data["stream"] = stream + if fake_stream is not True: + data["stream"] = stream completion_stream, headers = self.make_sync_call( provider_config=provider_config, api_base=api_base, @@ -170,6 +174,7 @@ class BaseLLMHTTPHandler: messages=messages, logging_obj=logging_obj, timeout=timeout, + fake_stream=fake_stream, ) return CustomStreamWrapper( completion_stream=completion_stream, @@ -215,11 +220,15 @@ class BaseLLMHTTPHandler: messages: list, logging_obj, timeout: Optional[Union[float, httpx.Timeout]], + fake_stream: bool = False, ) -> Tuple[Any, httpx.Headers]: sync_httpx_client = _get_httpx_client() try: + stream = True + if fake_stream is True: + stream = False response = sync_httpx_client.post( - api_base, headers=headers, data=data, stream=True, timeout=timeout + api_base, headers=headers, data=data, timeout=timeout, stream=stream ) except httpx.HTTPStatusError as e: raise self._handle_error( @@ -240,9 +249,15 @@ class BaseLLMHTTPHandler: status_code=response.status_code, message=str(response.read()), ) - completion_stream = provider_config.get_model_response_iterator( - streaming_response=response.iter_lines(), sync_stream=True - ) + + if fake_stream is True: + completion_stream = provider_config.get_model_response_iterator( + streaming_response=response.json(), sync_stream=True + ) + else: + completion_stream = provider_config.get_model_response_iterator( + streaming_response=response.iter_lines(), sync_stream=True + ) # LOGGING logging_obj.post_call( @@ -265,8 +280,8 @@ class BaseLLMHTTPHandler: timeout: Union[float, httpx.Timeout], logging_obj: LiteLLMLoggingObj, data: dict, + fake_stream: bool = False, ): - data["stream"] = True completion_stream, _response_headers = await self.make_async_call( custom_llm_provider=custom_llm_provider, provider_config=provider_config, @@ -276,6 +291,7 @@ class BaseLLMHTTPHandler: messages=messages, logging_obj=logging_obj, timeout=timeout, + fake_stream=fake_stream, ) streamwrapper = CustomStreamWrapper( completion_stream=completion_stream, @@ -295,13 +311,17 @@ class BaseLLMHTTPHandler: messages: list, logging_obj: LiteLLMLoggingObj, timeout: Optional[Union[float, httpx.Timeout]], + fake_stream: bool = False, ) -> Tuple[Any, httpx.Headers]: async_httpx_client = get_async_httpx_client( llm_provider=litellm.LlmProviders(custom_llm_provider) ) + stream = True + if fake_stream is True: + stream = False try: response = await async_httpx_client.post( - api_base, headers=headers, data=data, stream=True, timeout=timeout + api_base, headers=headers, data=data, stream=stream, timeout=timeout ) except httpx.HTTPStatusError as e: raise self._handle_error( @@ -322,10 +342,14 @@ class BaseLLMHTTPHandler: status_code=response.status_code, message=str(response.read()), ) - - completion_stream = provider_config.get_model_response_iterator( - streaming_response=response.aiter_lines(), sync_stream=False - ) + if fake_stream is True: + completion_stream = provider_config.get_model_response_iterator( + streaming_response=response.json(), sync_stream=False + ) + else: + completion_stream = provider_config.get_model_response_iterator( + streaming_response=response.aiter_lines(), sync_stream=False + ) # LOGGING logging_obj.post_call( input=messages, diff --git a/litellm/main.py b/litellm/main.py index 71621f14b8..8713bb932e 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -110,7 +110,6 @@ from .llms.azure_text import AzureTextCompletion from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM from .llms.bedrock.embed.embedding import BedrockEmbedding from .llms.bedrock.image.image_handler import BedrockImageGeneration -from .llms.clarifai.chat import handler from .llms.cohere.completion import completion as cohere_completion # type: ignore from .llms.cohere.embed import handler as cohere_embed from .llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler @@ -1689,41 +1688,23 @@ def completion( # type: ignore # noqa: PLR0915 or "https://api.clarifai.com/v2" ) api_base = litellm.ClarifaiConfig()._convert_model_to_url(model, api_base) - custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - model_response = handler.completion( + response = base_llm_http_handler.completion( model=model, + stream=stream, + fake_stream=True, # clarifai does not support streaming, we fake it messages=messages, + acompletion=acompletion, api_base=api_base, model_response=model_response, - print_verbose=print_verbose, optional_params=optional_params, litellm_params=litellm_params, - acompletion=acompletion, - logger_fn=logger_fn, - encoding=encoding, # for calculating input/output tokens + custom_llm_provider="clarifai", + timeout=timeout, + headers=headers, + encoding=encoding, api_key=clarifai_key, - logging_obj=logging, - custom_prompt_dict=custom_prompt_dict, + logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements ) - - if "stream" in optional_params and optional_params["stream"] is True: - # don't try to access stream object, - ## LOGGING - logging.post_call( - input=messages, - api_key=api_key, - original_response=model_response, - ) - - if optional_params.get("stream", False) or acompletion is True: - ## LOGGING - logging.post_call( - input=messages, - api_key=clarifai_key, - original_response=model_response, - ) - response = model_response - elif custom_llm_provider == "anthropic": api_key = ( api_key diff --git a/tests/local_testing/test_clarifai_completion.py b/tests/llm_translation/test_clarifai_completion.py similarity index 100% rename from tests/local_testing/test_clarifai_completion.py rename to tests/llm_translation/test_clarifai_completion.py