(Refactor) Code Quality improvement - Use Common base handler for clarifai/ (#7125)

* use base_llm_http_handler for clarifai

* fix clarifai completion

* handle faking streaming base llm http handler

* add fake streaming for clarifai

* add FakeStreamResponseIterator for base model iterator

* fix get_model_response_iterator

* fix base model iterator

* fix base model iterator

* add support for faking sync streams clarfiai

* add fake streaming for clarifai

* remove unused code

* fix import

* fix llm http handler

* test_async_completion_clarifai

* fix clarifai tests

* fix linting
This commit is contained in:
Ishaan Jaff 2024-12-09 21:04:48 -08:00 committed by GitHub
parent c5e0407703
commit 28ff38e35d
9 changed files with 155 additions and 269 deletions

View file

@ -803,49 +803,6 @@ class CustomStreamWrapper:
except Exception as e: except Exception as e:
raise e raise e
def handle_clarifai_completion_chunk(self, chunk):
try:
if isinstance(chunk, dict):
parsed_response = chunk
elif isinstance(chunk, (str, bytes)):
if isinstance(chunk, bytes):
parsed_response = chunk.decode("utf-8")
else:
parsed_response = chunk
else:
raise ValueError("Unable to parse streaming chunk")
if isinstance(parsed_response, dict):
data_json = parsed_response
else:
data_json = json.loads(parsed_response)
text = (
data_json.get("outputs", "")[0]
.get("data", "")
.get("text", "")
.get("raw", "")
)
len(
encoding.encode(
data_json.get("outputs", "")[0]
.get("input", "")
.get("data", "")
.get("text", "")
.get("raw", "")
)
)
len(encoding.encode(text))
return {
"text": text,
"is_finished": True,
}
except Exception as e:
verbose_logger.exception(
"litellm.CustomStreamWrapper.handle_clarifai_chunk(): Exception occured - {}".format(
str(e)
)
)
return ""
def model_response_creator( def model_response_creator(
self, chunk: Optional[dict] = None, hidden_params: Optional[dict] = None self, chunk: Optional[dict] = None, hidden_params: Optional[dict] = None
): ):
@ -1112,11 +1069,6 @@ class CustomStreamWrapper:
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]: if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"] self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider and self.custom_llm_provider == "clarifai":
response_obj = self.handle_clarifai_completion_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
elif self.model == "replicate" or self.custom_llm_provider == "replicate": elif self.model == "replicate" or self.custom_llm_provider == "replicate":
response_obj = self.handle_replicate_chunk(chunk) response_obj = self.handle_replicate_chunk(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]

View file

@ -0,0 +1,43 @@
import json
from abc import abstractmethod
from typing import List, Optional, Tuple
import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.types.utils import (
ChatCompletionToolCallChunk,
ChatCompletionUsageBlock,
GenericStreamingChunk,
ModelResponse,
)
class FakeStreamResponseIterator:
def __init__(self, model_response, json_mode: Optional[bool] = False):
self.model_response = model_response
self.json_mode = json_mode
self.is_done = False
# Sync iterator
def __iter__(self):
return self
@abstractmethod
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
pass
def __next__(self):
if self.is_done:
raise StopIteration
self.is_done = True
return self.chunk_parser(self.model_response)
# Async iterator
def __aiter__(self):
return self
async def __anext__(self):
if self.is_done:
raise StopAsyncIteration
self.is_done = True
return self.chunk_parser(self.model_response)

View file

@ -136,7 +136,7 @@ class BaseConfig(ABC):
def get_model_response_iterator( def get_model_response_iterator(
self, self,
streaming_response: Union[Iterator[str], AsyncIterator[str]], streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
sync_stream: bool, sync_stream: bool,
json_mode: Optional[bool] = False, json_mode: Optional[bool] = False,
) -> Any: ) -> Any:

View file

@ -1,177 +0,0 @@
import json
import os
import time
import traceback
import types
from typing import Callable, List, Optional
import httpx
import requests
import litellm
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
_get_httpx_client,
get_async_httpx_client,
)
from litellm.types.llms.openai import AllMessageValues
from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage
from ...prompt_templates.factory import custom_prompt, prompt_factory
from ..common_utils import ClarifaiError
async def async_completion(
model: str,
messages: List[AllMessageValues],
model_response: ModelResponse,
encoding,
api_key,
api_base: str,
logging_obj,
data: dict,
optional_params: dict,
litellm_params=None,
logger_fn=None,
headers={},
):
async_handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.CLARIFAI,
params={"timeout": 600.0},
)
response = await async_handler.post(
url=api_base, headers=headers, data=json.dumps(data)
)
return litellm.ClarifaiConfig().transform_response(
model=model,
raw_response=response,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
request_data=data,
messages=messages,
optional_params=optional_params,
encoding=encoding,
)
def completion(
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params: dict,
litellm_params: dict,
custom_prompt_dict={},
acompletion=False,
logger_fn=None,
headers={},
):
headers = litellm.ClarifaiConfig().validate_environment(
api_key=api_key,
headers=headers,
model=model,
messages=messages,
optional_params=optional_params,
)
data = litellm.ClarifaiConfig().transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
## LOGGING
logging_obj.pre_call(
input=data,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": model,
},
)
if acompletion is True:
return async_completion(
model=model,
messages=messages,
api_base=api_base,
model_response=model_response,
encoding=encoding,
api_key=api_key,
logging_obj=logging_obj,
data=data,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
)
else:
## COMPLETION CALL
httpx_client = _get_httpx_client(
params={"timeout": 600.0},
)
response = httpx_client.post(
url=api_base,
headers=headers,
data=json.dumps(data),
)
if response.status_code != 200:
raise ClarifaiError(status_code=response.status_code, message=response.text)
if "stream" in optional_params and optional_params["stream"] is True:
completion_stream = response.iter_lines()
stream_response = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="clarifai",
logging_obj=logging_obj,
)
return stream_response
else:
return litellm.ClarifaiConfig().transform_response(
model=model,
raw_response=response,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
request_data=data,
messages=messages,
optional_params=optional_params,
encoding=encoding,
)
class ModelResponseIterator:
def __init__(self, model_response):
self.model_response = model_response
self.is_done = False
# Sync iterator
def __iter__(self):
return self
def __next__(self):
if self.is_done:
raise StopIteration
self.is_done = True
return self.model_response
# Async iterator
def __aiter__(self):
return self
async def __anext__(self):
if self.is_done:
raise StopAsyncIteration
self.is_done = True
return self.model_response

View file

@ -1,13 +1,23 @@
import json
import types import types
from typing import TYPE_CHECKING, Any, List, Optional, Union from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union
import httpx import httpx
import litellm import litellm
from litellm.llms.base_llm.base_model_iterator import FakeStreamResponseIterator
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.llms.prompt_templates.common_utils import convert_content_list_to_str from litellm.llms.prompt_templates.common_utils import convert_content_list_to_str
from litellm.types.llms.openai import AllMessageValues from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import Choices, Message, ModelResponse, Usage from litellm.types.utils import (
ChatCompletionToolCallChunk,
ChatCompletionUsageBlock,
Choices,
GenericStreamingChunk,
Message,
ModelResponse,
Usage,
)
from litellm.utils import token_counter from litellm.utils import token_counter
from ..common_utils import ClarifaiError from ..common_utils import ClarifaiError
@ -199,3 +209,56 @@ class ClarifaiConfig(BaseConfig):
), ),
) )
return model_response return model_response
def get_model_response_iterator(
self,
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
sync_stream: bool,
json_mode: Optional[bool] = False,
) -> Any:
return ClarifaiModelResponseIterator(
model_response=streaming_response,
json_mode=json_mode,
)
class ClarifaiModelResponseIterator(FakeStreamResponseIterator):
def __init__(
self,
model_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
json_mode: Optional[bool] = False,
):
super().__init__(
model_response=model_response,
json_mode=json_mode,
)
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
try:
text = ""
tool_use: Optional[ChatCompletionToolCallChunk] = None
is_finished = False
finish_reason = ""
usage: Optional[ChatCompletionUsageBlock] = None
provider_specific_fields = None
text = (
chunk.get("outputs", "")[0]
.get("data", "")
.get("text", "")
.get("raw", "")
)
index: int = 0
return GenericStreamingChunk(
text=text,
tool_use=tool_use,
is_finished=is_finished,
finish_reason=finish_reason,
usage=usage,
index=index,
provider_specific_fields=provider_specific_fields,
)
except json.JSONDecodeError:
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")

View file

@ -368,7 +368,7 @@ class CohereChatConfig(BaseConfig):
def get_model_response_iterator( def get_model_response_iterator(
self, self,
streaming_response: Union[Iterator[str], AsyncIterator[str]], streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
sync_stream: bool, sync_stream: bool,
json_mode: Optional[bool] = False, json_mode: Optional[bool] = False,
): ):

View file

@ -93,6 +93,7 @@ class BaseLLMHTTPHandler:
litellm_params: dict, litellm_params: dict,
acompletion: bool, acompletion: bool,
stream: Optional[bool] = False, stream: Optional[bool] = False,
fake_stream: bool = False,
api_key: Optional[str] = None, api_key: Optional[str] = None,
headers={}, headers={},
): ):
@ -129,7 +130,8 @@ class BaseLLMHTTPHandler:
if acompletion is True: if acompletion is True:
if stream is True: if stream is True:
data["stream"] = stream if fake_stream is not True:
data["stream"] = stream
return self.acompletion_stream_function( return self.acompletion_stream_function(
model=model, model=model,
messages=messages, messages=messages,
@ -140,6 +142,7 @@ class BaseLLMHTTPHandler:
timeout=timeout, timeout=timeout,
logging_obj=logging_obj, logging_obj=logging_obj,
data=data, data=data,
fake_stream=fake_stream,
) )
else: else:
@ -160,7 +163,8 @@ class BaseLLMHTTPHandler:
) )
if stream is True: if stream is True:
data["stream"] = stream if fake_stream is not True:
data["stream"] = stream
completion_stream, headers = self.make_sync_call( completion_stream, headers = self.make_sync_call(
provider_config=provider_config, provider_config=provider_config,
api_base=api_base, api_base=api_base,
@ -170,6 +174,7 @@ class BaseLLMHTTPHandler:
messages=messages, messages=messages,
logging_obj=logging_obj, logging_obj=logging_obj,
timeout=timeout, timeout=timeout,
fake_stream=fake_stream,
) )
return CustomStreamWrapper( return CustomStreamWrapper(
completion_stream=completion_stream, completion_stream=completion_stream,
@ -215,11 +220,15 @@ class BaseLLMHTTPHandler:
messages: list, messages: list,
logging_obj, logging_obj,
timeout: Optional[Union[float, httpx.Timeout]], timeout: Optional[Union[float, httpx.Timeout]],
fake_stream: bool = False,
) -> Tuple[Any, httpx.Headers]: ) -> Tuple[Any, httpx.Headers]:
sync_httpx_client = _get_httpx_client() sync_httpx_client = _get_httpx_client()
try: try:
stream = True
if fake_stream is True:
stream = False
response = sync_httpx_client.post( response = sync_httpx_client.post(
api_base, headers=headers, data=data, stream=True, timeout=timeout api_base, headers=headers, data=data, timeout=timeout, stream=stream
) )
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
raise self._handle_error( raise self._handle_error(
@ -240,9 +249,15 @@ class BaseLLMHTTPHandler:
status_code=response.status_code, status_code=response.status_code,
message=str(response.read()), message=str(response.read()),
) )
completion_stream = provider_config.get_model_response_iterator(
streaming_response=response.iter_lines(), sync_stream=True if fake_stream is True:
) completion_stream = provider_config.get_model_response_iterator(
streaming_response=response.json(), sync_stream=True
)
else:
completion_stream = provider_config.get_model_response_iterator(
streaming_response=response.iter_lines(), sync_stream=True
)
# LOGGING # LOGGING
logging_obj.post_call( logging_obj.post_call(
@ -265,8 +280,8 @@ class BaseLLMHTTPHandler:
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
logging_obj: LiteLLMLoggingObj, logging_obj: LiteLLMLoggingObj,
data: dict, data: dict,
fake_stream: bool = False,
): ):
data["stream"] = True
completion_stream, _response_headers = await self.make_async_call( completion_stream, _response_headers = await self.make_async_call(
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
provider_config=provider_config, provider_config=provider_config,
@ -276,6 +291,7 @@ class BaseLLMHTTPHandler:
messages=messages, messages=messages,
logging_obj=logging_obj, logging_obj=logging_obj,
timeout=timeout, timeout=timeout,
fake_stream=fake_stream,
) )
streamwrapper = CustomStreamWrapper( streamwrapper = CustomStreamWrapper(
completion_stream=completion_stream, completion_stream=completion_stream,
@ -295,13 +311,17 @@ class BaseLLMHTTPHandler:
messages: list, messages: list,
logging_obj: LiteLLMLoggingObj, logging_obj: LiteLLMLoggingObj,
timeout: Optional[Union[float, httpx.Timeout]], timeout: Optional[Union[float, httpx.Timeout]],
fake_stream: bool = False,
) -> Tuple[Any, httpx.Headers]: ) -> Tuple[Any, httpx.Headers]:
async_httpx_client = get_async_httpx_client( async_httpx_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders(custom_llm_provider) llm_provider=litellm.LlmProviders(custom_llm_provider)
) )
stream = True
if fake_stream is True:
stream = False
try: try:
response = await async_httpx_client.post( response = await async_httpx_client.post(
api_base, headers=headers, data=data, stream=True, timeout=timeout api_base, headers=headers, data=data, stream=stream, timeout=timeout
) )
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
raise self._handle_error( raise self._handle_error(
@ -322,10 +342,14 @@ class BaseLLMHTTPHandler:
status_code=response.status_code, status_code=response.status_code,
message=str(response.read()), message=str(response.read()),
) )
if fake_stream is True:
completion_stream = provider_config.get_model_response_iterator( completion_stream = provider_config.get_model_response_iterator(
streaming_response=response.aiter_lines(), sync_stream=False streaming_response=response.json(), sync_stream=False
) )
else:
completion_stream = provider_config.get_model_response_iterator(
streaming_response=response.aiter_lines(), sync_stream=False
)
# LOGGING # LOGGING
logging_obj.post_call( logging_obj.post_call(
input=messages, input=messages,

View file

@ -110,7 +110,6 @@ from .llms.azure_text import AzureTextCompletion
from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
from .llms.bedrock.embed.embedding import BedrockEmbedding from .llms.bedrock.embed.embedding import BedrockEmbedding
from .llms.bedrock.image.image_handler import BedrockImageGeneration from .llms.bedrock.image.image_handler import BedrockImageGeneration
from .llms.clarifai.chat import handler
from .llms.cohere.completion import completion as cohere_completion # type: ignore from .llms.cohere.completion import completion as cohere_completion # type: ignore
from .llms.cohere.embed import handler as cohere_embed from .llms.cohere.embed import handler as cohere_embed
from .llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler from .llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
@ -1689,41 +1688,23 @@ def completion( # type: ignore # noqa: PLR0915
or "https://api.clarifai.com/v2" or "https://api.clarifai.com/v2"
) )
api_base = litellm.ClarifaiConfig()._convert_model_to_url(model, api_base) api_base = litellm.ClarifaiConfig()._convert_model_to_url(model, api_base)
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict response = base_llm_http_handler.completion(
model_response = handler.completion(
model=model, model=model,
stream=stream,
fake_stream=True, # clarifai does not support streaming, we fake it
messages=messages, messages=messages,
acompletion=acompletion,
api_base=api_base, api_base=api_base,
model_response=model_response, model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params, litellm_params=litellm_params,
acompletion=acompletion, custom_llm_provider="clarifai",
logger_fn=logger_fn, timeout=timeout,
encoding=encoding, # for calculating input/output tokens headers=headers,
encoding=encoding,
api_key=clarifai_key, api_key=clarifai_key,
logging_obj=logging, logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
custom_prompt_dict=custom_prompt_dict,
) )
if "stream" in optional_params and optional_params["stream"] is True:
# don't try to access stream object,
## LOGGING
logging.post_call(
input=messages,
api_key=api_key,
original_response=model_response,
)
if optional_params.get("stream", False) or acompletion is True:
## LOGGING
logging.post_call(
input=messages,
api_key=clarifai_key,
original_response=model_response,
)
response = model_response
elif custom_llm_provider == "anthropic": elif custom_llm_provider == "anthropic":
api_key = ( api_key = (
api_key api_key