mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
(Refactor) Code Quality improvement - Use Common base handler for clarifai/
(#7125)
* use base_llm_http_handler for clarifai * fix clarifai completion * handle faking streaming base llm http handler * add fake streaming for clarifai * add FakeStreamResponseIterator for base model iterator * fix get_model_response_iterator * fix base model iterator * fix base model iterator * add support for faking sync streams clarfiai * add fake streaming for clarifai * remove unused code * fix import * fix llm http handler * test_async_completion_clarifai * fix clarifai tests * fix linting
This commit is contained in:
parent
c5e0407703
commit
28ff38e35d
9 changed files with 155 additions and 269 deletions
|
@ -93,6 +93,7 @@ class BaseLLMHTTPHandler:
|
|||
litellm_params: dict,
|
||||
acompletion: bool,
|
||||
stream: Optional[bool] = False,
|
||||
fake_stream: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
headers={},
|
||||
):
|
||||
|
@ -129,7 +130,8 @@ class BaseLLMHTTPHandler:
|
|||
|
||||
if acompletion is True:
|
||||
if stream is True:
|
||||
data["stream"] = stream
|
||||
if fake_stream is not True:
|
||||
data["stream"] = stream
|
||||
return self.acompletion_stream_function(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -140,6 +142,7 @@ class BaseLLMHTTPHandler:
|
|||
timeout=timeout,
|
||||
logging_obj=logging_obj,
|
||||
data=data,
|
||||
fake_stream=fake_stream,
|
||||
)
|
||||
|
||||
else:
|
||||
|
@ -160,7 +163,8 @@ class BaseLLMHTTPHandler:
|
|||
)
|
||||
|
||||
if stream is True:
|
||||
data["stream"] = stream
|
||||
if fake_stream is not True:
|
||||
data["stream"] = stream
|
||||
completion_stream, headers = self.make_sync_call(
|
||||
provider_config=provider_config,
|
||||
api_base=api_base,
|
||||
|
@ -170,6 +174,7 @@ class BaseLLMHTTPHandler:
|
|||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
timeout=timeout,
|
||||
fake_stream=fake_stream,
|
||||
)
|
||||
return CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
|
@ -215,11 +220,15 @@ class BaseLLMHTTPHandler:
|
|||
messages: list,
|
||||
logging_obj,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
fake_stream: bool = False,
|
||||
) -> Tuple[Any, httpx.Headers]:
|
||||
sync_httpx_client = _get_httpx_client()
|
||||
try:
|
||||
stream = True
|
||||
if fake_stream is True:
|
||||
stream = False
|
||||
response = sync_httpx_client.post(
|
||||
api_base, headers=headers, data=data, stream=True, timeout=timeout
|
||||
api_base, headers=headers, data=data, timeout=timeout, stream=stream
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise self._handle_error(
|
||||
|
@ -240,9 +249,15 @@ class BaseLLMHTTPHandler:
|
|||
status_code=response.status_code,
|
||||
message=str(response.read()),
|
||||
)
|
||||
completion_stream = provider_config.get_model_response_iterator(
|
||||
streaming_response=response.iter_lines(), sync_stream=True
|
||||
)
|
||||
|
||||
if fake_stream is True:
|
||||
completion_stream = provider_config.get_model_response_iterator(
|
||||
streaming_response=response.json(), sync_stream=True
|
||||
)
|
||||
else:
|
||||
completion_stream = provider_config.get_model_response_iterator(
|
||||
streaming_response=response.iter_lines(), sync_stream=True
|
||||
)
|
||||
|
||||
# LOGGING
|
||||
logging_obj.post_call(
|
||||
|
@ -265,8 +280,8 @@ class BaseLLMHTTPHandler:
|
|||
timeout: Union[float, httpx.Timeout],
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
data: dict,
|
||||
fake_stream: bool = False,
|
||||
):
|
||||
data["stream"] = True
|
||||
completion_stream, _response_headers = await self.make_async_call(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
provider_config=provider_config,
|
||||
|
@ -276,6 +291,7 @@ class BaseLLMHTTPHandler:
|
|||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
timeout=timeout,
|
||||
fake_stream=fake_stream,
|
||||
)
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
|
@ -295,13 +311,17 @@ class BaseLLMHTTPHandler:
|
|||
messages: list,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
fake_stream: bool = False,
|
||||
) -> Tuple[Any, httpx.Headers]:
|
||||
async_httpx_client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders(custom_llm_provider)
|
||||
)
|
||||
stream = True
|
||||
if fake_stream is True:
|
||||
stream = False
|
||||
try:
|
||||
response = await async_httpx_client.post(
|
||||
api_base, headers=headers, data=data, stream=True, timeout=timeout
|
||||
api_base, headers=headers, data=data, stream=stream, timeout=timeout
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise self._handle_error(
|
||||
|
@ -322,10 +342,14 @@ class BaseLLMHTTPHandler:
|
|||
status_code=response.status_code,
|
||||
message=str(response.read()),
|
||||
)
|
||||
|
||||
completion_stream = provider_config.get_model_response_iterator(
|
||||
streaming_response=response.aiter_lines(), sync_stream=False
|
||||
)
|
||||
if fake_stream is True:
|
||||
completion_stream = provider_config.get_model_response_iterator(
|
||||
streaming_response=response.json(), sync_stream=False
|
||||
)
|
||||
else:
|
||||
completion_stream = provider_config.get_model_response_iterator(
|
||||
streaming_response=response.aiter_lines(), sync_stream=False
|
||||
)
|
||||
# LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue