(Refactor) Code Quality improvement - Use Common base handler for cloudflare/ provider (#7127)

* add get_complete_url to base config

* cloudflare - refactor to following existing pattern

* migrate cloudflare chat completions to base llm http handler

* fix unused import

* fix fake stream in cloudflare

* fix cloudflare transformation

* fix naming for BaseModelResponseIterator

* add async cloudflare streaming test

* test cloudflare

* add handler.py

* add handler.py in cohere handler.py
This commit is contained in:
Ishaan Jaff 2024-12-10 10:12:22 -08:00 committed by GitHub
parent 28ff38e35d
commit 9c2316b7ec
14 changed files with 391 additions and 268 deletions

View file

@ -1067,10 +1067,10 @@ from .llms.predibase import PredibaseConfig
from .llms.replicate import ReplicateConfig from .llms.replicate import ReplicateConfig
from .llms.cohere.completion.transformation import CohereTextConfig as CohereConfig from .llms.cohere.completion.transformation import CohereTextConfig as CohereConfig
from .llms.clarifai.chat.transformation import ClarifaiConfig from .llms.clarifai.chat.transformation import ClarifaiConfig
from .llms.cloudflare.chat.transformation import CloudflareChatConfig
from .llms.ai21.completion import AI21Config from .llms.ai21.completion import AI21Config
from .llms.ai21.chat import AI21ChatConfig from .llms.ai21.chat import AI21ChatConfig
from .llms.together_ai.chat import TogetherAIConfig from .llms.together_ai.chat import TogetherAIConfig
from .llms.cloudflare import CloudflareConfig
from .llms.palm import PalmConfig from .llms.palm import PalmConfig
from .llms.gemini import GeminiConfig from .llms.gemini import GeminiConfig
from .llms.nlp_cloud import NLPCloudConfig from .llms.nlp_cloud import NLPCloudConfig

View file

@ -195,7 +195,7 @@ def get_supported_openai_params( # noqa: PLR0915
"stop", "stop",
] ]
elif custom_llm_provider == "cloudflare": elif custom_llm_provider == "cloudflare":
return ["max_tokens", "stream"] return litellm.CloudflareChatConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "nlp_cloud": elif custom_llm_provider == "nlp_cloud":
return [ return [
"max_tokens", "max_tokens",

View file

@ -630,36 +630,6 @@ class CustomStreamWrapper:
) )
return "" return ""
def handle_cloudlfare_stream(self, chunk):
try:
print_verbose(f"\nRaw OpenAI Chunk\n{chunk}\n")
chunk = chunk.decode("utf-8")
str_line = chunk
text = ""
is_finished = False
finish_reason = None
if "[DONE]" in chunk:
return {"text": text, "is_finished": True, "finish_reason": "stop"}
elif str_line.startswith("data:"):
data_json = json.loads(str_line[5:])
print_verbose(f"delta content: {data_json}")
text = data_json["response"]
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
else:
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
except Exception as e:
raise e
def handle_ollama_stream(self, chunk): def handle_ollama_stream(self, chunk):
try: try:
if isinstance(chunk, dict): if isinstance(chunk, dict):
@ -1226,12 +1196,6 @@ class CustomStreamWrapper:
print_verbose(f"completion obj content: {completion_obj['content']}") print_verbose(f"completion obj content: {completion_obj['content']}")
if response_obj["is_finished"]: if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"] self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "cloudflare":
response_obj = self.handle_cloudlfare_stream(chunk)
completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}")
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "watsonx": elif self.custom_llm_provider == "watsonx":
response_obj = self.handle_watsonx_stream(chunk) response_obj = self.handle_watsonx_stream(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
@ -1722,6 +1686,7 @@ class CustomStreamWrapper:
or self.custom_llm_provider == "bedrock" or self.custom_llm_provider == "bedrock"
or self.custom_llm_provider == "triton" or self.custom_llm_provider == "triton"
or self.custom_llm_provider == "watsonx" or self.custom_llm_provider == "watsonx"
or self.custom_llm_provider == "cloudflare"
or self.custom_llm_provider in litellm.openai_compatible_providers or self.custom_llm_provider in litellm.openai_compatible_providers
or self.custom_llm_provider in litellm._custom_providers or self.custom_llm_provider in litellm._custom_providers
): ):

View file

@ -1,5 +1,5 @@
import json import json
from abc import abstractmethod from abc import ABC, abstractmethod
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import litellm import litellm
@ -12,6 +12,103 @@ from litellm.types.utils import (
) )
class BaseModelResponseIterator:
def __init__(
self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False
):
self.streaming_response = streaming_response
self.response_iterator = self.streaming_response
self.json_mode = json_mode
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
# Sync iterator
def __iter__(self):
return self
def _handle_string_chunk(self, str_line: str) -> GenericStreamingChunk:
# chunk is a str at this point
if "[DONE]" in str_line:
return GenericStreamingChunk(
text="",
is_finished=True,
finish_reason="stop",
usage=None,
index=0,
tool_use=None,
)
elif str_line.startswith("data:"):
data_json = json.loads(str_line[5:])
return self.chunk_parser(chunk=data_json)
else:
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
def __next__(self):
try:
chunk = self.response_iterator.__next__()
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
str_line = chunk
if isinstance(chunk, bytes): # Handle binary data
str_line = chunk.decode("utf-8") # Convert bytes to string
index = str_line.find("data:")
if index != -1:
str_line = str_line[index:]
# chunk is a str at this point
return self._handle_string_chunk(str_line=str_line)
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
# Async iterator
def __aiter__(self):
self.async_response_iterator = self.streaming_response.__aiter__()
return self
async def __anext__(self):
try:
chunk = await self.async_response_iterator.__anext__()
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
str_line = chunk
if isinstance(chunk, bytes): # Handle binary data
str_line = chunk.decode("utf-8") # Convert bytes to string
index = str_line.find("data:")
if index != -1:
str_line = str_line[index:]
# chunk is a str at this point
return self._handle_string_chunk(str_line=str_line)
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
class FakeStreamResponseIterator: class FakeStreamResponseIterator:
def __init__(self, model_response, json_mode: Optional[bool] = False): def __init__(self, model_response, json_mode: Optional[bool] = False):
self.model_response = model_response self.model_response = model_response

View file

@ -95,6 +95,16 @@ class BaseConfig(ABC):
) -> dict: ) -> dict:
pass pass
def get_complete_url(self, api_base: str, model: str) -> str:
"""
OPTIONAL
Get the complete url for the request
Some providers need `model` in `api_base`
"""
return api_base
@abstractmethod @abstractmethod
def transform_request( def transform_request(
self, self,

View file

@ -1,180 +0,0 @@
import json
import os
import time
import types
from enum import Enum
from typing import Callable, Optional
import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.utils import ModelResponse, Usage
from .prompt_templates.factory import custom_prompt, prompt_factory
class CloudflareError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(method="POST", url="https://api.cloudflare.com")
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class CloudflareConfig:
max_tokens: Optional[int] = None
stream: Optional[bool] = None
def __init__(
self,
max_tokens: Optional[int] = None,
stream: Optional[bool] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def validate_environment(api_key):
if api_key is None:
raise ValueError(
"Missing CloudflareError API Key - A call is being made to cloudflare but no key is set either in the environment variables or via params"
)
headers = {
"accept": "application/json",
"content-type": "application/json",
"Authorization": "Bearer " + api_key,
}
return headers
def completion(
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params: dict,
custom_prompt_dict={},
litellm_params=None,
logger_fn=None,
):
headers = validate_environment(api_key)
## Load Config
config = litellm.CloudflareConfig.get_config()
for k, v in config.items():
if k not in optional_params:
optional_params[k] = v
print_verbose(f"CUSTOM PROMPT DICT: {custom_prompt_dict}; model: {model}")
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
custom_prompt(
role_dict=model_prompt_details.get("roles", {}),
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
bos_token=model_prompt_details.get("bos_token", ""),
eos_token=model_prompt_details.get("eos_token", ""),
messages=messages,
)
# cloudflare adds the model to the api base
api_base = api_base + model
data = {
"messages": messages,
**optional_params,
}
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=api_key,
additional_args={
"headers": headers,
"api_base": api_base,
"complete_input_dict": data,
},
)
## COMPLETION CALL
if "stream" in optional_params and optional_params["stream"] is True:
response = requests.post(
api_base,
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"],
)
return response.iter_lines()
else:
response = requests.post(api_base, headers=headers, data=json.dumps(data))
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
if response.status_code != 200:
raise CloudflareError(
status_code=response.status_code, message=response.text
)
completion_response = response.json()
model_response.choices[0].message.content = completion_response["result"][ # type: ignore
"response"
]
## CALCULATING USAGE
print_verbose(
f"CALCULATING CLOUDFLARE TOKEN USAGE. Model Response: {model_response}; model_response['choices'][0]['message'].get('content', ''): {model_response['choices'][0]['message'].get('content', None)}"
)
prompt_tokens = litellm.utils.get_token_count(messages=messages, model=model)
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
model_response.created = int(time.time())
model_response.model = "cloudflare/" + model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -0,0 +1,5 @@
"""
Cloudflare - uses `llm_http_handler.py` to make httpx requests
Request/Response transformation is handled in `transformation.py`
"""

View file

@ -0,0 +1,202 @@
import json
import time
from typing import AsyncIterator, Iterator, List, Optional, Union
import httpx
import litellm
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
from litellm.llms.base_llm.transformation import (
BaseConfig,
BaseLLMException,
LiteLLMLoggingObj,
)
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import (
ChatCompletionToolCallChunk,
ChatCompletionUsageBlock,
GenericStreamingChunk,
ModelResponse,
Usage,
)
class CloudflareError(BaseLLMException):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(method="POST", url="https://api.cloudflare.com")
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
status_code=status_code,
message=message,
request=self.request,
response=self.response,
) # Call the base class constructor with the parameters it needs
class CloudflareChatConfig(BaseConfig):
max_tokens: Optional[int] = None
stream: Optional[bool] = None
def __init__(
self,
max_tokens: Optional[int] = None,
stream: Optional[bool] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
) -> dict:
if api_key is None:
raise ValueError(
"Missing CloudflareError API Key - A call is being made to cloudflare but no key is set either in the environment variables or via params"
)
headers = {
"accept": "application/json",
"content-type": "apbplication/json",
"Authorization": "Bearer " + api_key,
}
return headers
def get_complete_url(self, api_base: str, model: str) -> str:
return api_base + model
def get_supported_openai_params(self, model: str) -> List[str]:
return [
"stream",
"max_tokens",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
supported_openai_params = self.get_supported_openai_params(model=model)
for param, value in non_default_params.items():
if param == "max_completion_tokens":
optional_params["max_tokens"] = value
elif param in supported_openai_params:
optional_params[param] = value
return optional_params
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
config = litellm.CloudflareChatConfig.get_config()
for k, v in config.items():
if k not in optional_params:
optional_params[k] = v
data = {
"messages": messages,
**optional_params,
}
return data
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
encoding: str,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
completion_response = raw_response.json()
model_response.choices[0].message.content = completion_response["result"][ # type: ignore
"response"
]
prompt_tokens = litellm.utils.get_token_count(messages=messages, model=model)
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
model_response.created = int(time.time())
model_response.model = "cloudflare/" + model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return CloudflareError(
status_code=status_code,
message=error_message,
)
def _transform_messages(
self, messages: List[AllMessageValues]
) -> List[AllMessageValues]:
raise NotImplementedError
def get_model_response_iterator(
self,
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
sync_stream: bool,
json_mode: Optional[bool] = False,
):
return CloudflareChatResponseIterator(
streaming_response=streaming_response,
sync_stream=sync_stream,
json_mode=json_mode,
)
class CloudflareChatResponseIterator(BaseModelResponseIterator):
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
try:
text = ""
tool_use: Optional[ChatCompletionToolCallChunk] = None
is_finished = False
finish_reason = ""
usage: Optional[ChatCompletionUsageBlock] = None
provider_specific_fields = None
index = int(chunk.get("index", 0))
if "response" in chunk:
text = chunk["response"]
returned_chunk = GenericStreamingChunk(
text=text,
tool_use=tool_use,
is_finished=is_finished,
finish_reason=finish_reason,
usage=usage,
index=index,
provider_specific_fields=provider_specific_fields,
)
return returned_chunk
except json.JSONDecodeError:
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")

View file

@ -0,0 +1,5 @@
"""
Cohere /generate API - uses `llm_http_handler.py` to make httpx requests
Request/Response transformation is handled in `transformation.py`
"""

View file

@ -13,7 +13,6 @@ from typing import (
) )
import httpx # type: ignore import httpx # type: ignore
import requests # type: ignore
from openai.types.chat.chat_completion_chunk import Choice as OpenAIStreamingChoice from openai.types.chat.chat_completion_chunk import Choice as OpenAIStreamingChoice
import litellm import litellm
@ -109,6 +108,11 @@ class BaseLLMHTTPHandler:
optional_params=optional_params, optional_params=optional_params,
) )
api_base = provider_config.get_complete_url(
api_base=api_base,
model=model,
)
data = provider_config.transform_request( data = provider_config.transform_request(
model=model, model=model,
messages=messages, messages=messages,

View file

@ -86,7 +86,6 @@ from .litellm_core_utils.streaming_chunk_builder_utils import ChunkProcessor
from .llms import ( from .llms import (
aleph_alpha, aleph_alpha,
baseten, baseten,
cloudflare,
maritalk, maritalk,
nlp_cloud, nlp_cloud,
ollama, ollama,
@ -471,6 +470,7 @@ async def acompletion(
or custom_llm_provider == "triton" or custom_llm_provider == "triton"
or custom_llm_provider == "clarifai" or custom_llm_provider == "clarifai"
or custom_llm_provider == "watsonx" or custom_llm_provider == "watsonx"
or custom_llm_provider == "cloudflare"
or custom_llm_provider in litellm.openai_compatible_providers or custom_llm_provider in litellm.openai_compatible_providers
or custom_llm_provider in litellm._custom_providers or custom_llm_provider in litellm._custom_providers
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all. ): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
@ -2828,37 +2828,22 @@ def completion( # type: ignore # noqa: PLR0915
) )
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
response = cloudflare.completion( response = base_llm_http_handler.completion(
model=model, model=model,
stream=stream,
messages=messages, messages=messages,
acompletion=acompletion,
api_base=api_base, api_base=api_base,
custom_prompt_dict=litellm.custom_prompt_dict,
model_response=model_response, model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params, litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding, # for calculating input/output tokens
api_key=api_key,
logging_obj=logging,
)
if "stream" in optional_params and optional_params["stream"] is True:
# don't try to access stream object,
response = CustomStreamWrapper(
response,
model,
custom_llm_provider="cloudflare", custom_llm_provider="cloudflare",
logging_obj=logging, timeout=timeout,
) headers=headers,
encoding=encoding,
if optional_params.get("stream", False) or acompletion is True:
## LOGGING
logging.post_call(
input=messages,
api_key=api_key, api_key=api_key,
original_response=response, logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
) )
response = response
elif ( elif (
custom_llm_provider == "baseten" custom_llm_provider == "baseten"
or litellm.api_base == "https://app.baseten.co" or litellm.api_base == "https://app.baseten.co"

View file

@ -3274,10 +3274,16 @@ def get_optional_params( # noqa: PLR0915
) )
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
if max_tokens is not None: optional_params = litellm.CloudflareChatConfig().map_openai_params(
optional_params["max_tokens"] = max_tokens model=model,
if stream is not None: non_default_params=non_default_params,
optional_params["stream"] = stream optional_params=optional_params,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
elif custom_llm_provider == "ollama": elif custom_llm_provider == "ollama":
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider model=model, custom_llm_provider=custom_llm_provider
@ -6248,6 +6254,8 @@ class ProviderConfigManager:
elif litellm.LlmProviders.VERTEX_AI == provider: elif litellm.LlmProviders.VERTEX_AI == provider:
if "claude" in model: if "claude" in model:
return litellm.VertexAIAnthropicConfig() return litellm.VertexAIAnthropicConfig()
elif litellm.LlmProviders.CLOUDFLARE == provider:
return litellm.CloudflareChatConfig()
return litellm.OpenAIGPTConfig() return litellm.OpenAIGPTConfig()

View file

@ -0,0 +1,42 @@
import os
import sys
import traceback
from dotenv import load_dotenv
load_dotenv()
import io
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import json
import pytest
import litellm
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
# Cloud flare AI test
@pytest.mark.asyncio
@pytest.mark.parametrize("stream", [True, False])
async def test_completion_cloudflare(stream):
try:
litellm.set_verbose = False
response = await litellm.acompletion(
model="cloudflare/@cf/meta/llama-2-7b-chat-int8",
messages=[{"content": "what llm are you", "role": "user"}],
max_tokens=15,
stream=stream,
)
print(response)
if stream is True:
async for chunk in response:
print(chunk)
else:
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -4181,26 +4181,6 @@ def test_completion_together_ai_stream():
# test_completion_together_ai_stream() # test_completion_together_ai_stream()
# Cloud flare AI tests
@pytest.mark.skip(reason="Flaky test-cloudflare is very unstable")
def test_completion_cloudflare():
try:
litellm.set_verbose = True
response = completion(
model="cloudflare/@cf/meta/llama-2-7b-chat-int8",
messages=[{"content": "what llm are you", "role": "user"}],
max_tokens=15,
num_retries=3,
)
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_cloudflare()
def test_moderation(): def test_moderation():
response = litellm.moderation(input="i'm ishaan cto of litellm") response = litellm.moderation(input="i'm ishaan cto of litellm")
print(response) print(response)