mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
(Refactor) Code Quality improvement - Use Common base handler for clarifai/
(#7125)
* use base_llm_http_handler for clarifai * fix clarifai completion * handle faking streaming base llm http handler * add fake streaming for clarifai * add FakeStreamResponseIterator for base model iterator * fix get_model_response_iterator * fix base model iterator * fix base model iterator * add support for faking sync streams clarfiai * add fake streaming for clarifai * remove unused code * fix import * fix llm http handler * test_async_completion_clarifai * fix clarifai tests * fix linting
This commit is contained in:
parent
c5e0407703
commit
28ff38e35d
9 changed files with 155 additions and 269 deletions
|
@ -803,49 +803,6 @@ class CustomStreamWrapper:
|
|||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def handle_clarifai_completion_chunk(self, chunk):
|
||||
try:
|
||||
if isinstance(chunk, dict):
|
||||
parsed_response = chunk
|
||||
elif isinstance(chunk, (str, bytes)):
|
||||
if isinstance(chunk, bytes):
|
||||
parsed_response = chunk.decode("utf-8")
|
||||
else:
|
||||
parsed_response = chunk
|
||||
else:
|
||||
raise ValueError("Unable to parse streaming chunk")
|
||||
if isinstance(parsed_response, dict):
|
||||
data_json = parsed_response
|
||||
else:
|
||||
data_json = json.loads(parsed_response)
|
||||
text = (
|
||||
data_json.get("outputs", "")[0]
|
||||
.get("data", "")
|
||||
.get("text", "")
|
||||
.get("raw", "")
|
||||
)
|
||||
len(
|
||||
encoding.encode(
|
||||
data_json.get("outputs", "")[0]
|
||||
.get("input", "")
|
||||
.get("data", "")
|
||||
.get("text", "")
|
||||
.get("raw", "")
|
||||
)
|
||||
)
|
||||
len(encoding.encode(text))
|
||||
return {
|
||||
"text": text,
|
||||
"is_finished": True,
|
||||
}
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
"litellm.CustomStreamWrapper.handle_clarifai_chunk(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
return ""
|
||||
|
||||
def model_response_creator(
|
||||
self, chunk: Optional[dict] = None, hidden_params: Optional[dict] = None
|
||||
):
|
||||
|
@ -1112,11 +1069,6 @@ class CustomStreamWrapper:
|
|||
completion_obj["content"] = response_obj["text"]
|
||||
if response_obj["is_finished"]:
|
||||
self.received_finish_reason = response_obj["finish_reason"]
|
||||
elif self.custom_llm_provider and self.custom_llm_provider == "clarifai":
|
||||
response_obj = self.handle_clarifai_completion_chunk(chunk)
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
if response_obj["is_finished"]:
|
||||
self.received_finish_reason = response_obj["finish_reason"]
|
||||
elif self.model == "replicate" or self.custom_llm_provider == "replicate":
|
||||
response_obj = self.handle_replicate_chunk(chunk)
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
|
|
43
litellm/llms/base_llm/base_model_iterator.py
Normal file
43
litellm/llms/base_llm/base_model_iterator.py
Normal file
|
@ -0,0 +1,43 @@
|
|||
import json
|
||||
from abc import abstractmethod
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionUsageBlock,
|
||||
GenericStreamingChunk,
|
||||
ModelResponse,
|
||||
)
|
||||
|
||||
|
||||
class FakeStreamResponseIterator:
|
||||
def __init__(self, model_response, json_mode: Optional[bool] = False):
|
||||
self.model_response = model_response
|
||||
self.json_mode = json_mode
|
||||
self.is_done = False
|
||||
|
||||
# Sync iterator
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
@abstractmethod
|
||||
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||
pass
|
||||
|
||||
def __next__(self):
|
||||
if self.is_done:
|
||||
raise StopIteration
|
||||
self.is_done = True
|
||||
return self.chunk_parser(self.model_response)
|
||||
|
||||
# Async iterator
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if self.is_done:
|
||||
raise StopAsyncIteration
|
||||
self.is_done = True
|
||||
return self.chunk_parser(self.model_response)
|
|
@ -136,7 +136,7 @@ class BaseConfig(ABC):
|
|||
|
||||
def get_model_response_iterator(
|
||||
self,
|
||||
streaming_response: Union[Iterator[str], AsyncIterator[str]],
|
||||
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||
sync_stream: bool,
|
||||
json_mode: Optional[bool] = False,
|
||||
) -> Any:
|
||||
|
|
|
@ -1,177 +0,0 @@
|
|||
import json
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
import types
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import httpx
|
||||
import requests
|
||||
|
||||
import litellm
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage
|
||||
|
||||
from ...prompt_templates.factory import custom_prompt, prompt_factory
|
||||
from ..common_utils import ClarifaiError
|
||||
|
||||
|
||||
async def async_completion(
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
model_response: ModelResponse,
|
||||
encoding,
|
||||
api_key,
|
||||
api_base: str,
|
||||
logging_obj,
|
||||
data: dict,
|
||||
optional_params: dict,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
):
|
||||
|
||||
async_handler = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.CLARIFAI,
|
||||
params={"timeout": 600.0},
|
||||
)
|
||||
response = await async_handler.post(
|
||||
url=api_base, headers=headers, data=json.dumps(data)
|
||||
)
|
||||
|
||||
return litellm.ClarifaiConfig().transform_response(
|
||||
model=model,
|
||||
raw_response=response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
request_data=data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
|
||||
def completion(
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
custom_prompt_dict={},
|
||||
acompletion=False,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
):
|
||||
headers = litellm.ClarifaiConfig().validate_environment(
|
||||
api_key=api_key,
|
||||
headers=headers,
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
data = litellm.ClarifaiConfig().transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=data,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"headers": headers,
|
||||
"api_base": model,
|
||||
},
|
||||
)
|
||||
if acompletion is True:
|
||||
return async_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
model_response=model_response,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
logging_obj=logging_obj,
|
||||
data=data,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
)
|
||||
else:
|
||||
## COMPLETION CALL
|
||||
httpx_client = _get_httpx_client(
|
||||
params={"timeout": 600.0},
|
||||
)
|
||||
response = httpx_client.post(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise ClarifaiError(status_code=response.status_code, message=response.text)
|
||||
|
||||
if "stream" in optional_params and optional_params["stream"] is True:
|
||||
completion_stream = response.iter_lines()
|
||||
stream_response = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="clarifai",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return stream_response
|
||||
|
||||
else:
|
||||
return litellm.ClarifaiConfig().transform_response(
|
||||
model=model,
|
||||
raw_response=response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
request_data=data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
|
||||
class ModelResponseIterator:
|
||||
def __init__(self, model_response):
|
||||
self.model_response = model_response
|
||||
self.is_done = False
|
||||
|
||||
# Sync iterator
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.is_done:
|
||||
raise StopIteration
|
||||
self.is_done = True
|
||||
return self.model_response
|
||||
|
||||
# Async iterator
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if self.is_done:
|
||||
raise StopAsyncIteration
|
||||
self.is_done = True
|
||||
return self.model_response
|
|
@ -1,13 +1,23 @@
|
|||
import json
|
||||
import types
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.base_model_iterator import FakeStreamResponseIterator
|
||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.llms.prompt_templates.common_utils import convert_content_list_to_str
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import Choices, Message, ModelResponse, Usage
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionUsageBlock,
|
||||
Choices,
|
||||
GenericStreamingChunk,
|
||||
Message,
|
||||
ModelResponse,
|
||||
Usage,
|
||||
)
|
||||
from litellm.utils import token_counter
|
||||
|
||||
from ..common_utils import ClarifaiError
|
||||
|
@ -199,3 +209,56 @@ class ClarifaiConfig(BaseConfig):
|
|||
),
|
||||
)
|
||||
return model_response
|
||||
|
||||
def get_model_response_iterator(
|
||||
self,
|
||||
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||
sync_stream: bool,
|
||||
json_mode: Optional[bool] = False,
|
||||
) -> Any:
|
||||
return ClarifaiModelResponseIterator(
|
||||
model_response=streaming_response,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
|
||||
class ClarifaiModelResponseIterator(FakeStreamResponseIterator):
|
||||
def __init__(
|
||||
self,
|
||||
model_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||
json_mode: Optional[bool] = False,
|
||||
):
|
||||
super().__init__(
|
||||
model_response=model_response,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||
try:
|
||||
text = ""
|
||||
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||
is_finished = False
|
||||
finish_reason = ""
|
||||
usage: Optional[ChatCompletionUsageBlock] = None
|
||||
provider_specific_fields = None
|
||||
|
||||
text = (
|
||||
chunk.get("outputs", "")[0]
|
||||
.get("data", "")
|
||||
.get("text", "")
|
||||
.get("raw", "")
|
||||
)
|
||||
|
||||
index: int = 0
|
||||
|
||||
return GenericStreamingChunk(
|
||||
text=text,
|
||||
tool_use=tool_use,
|
||||
is_finished=is_finished,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
index=index,
|
||||
provider_specific_fields=provider_specific_fields,
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
|
||||
|
|
|
@ -368,7 +368,7 @@ class CohereChatConfig(BaseConfig):
|
|||
|
||||
def get_model_response_iterator(
|
||||
self,
|
||||
streaming_response: Union[Iterator[str], AsyncIterator[str]],
|
||||
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||
sync_stream: bool,
|
||||
json_mode: Optional[bool] = False,
|
||||
):
|
||||
|
|
|
@ -93,6 +93,7 @@ class BaseLLMHTTPHandler:
|
|||
litellm_params: dict,
|
||||
acompletion: bool,
|
||||
stream: Optional[bool] = False,
|
||||
fake_stream: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
headers={},
|
||||
):
|
||||
|
@ -129,7 +130,8 @@ class BaseLLMHTTPHandler:
|
|||
|
||||
if acompletion is True:
|
||||
if stream is True:
|
||||
data["stream"] = stream
|
||||
if fake_stream is not True:
|
||||
data["stream"] = stream
|
||||
return self.acompletion_stream_function(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -140,6 +142,7 @@ class BaseLLMHTTPHandler:
|
|||
timeout=timeout,
|
||||
logging_obj=logging_obj,
|
||||
data=data,
|
||||
fake_stream=fake_stream,
|
||||
)
|
||||
|
||||
else:
|
||||
|
@ -160,7 +163,8 @@ class BaseLLMHTTPHandler:
|
|||
)
|
||||
|
||||
if stream is True:
|
||||
data["stream"] = stream
|
||||
if fake_stream is not True:
|
||||
data["stream"] = stream
|
||||
completion_stream, headers = self.make_sync_call(
|
||||
provider_config=provider_config,
|
||||
api_base=api_base,
|
||||
|
@ -170,6 +174,7 @@ class BaseLLMHTTPHandler:
|
|||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
timeout=timeout,
|
||||
fake_stream=fake_stream,
|
||||
)
|
||||
return CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
|
@ -215,11 +220,15 @@ class BaseLLMHTTPHandler:
|
|||
messages: list,
|
||||
logging_obj,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
fake_stream: bool = False,
|
||||
) -> Tuple[Any, httpx.Headers]:
|
||||
sync_httpx_client = _get_httpx_client()
|
||||
try:
|
||||
stream = True
|
||||
if fake_stream is True:
|
||||
stream = False
|
||||
response = sync_httpx_client.post(
|
||||
api_base, headers=headers, data=data, stream=True, timeout=timeout
|
||||
api_base, headers=headers, data=data, timeout=timeout, stream=stream
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise self._handle_error(
|
||||
|
@ -240,9 +249,15 @@ class BaseLLMHTTPHandler:
|
|||
status_code=response.status_code,
|
||||
message=str(response.read()),
|
||||
)
|
||||
completion_stream = provider_config.get_model_response_iterator(
|
||||
streaming_response=response.iter_lines(), sync_stream=True
|
||||
)
|
||||
|
||||
if fake_stream is True:
|
||||
completion_stream = provider_config.get_model_response_iterator(
|
||||
streaming_response=response.json(), sync_stream=True
|
||||
)
|
||||
else:
|
||||
completion_stream = provider_config.get_model_response_iterator(
|
||||
streaming_response=response.iter_lines(), sync_stream=True
|
||||
)
|
||||
|
||||
# LOGGING
|
||||
logging_obj.post_call(
|
||||
|
@ -265,8 +280,8 @@ class BaseLLMHTTPHandler:
|
|||
timeout: Union[float, httpx.Timeout],
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
data: dict,
|
||||
fake_stream: bool = False,
|
||||
):
|
||||
data["stream"] = True
|
||||
completion_stream, _response_headers = await self.make_async_call(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
provider_config=provider_config,
|
||||
|
@ -276,6 +291,7 @@ class BaseLLMHTTPHandler:
|
|||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
timeout=timeout,
|
||||
fake_stream=fake_stream,
|
||||
)
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
|
@ -295,13 +311,17 @@ class BaseLLMHTTPHandler:
|
|||
messages: list,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
fake_stream: bool = False,
|
||||
) -> Tuple[Any, httpx.Headers]:
|
||||
async_httpx_client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders(custom_llm_provider)
|
||||
)
|
||||
stream = True
|
||||
if fake_stream is True:
|
||||
stream = False
|
||||
try:
|
||||
response = await async_httpx_client.post(
|
||||
api_base, headers=headers, data=data, stream=True, timeout=timeout
|
||||
api_base, headers=headers, data=data, stream=stream, timeout=timeout
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise self._handle_error(
|
||||
|
@ -322,10 +342,14 @@ class BaseLLMHTTPHandler:
|
|||
status_code=response.status_code,
|
||||
message=str(response.read()),
|
||||
)
|
||||
|
||||
completion_stream = provider_config.get_model_response_iterator(
|
||||
streaming_response=response.aiter_lines(), sync_stream=False
|
||||
)
|
||||
if fake_stream is True:
|
||||
completion_stream = provider_config.get_model_response_iterator(
|
||||
streaming_response=response.json(), sync_stream=False
|
||||
)
|
||||
else:
|
||||
completion_stream = provider_config.get_model_response_iterator(
|
||||
streaming_response=response.aiter_lines(), sync_stream=False
|
||||
)
|
||||
# LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
|
|
|
@ -110,7 +110,6 @@ from .llms.azure_text import AzureTextCompletion
|
|||
from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
|
||||
from .llms.bedrock.embed.embedding import BedrockEmbedding
|
||||
from .llms.bedrock.image.image_handler import BedrockImageGeneration
|
||||
from .llms.clarifai.chat import handler
|
||||
from .llms.cohere.completion import completion as cohere_completion # type: ignore
|
||||
from .llms.cohere.embed import handler as cohere_embed
|
||||
from .llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
||||
|
@ -1689,41 +1688,23 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
or "https://api.clarifai.com/v2"
|
||||
)
|
||||
api_base = litellm.ClarifaiConfig()._convert_model_to_url(model, api_base)
|
||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||
model_response = handler.completion(
|
||||
response = base_llm_http_handler.completion(
|
||||
model=model,
|
||||
stream=stream,
|
||||
fake_stream=True, # clarifai does not support streaming, we fake it
|
||||
messages=messages,
|
||||
acompletion=acompletion,
|
||||
api_base=api_base,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
acompletion=acompletion,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding, # for calculating input/output tokens
|
||||
custom_llm_provider="clarifai",
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
encoding=encoding,
|
||||
api_key=clarifai_key,
|
||||
logging_obj=logging,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
||||
)
|
||||
|
||||
if "stream" in optional_params and optional_params["stream"] is True:
|
||||
# don't try to access stream object,
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=model_response,
|
||||
)
|
||||
|
||||
if optional_params.get("stream", False) or acompletion is True:
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
api_key=clarifai_key,
|
||||
original_response=model_response,
|
||||
)
|
||||
response = model_response
|
||||
|
||||
elif custom_llm_provider == "anthropic":
|
||||
api_key = (
|
||||
api_key
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue