(Refactor) Code Quality improvement - Use Common base handler for cloudflare/ provider (#7127)

* add get_complete_url to base config

* cloudflare - refactor to following existing pattern

* migrate cloudflare chat completions to base llm http handler

* fix unused import

* fix fake stream in cloudflare

* fix cloudflare transformation

* fix naming for BaseModelResponseIterator

* add async cloudflare streaming test

* test cloudflare

* add handler.py

* add handler.py in cohere handler.py
This commit is contained in:
Ishaan Jaff 2024-12-10 10:12:22 -08:00 committed by GitHub
parent 1ef311343c
commit bd39e1ab5d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 391 additions and 268 deletions

View file

@ -86,7 +86,6 @@ from .litellm_core_utils.streaming_chunk_builder_utils import ChunkProcessor
from .llms import (
aleph_alpha,
baseten,
cloudflare,
maritalk,
nlp_cloud,
ollama,
@ -471,6 +470,7 @@ async def acompletion(
or custom_llm_provider == "triton"
or custom_llm_provider == "clarifai"
or custom_llm_provider == "watsonx"
or custom_llm_provider == "cloudflare"
or custom_llm_provider in litellm.openai_compatible_providers
or custom_llm_provider in litellm._custom_providers
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
@ -2828,37 +2828,22 @@ def completion( # type: ignore # noqa: PLR0915
)
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
response = cloudflare.completion(
response = base_llm_http_handler.completion(
model=model,
stream=stream,
messages=messages,
acompletion=acompletion,
api_base=api_base,
custom_prompt_dict=litellm.custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding, # for calculating input/output tokens
custom_llm_provider="cloudflare",
timeout=timeout,
headers=headers,
encoding=encoding,
api_key=api_key,
logging_obj=logging,
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
)
if "stream" in optional_params and optional_params["stream"] is True:
# don't try to access stream object,
response = CustomStreamWrapper(
response,
model,
custom_llm_provider="cloudflare",
logging_obj=logging,
)
if optional_params.get("stream", False) or acompletion is True:
## LOGGING
logging.post_call(
input=messages,
api_key=api_key,
original_response=response,
)
response = response
elif (
custom_llm_provider == "baseten"
or litellm.api_base == "https://app.baseten.co"