(Refactor) Code Quality improvement - Use Common base handler for clarifai/ (#7125)

* use base_llm_http_handler for clarifai

* fix clarifai completion

* handle faking streaming base llm http handler

* add fake streaming for clarifai

* add FakeStreamResponseIterator for base model iterator

* fix get_model_response_iterator

* fix base model iterator

* fix base model iterator

* add support for faking sync streams clarfiai

* add fake streaming for clarifai

* remove unused code

* fix import

* fix llm http handler

* test_async_completion_clarifai

* fix clarifai tests

* fix linting
This commit is contained in:
Ishaan Jaff 2024-12-09 21:04:48 -08:00 committed by GitHub
parent c5e0407703
commit 28ff38e35d
9 changed files with 155 additions and 269 deletions

View file

@ -93,6 +93,7 @@ class BaseLLMHTTPHandler:
litellm_params: dict,
acompletion: bool,
stream: Optional[bool] = False,
fake_stream: bool = False,
api_key: Optional[str] = None,
headers={},
):
@ -129,7 +130,8 @@ class BaseLLMHTTPHandler:
if acompletion is True:
if stream is True:
data["stream"] = stream
if fake_stream is not True:
data["stream"] = stream
return self.acompletion_stream_function(
model=model,
messages=messages,
@ -140,6 +142,7 @@ class BaseLLMHTTPHandler:
timeout=timeout,
logging_obj=logging_obj,
data=data,
fake_stream=fake_stream,
)
else:
@ -160,7 +163,8 @@ class BaseLLMHTTPHandler:
)
if stream is True:
data["stream"] = stream
if fake_stream is not True:
data["stream"] = stream
completion_stream, headers = self.make_sync_call(
provider_config=provider_config,
api_base=api_base,
@ -170,6 +174,7 @@ class BaseLLMHTTPHandler:
messages=messages,
logging_obj=logging_obj,
timeout=timeout,
fake_stream=fake_stream,
)
return CustomStreamWrapper(
completion_stream=completion_stream,
@ -215,11 +220,15 @@ class BaseLLMHTTPHandler:
messages: list,
logging_obj,
timeout: Optional[Union[float, httpx.Timeout]],
fake_stream: bool = False,
) -> Tuple[Any, httpx.Headers]:
sync_httpx_client = _get_httpx_client()
try:
stream = True
if fake_stream is True:
stream = False
response = sync_httpx_client.post(
api_base, headers=headers, data=data, stream=True, timeout=timeout
api_base, headers=headers, data=data, timeout=timeout, stream=stream
)
except httpx.HTTPStatusError as e:
raise self._handle_error(
@ -240,9 +249,15 @@ class BaseLLMHTTPHandler:
status_code=response.status_code,
message=str(response.read()),
)
completion_stream = provider_config.get_model_response_iterator(
streaming_response=response.iter_lines(), sync_stream=True
)
if fake_stream is True:
completion_stream = provider_config.get_model_response_iterator(
streaming_response=response.json(), sync_stream=True
)
else:
completion_stream = provider_config.get_model_response_iterator(
streaming_response=response.iter_lines(), sync_stream=True
)
# LOGGING
logging_obj.post_call(
@ -265,8 +280,8 @@ class BaseLLMHTTPHandler:
timeout: Union[float, httpx.Timeout],
logging_obj: LiteLLMLoggingObj,
data: dict,
fake_stream: bool = False,
):
data["stream"] = True
completion_stream, _response_headers = await self.make_async_call(
custom_llm_provider=custom_llm_provider,
provider_config=provider_config,
@ -276,6 +291,7 @@ class BaseLLMHTTPHandler:
messages=messages,
logging_obj=logging_obj,
timeout=timeout,
fake_stream=fake_stream,
)
streamwrapper = CustomStreamWrapper(
completion_stream=completion_stream,
@ -295,13 +311,17 @@ class BaseLLMHTTPHandler:
messages: list,
logging_obj: LiteLLMLoggingObj,
timeout: Optional[Union[float, httpx.Timeout]],
fake_stream: bool = False,
) -> Tuple[Any, httpx.Headers]:
async_httpx_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders(custom_llm_provider)
)
stream = True
if fake_stream is True:
stream = False
try:
response = await async_httpx_client.post(
api_base, headers=headers, data=data, stream=True, timeout=timeout
api_base, headers=headers, data=data, stream=stream, timeout=timeout
)
except httpx.HTTPStatusError as e:
raise self._handle_error(
@ -322,10 +342,14 @@ class BaseLLMHTTPHandler:
status_code=response.status_code,
message=str(response.read()),
)
completion_stream = provider_config.get_model_response_iterator(
streaming_response=response.aiter_lines(), sync_stream=False
)
if fake_stream is True:
completion_stream = provider_config.get_model_response_iterator(
streaming_response=response.json(), sync_stream=False
)
else:
completion_stream = provider_config.get_model_response_iterator(
streaming_response=response.aiter_lines(), sync_stream=False
)
# LOGGING
logging_obj.post_call(
input=messages,