mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
(Refactor) Code Quality improvement - Use Common base handler for clarifai/
(#7125)
* use base_llm_http_handler for clarifai * fix clarifai completion * handle faking streaming base llm http handler * add fake streaming for clarifai * add FakeStreamResponseIterator for base model iterator * fix get_model_response_iterator * fix base model iterator * fix base model iterator * add support for faking sync streams clarfiai * add fake streaming for clarifai * remove unused code * fix import * fix llm http handler * test_async_completion_clarifai * fix clarifai tests * fix linting
This commit is contained in:
parent
c5e0407703
commit
28ff38e35d
9 changed files with 155 additions and 269 deletions
|
@ -803,49 +803,6 @@ class CustomStreamWrapper:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def handle_clarifai_completion_chunk(self, chunk):
|
|
||||||
try:
|
|
||||||
if isinstance(chunk, dict):
|
|
||||||
parsed_response = chunk
|
|
||||||
elif isinstance(chunk, (str, bytes)):
|
|
||||||
if isinstance(chunk, bytes):
|
|
||||||
parsed_response = chunk.decode("utf-8")
|
|
||||||
else:
|
|
||||||
parsed_response = chunk
|
|
||||||
else:
|
|
||||||
raise ValueError("Unable to parse streaming chunk")
|
|
||||||
if isinstance(parsed_response, dict):
|
|
||||||
data_json = parsed_response
|
|
||||||
else:
|
|
||||||
data_json = json.loads(parsed_response)
|
|
||||||
text = (
|
|
||||||
data_json.get("outputs", "")[0]
|
|
||||||
.get("data", "")
|
|
||||||
.get("text", "")
|
|
||||||
.get("raw", "")
|
|
||||||
)
|
|
||||||
len(
|
|
||||||
encoding.encode(
|
|
||||||
data_json.get("outputs", "")[0]
|
|
||||||
.get("input", "")
|
|
||||||
.get("data", "")
|
|
||||||
.get("text", "")
|
|
||||||
.get("raw", "")
|
|
||||||
)
|
|
||||||
)
|
|
||||||
len(encoding.encode(text))
|
|
||||||
return {
|
|
||||||
"text": text,
|
|
||||||
"is_finished": True,
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
verbose_logger.exception(
|
|
||||||
"litellm.CustomStreamWrapper.handle_clarifai_chunk(): Exception occured - {}".format(
|
|
||||||
str(e)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return ""
|
|
||||||
|
|
||||||
def model_response_creator(
|
def model_response_creator(
|
||||||
self, chunk: Optional[dict] = None, hidden_params: Optional[dict] = None
|
self, chunk: Optional[dict] = None, hidden_params: Optional[dict] = None
|
||||||
):
|
):
|
||||||
|
@ -1112,11 +1069,6 @@ class CustomStreamWrapper:
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
if response_obj["is_finished"]:
|
if response_obj["is_finished"]:
|
||||||
self.received_finish_reason = response_obj["finish_reason"]
|
self.received_finish_reason = response_obj["finish_reason"]
|
||||||
elif self.custom_llm_provider and self.custom_llm_provider == "clarifai":
|
|
||||||
response_obj = self.handle_clarifai_completion_chunk(chunk)
|
|
||||||
completion_obj["content"] = response_obj["text"]
|
|
||||||
if response_obj["is_finished"]:
|
|
||||||
self.received_finish_reason = response_obj["finish_reason"]
|
|
||||||
elif self.model == "replicate" or self.custom_llm_provider == "replicate":
|
elif self.model == "replicate" or self.custom_llm_provider == "replicate":
|
||||||
response_obj = self.handle_replicate_chunk(chunk)
|
response_obj = self.handle_replicate_chunk(chunk)
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
|
|
43
litellm/llms/base_llm/base_model_iterator.py
Normal file
43
litellm/llms/base_llm/base_model_iterator.py
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
import json
|
||||||
|
from abc import abstractmethod
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
|
from litellm.types.utils import (
|
||||||
|
ChatCompletionToolCallChunk,
|
||||||
|
ChatCompletionUsageBlock,
|
||||||
|
GenericStreamingChunk,
|
||||||
|
ModelResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FakeStreamResponseIterator:
|
||||||
|
def __init__(self, model_response, json_mode: Optional[bool] = False):
|
||||||
|
self.model_response = model_response
|
||||||
|
self.json_mode = json_mode
|
||||||
|
self.is_done = False
|
||||||
|
|
||||||
|
# Sync iterator
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
if self.is_done:
|
||||||
|
raise StopIteration
|
||||||
|
self.is_done = True
|
||||||
|
return self.chunk_parser(self.model_response)
|
||||||
|
|
||||||
|
# Async iterator
|
||||||
|
def __aiter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self):
|
||||||
|
if self.is_done:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
self.is_done = True
|
||||||
|
return self.chunk_parser(self.model_response)
|
|
@ -136,7 +136,7 @@ class BaseConfig(ABC):
|
||||||
|
|
||||||
def get_model_response_iterator(
|
def get_model_response_iterator(
|
||||||
self,
|
self,
|
||||||
streaming_response: Union[Iterator[str], AsyncIterator[str]],
|
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||||
sync_stream: bool,
|
sync_stream: bool,
|
||||||
json_mode: Optional[bool] = False,
|
json_mode: Optional[bool] = False,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
|
|
|
@ -1,177 +0,0 @@
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
import traceback
|
|
||||||
import types
|
|
||||||
from typing import Callable, List, Optional
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
import requests
|
|
||||||
|
|
||||||
import litellm
|
|
||||||
from litellm.llms.custom_httpx.http_handler import (
|
|
||||||
AsyncHTTPHandler,
|
|
||||||
_get_httpx_client,
|
|
||||||
get_async_httpx_client,
|
|
||||||
)
|
|
||||||
from litellm.types.llms.openai import AllMessageValues
|
|
||||||
from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage
|
|
||||||
|
|
||||||
from ...prompt_templates.factory import custom_prompt, prompt_factory
|
|
||||||
from ..common_utils import ClarifaiError
|
|
||||||
|
|
||||||
|
|
||||||
async def async_completion(
|
|
||||||
model: str,
|
|
||||||
messages: List[AllMessageValues],
|
|
||||||
model_response: ModelResponse,
|
|
||||||
encoding,
|
|
||||||
api_key,
|
|
||||||
api_base: str,
|
|
||||||
logging_obj,
|
|
||||||
data: dict,
|
|
||||||
optional_params: dict,
|
|
||||||
litellm_params=None,
|
|
||||||
logger_fn=None,
|
|
||||||
headers={},
|
|
||||||
):
|
|
||||||
|
|
||||||
async_handler = get_async_httpx_client(
|
|
||||||
llm_provider=litellm.LlmProviders.CLARIFAI,
|
|
||||||
params={"timeout": 600.0},
|
|
||||||
)
|
|
||||||
response = await async_handler.post(
|
|
||||||
url=api_base, headers=headers, data=json.dumps(data)
|
|
||||||
)
|
|
||||||
|
|
||||||
return litellm.ClarifaiConfig().transform_response(
|
|
||||||
model=model,
|
|
||||||
raw_response=response,
|
|
||||||
model_response=model_response,
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
api_key=api_key,
|
|
||||||
request_data=data,
|
|
||||||
messages=messages,
|
|
||||||
optional_params=optional_params,
|
|
||||||
encoding=encoding,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def completion(
|
|
||||||
model: str,
|
|
||||||
messages: list,
|
|
||||||
api_base: str,
|
|
||||||
model_response: ModelResponse,
|
|
||||||
print_verbose: Callable,
|
|
||||||
encoding,
|
|
||||||
api_key,
|
|
||||||
logging_obj,
|
|
||||||
optional_params: dict,
|
|
||||||
litellm_params: dict,
|
|
||||||
custom_prompt_dict={},
|
|
||||||
acompletion=False,
|
|
||||||
logger_fn=None,
|
|
||||||
headers={},
|
|
||||||
):
|
|
||||||
headers = litellm.ClarifaiConfig().validate_environment(
|
|
||||||
api_key=api_key,
|
|
||||||
headers=headers,
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
optional_params=optional_params,
|
|
||||||
)
|
|
||||||
data = litellm.ClarifaiConfig().transform_request(
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
optional_params=optional_params,
|
|
||||||
litellm_params=litellm_params,
|
|
||||||
headers=headers,
|
|
||||||
)
|
|
||||||
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.pre_call(
|
|
||||||
input=data,
|
|
||||||
api_key=api_key,
|
|
||||||
additional_args={
|
|
||||||
"complete_input_dict": data,
|
|
||||||
"headers": headers,
|
|
||||||
"api_base": model,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
if acompletion is True:
|
|
||||||
return async_completion(
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
api_base=api_base,
|
|
||||||
model_response=model_response,
|
|
||||||
encoding=encoding,
|
|
||||||
api_key=api_key,
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
data=data,
|
|
||||||
optional_params=optional_params,
|
|
||||||
litellm_params=litellm_params,
|
|
||||||
logger_fn=logger_fn,
|
|
||||||
headers=headers,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
## COMPLETION CALL
|
|
||||||
httpx_client = _get_httpx_client(
|
|
||||||
params={"timeout": 600.0},
|
|
||||||
)
|
|
||||||
response = httpx_client.post(
|
|
||||||
url=api_base,
|
|
||||||
headers=headers,
|
|
||||||
data=json.dumps(data),
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
|
||||||
raise ClarifaiError(status_code=response.status_code, message=response.text)
|
|
||||||
|
|
||||||
if "stream" in optional_params and optional_params["stream"] is True:
|
|
||||||
completion_stream = response.iter_lines()
|
|
||||||
stream_response = CustomStreamWrapper(
|
|
||||||
completion_stream=completion_stream,
|
|
||||||
model=model,
|
|
||||||
custom_llm_provider="clarifai",
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
)
|
|
||||||
return stream_response
|
|
||||||
|
|
||||||
else:
|
|
||||||
return litellm.ClarifaiConfig().transform_response(
|
|
||||||
model=model,
|
|
||||||
raw_response=response,
|
|
||||||
model_response=model_response,
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
api_key=api_key,
|
|
||||||
request_data=data,
|
|
||||||
messages=messages,
|
|
||||||
optional_params=optional_params,
|
|
||||||
encoding=encoding,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ModelResponseIterator:
|
|
||||||
def __init__(self, model_response):
|
|
||||||
self.model_response = model_response
|
|
||||||
self.is_done = False
|
|
||||||
|
|
||||||
# Sync iterator
|
|
||||||
def __iter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __next__(self):
|
|
||||||
if self.is_done:
|
|
||||||
raise StopIteration
|
|
||||||
self.is_done = True
|
|
||||||
return self.model_response
|
|
||||||
|
|
||||||
# Async iterator
|
|
||||||
def __aiter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __anext__(self):
|
|
||||||
if self.is_done:
|
|
||||||
raise StopAsyncIteration
|
|
||||||
self.is_done = True
|
|
||||||
return self.model_response
|
|
|
@ -1,13 +1,23 @@
|
||||||
|
import json
|
||||||
import types
|
import types
|
||||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm.llms.base_llm.base_model_iterator import FakeStreamResponseIterator
|
||||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
||||||
from litellm.llms.prompt_templates.common_utils import convert_content_list_to_str
|
from litellm.llms.prompt_templates.common_utils import convert_content_list_to_str
|
||||||
from litellm.types.llms.openai import AllMessageValues
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
from litellm.types.utils import Choices, Message, ModelResponse, Usage
|
from litellm.types.utils import (
|
||||||
|
ChatCompletionToolCallChunk,
|
||||||
|
ChatCompletionUsageBlock,
|
||||||
|
Choices,
|
||||||
|
GenericStreamingChunk,
|
||||||
|
Message,
|
||||||
|
ModelResponse,
|
||||||
|
Usage,
|
||||||
|
)
|
||||||
from litellm.utils import token_counter
|
from litellm.utils import token_counter
|
||||||
|
|
||||||
from ..common_utils import ClarifaiError
|
from ..common_utils import ClarifaiError
|
||||||
|
@ -199,3 +209,56 @@ class ClarifaiConfig(BaseConfig):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
def get_model_response_iterator(
|
||||||
|
self,
|
||||||
|
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||||
|
sync_stream: bool,
|
||||||
|
json_mode: Optional[bool] = False,
|
||||||
|
) -> Any:
|
||||||
|
return ClarifaiModelResponseIterator(
|
||||||
|
model_response=streaming_response,
|
||||||
|
json_mode=json_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ClarifaiModelResponseIterator(FakeStreamResponseIterator):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||||
|
json_mode: Optional[bool] = False,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
model_response=model_response,
|
||||||
|
json_mode=json_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||||
|
try:
|
||||||
|
text = ""
|
||||||
|
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||||
|
is_finished = False
|
||||||
|
finish_reason = ""
|
||||||
|
usage: Optional[ChatCompletionUsageBlock] = None
|
||||||
|
provider_specific_fields = None
|
||||||
|
|
||||||
|
text = (
|
||||||
|
chunk.get("outputs", "")[0]
|
||||||
|
.get("data", "")
|
||||||
|
.get("text", "")
|
||||||
|
.get("raw", "")
|
||||||
|
)
|
||||||
|
|
||||||
|
index: int = 0
|
||||||
|
|
||||||
|
return GenericStreamingChunk(
|
||||||
|
text=text,
|
||||||
|
tool_use=tool_use,
|
||||||
|
is_finished=is_finished,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
usage=usage,
|
||||||
|
index=index,
|
||||||
|
provider_specific_fields=provider_specific_fields,
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
|
||||||
|
|
|
@ -368,7 +368,7 @@ class CohereChatConfig(BaseConfig):
|
||||||
|
|
||||||
def get_model_response_iterator(
|
def get_model_response_iterator(
|
||||||
self,
|
self,
|
||||||
streaming_response: Union[Iterator[str], AsyncIterator[str]],
|
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||||
sync_stream: bool,
|
sync_stream: bool,
|
||||||
json_mode: Optional[bool] = False,
|
json_mode: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
|
|
|
@ -93,6 +93,7 @@ class BaseLLMHTTPHandler:
|
||||||
litellm_params: dict,
|
litellm_params: dict,
|
||||||
acompletion: bool,
|
acompletion: bool,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
|
fake_stream: bool = False,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
headers={},
|
headers={},
|
||||||
):
|
):
|
||||||
|
@ -129,7 +130,8 @@ class BaseLLMHTTPHandler:
|
||||||
|
|
||||||
if acompletion is True:
|
if acompletion is True:
|
||||||
if stream is True:
|
if stream is True:
|
||||||
data["stream"] = stream
|
if fake_stream is not True:
|
||||||
|
data["stream"] = stream
|
||||||
return self.acompletion_stream_function(
|
return self.acompletion_stream_function(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -140,6 +142,7 @@ class BaseLLMHTTPHandler:
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
data=data,
|
data=data,
|
||||||
|
fake_stream=fake_stream,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
@ -160,7 +163,8 @@ class BaseLLMHTTPHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
if stream is True:
|
if stream is True:
|
||||||
data["stream"] = stream
|
if fake_stream is not True:
|
||||||
|
data["stream"] = stream
|
||||||
completion_stream, headers = self.make_sync_call(
|
completion_stream, headers = self.make_sync_call(
|
||||||
provider_config=provider_config,
|
provider_config=provider_config,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -170,6 +174,7 @@ class BaseLLMHTTPHandler:
|
||||||
messages=messages,
|
messages=messages,
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
fake_stream=fake_stream,
|
||||||
)
|
)
|
||||||
return CustomStreamWrapper(
|
return CustomStreamWrapper(
|
||||||
completion_stream=completion_stream,
|
completion_stream=completion_stream,
|
||||||
|
@ -215,11 +220,15 @@ class BaseLLMHTTPHandler:
|
||||||
messages: list,
|
messages: list,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
timeout: Optional[Union[float, httpx.Timeout]],
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
fake_stream: bool = False,
|
||||||
) -> Tuple[Any, httpx.Headers]:
|
) -> Tuple[Any, httpx.Headers]:
|
||||||
sync_httpx_client = _get_httpx_client()
|
sync_httpx_client = _get_httpx_client()
|
||||||
try:
|
try:
|
||||||
|
stream = True
|
||||||
|
if fake_stream is True:
|
||||||
|
stream = False
|
||||||
response = sync_httpx_client.post(
|
response = sync_httpx_client.post(
|
||||||
api_base, headers=headers, data=data, stream=True, timeout=timeout
|
api_base, headers=headers, data=data, timeout=timeout, stream=stream
|
||||||
)
|
)
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
raise self._handle_error(
|
raise self._handle_error(
|
||||||
|
@ -240,9 +249,15 @@ class BaseLLMHTTPHandler:
|
||||||
status_code=response.status_code,
|
status_code=response.status_code,
|
||||||
message=str(response.read()),
|
message=str(response.read()),
|
||||||
)
|
)
|
||||||
completion_stream = provider_config.get_model_response_iterator(
|
|
||||||
streaming_response=response.iter_lines(), sync_stream=True
|
if fake_stream is True:
|
||||||
)
|
completion_stream = provider_config.get_model_response_iterator(
|
||||||
|
streaming_response=response.json(), sync_stream=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
completion_stream = provider_config.get_model_response_iterator(
|
||||||
|
streaming_response=response.iter_lines(), sync_stream=True
|
||||||
|
)
|
||||||
|
|
||||||
# LOGGING
|
# LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
|
@ -265,8 +280,8 @@ class BaseLLMHTTPHandler:
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
logging_obj: LiteLLMLoggingObj,
|
logging_obj: LiteLLMLoggingObj,
|
||||||
data: dict,
|
data: dict,
|
||||||
|
fake_stream: bool = False,
|
||||||
):
|
):
|
||||||
data["stream"] = True
|
|
||||||
completion_stream, _response_headers = await self.make_async_call(
|
completion_stream, _response_headers = await self.make_async_call(
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
provider_config=provider_config,
|
provider_config=provider_config,
|
||||||
|
@ -276,6 +291,7 @@ class BaseLLMHTTPHandler:
|
||||||
messages=messages,
|
messages=messages,
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
fake_stream=fake_stream,
|
||||||
)
|
)
|
||||||
streamwrapper = CustomStreamWrapper(
|
streamwrapper = CustomStreamWrapper(
|
||||||
completion_stream=completion_stream,
|
completion_stream=completion_stream,
|
||||||
|
@ -295,13 +311,17 @@ class BaseLLMHTTPHandler:
|
||||||
messages: list,
|
messages: list,
|
||||||
logging_obj: LiteLLMLoggingObj,
|
logging_obj: LiteLLMLoggingObj,
|
||||||
timeout: Optional[Union[float, httpx.Timeout]],
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
fake_stream: bool = False,
|
||||||
) -> Tuple[Any, httpx.Headers]:
|
) -> Tuple[Any, httpx.Headers]:
|
||||||
async_httpx_client = get_async_httpx_client(
|
async_httpx_client = get_async_httpx_client(
|
||||||
llm_provider=litellm.LlmProviders(custom_llm_provider)
|
llm_provider=litellm.LlmProviders(custom_llm_provider)
|
||||||
)
|
)
|
||||||
|
stream = True
|
||||||
|
if fake_stream is True:
|
||||||
|
stream = False
|
||||||
try:
|
try:
|
||||||
response = await async_httpx_client.post(
|
response = await async_httpx_client.post(
|
||||||
api_base, headers=headers, data=data, stream=True, timeout=timeout
|
api_base, headers=headers, data=data, stream=stream, timeout=timeout
|
||||||
)
|
)
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
raise self._handle_error(
|
raise self._handle_error(
|
||||||
|
@ -322,10 +342,14 @@ class BaseLLMHTTPHandler:
|
||||||
status_code=response.status_code,
|
status_code=response.status_code,
|
||||||
message=str(response.read()),
|
message=str(response.read()),
|
||||||
)
|
)
|
||||||
|
if fake_stream is True:
|
||||||
completion_stream = provider_config.get_model_response_iterator(
|
completion_stream = provider_config.get_model_response_iterator(
|
||||||
streaming_response=response.aiter_lines(), sync_stream=False
|
streaming_response=response.json(), sync_stream=False
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
completion_stream = provider_config.get_model_response_iterator(
|
||||||
|
streaming_response=response.aiter_lines(), sync_stream=False
|
||||||
|
)
|
||||||
# LOGGING
|
# LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=messages,
|
input=messages,
|
||||||
|
|
|
@ -110,7 +110,6 @@ from .llms.azure_text import AzureTextCompletion
|
||||||
from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
|
from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
|
||||||
from .llms.bedrock.embed.embedding import BedrockEmbedding
|
from .llms.bedrock.embed.embedding import BedrockEmbedding
|
||||||
from .llms.bedrock.image.image_handler import BedrockImageGeneration
|
from .llms.bedrock.image.image_handler import BedrockImageGeneration
|
||||||
from .llms.clarifai.chat import handler
|
|
||||||
from .llms.cohere.completion import completion as cohere_completion # type: ignore
|
from .llms.cohere.completion import completion as cohere_completion # type: ignore
|
||||||
from .llms.cohere.embed import handler as cohere_embed
|
from .llms.cohere.embed import handler as cohere_embed
|
||||||
from .llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
from .llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
||||||
|
@ -1689,41 +1688,23 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
or "https://api.clarifai.com/v2"
|
or "https://api.clarifai.com/v2"
|
||||||
)
|
)
|
||||||
api_base = litellm.ClarifaiConfig()._convert_model_to_url(model, api_base)
|
api_base = litellm.ClarifaiConfig()._convert_model_to_url(model, api_base)
|
||||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
response = base_llm_http_handler.completion(
|
||||||
model_response = handler.completion(
|
|
||||||
model=model,
|
model=model,
|
||||||
|
stream=stream,
|
||||||
|
fake_stream=True, # clarifai does not support streaming, we fake it
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
acompletion=acompletion,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
print_verbose=print_verbose,
|
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
acompletion=acompletion,
|
custom_llm_provider="clarifai",
|
||||||
logger_fn=logger_fn,
|
timeout=timeout,
|
||||||
encoding=encoding, # for calculating input/output tokens
|
headers=headers,
|
||||||
|
encoding=encoding,
|
||||||
api_key=clarifai_key,
|
api_key=clarifai_key,
|
||||||
logging_obj=logging,
|
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
||||||
custom_prompt_dict=custom_prompt_dict,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if "stream" in optional_params and optional_params["stream"] is True:
|
|
||||||
# don't try to access stream object,
|
|
||||||
## LOGGING
|
|
||||||
logging.post_call(
|
|
||||||
input=messages,
|
|
||||||
api_key=api_key,
|
|
||||||
original_response=model_response,
|
|
||||||
)
|
|
||||||
|
|
||||||
if optional_params.get("stream", False) or acompletion is True:
|
|
||||||
## LOGGING
|
|
||||||
logging.post_call(
|
|
||||||
input=messages,
|
|
||||||
api_key=clarifai_key,
|
|
||||||
original_response=model_response,
|
|
||||||
)
|
|
||||||
response = model_response
|
|
||||||
|
|
||||||
elif custom_llm_provider == "anthropic":
|
elif custom_llm_provider == "anthropic":
|
||||||
api_key = (
|
api_key = (
|
||||||
api_key
|
api_key
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue