mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
(Refactor) Code Quality improvement - Use Common base handler for cloudflare/
provider (#7127)
* add get_complete_url to base config * cloudflare - refactor to following existing pattern * migrate cloudflare chat completions to base llm http handler * fix unused import * fix fake stream in cloudflare * fix cloudflare transformation * fix naming for BaseModelResponseIterator * add async cloudflare streaming test * test cloudflare * add handler.py * add handler.py in cohere handler.py
This commit is contained in:
parent
28ff38e35d
commit
9c2316b7ec
14 changed files with 391 additions and 268 deletions
|
@ -1067,10 +1067,10 @@ from .llms.predibase import PredibaseConfig
|
||||||
from .llms.replicate import ReplicateConfig
|
from .llms.replicate import ReplicateConfig
|
||||||
from .llms.cohere.completion.transformation import CohereTextConfig as CohereConfig
|
from .llms.cohere.completion.transformation import CohereTextConfig as CohereConfig
|
||||||
from .llms.clarifai.chat.transformation import ClarifaiConfig
|
from .llms.clarifai.chat.transformation import ClarifaiConfig
|
||||||
|
from .llms.cloudflare.chat.transformation import CloudflareChatConfig
|
||||||
from .llms.ai21.completion import AI21Config
|
from .llms.ai21.completion import AI21Config
|
||||||
from .llms.ai21.chat import AI21ChatConfig
|
from .llms.ai21.chat import AI21ChatConfig
|
||||||
from .llms.together_ai.chat import TogetherAIConfig
|
from .llms.together_ai.chat import TogetherAIConfig
|
||||||
from .llms.cloudflare import CloudflareConfig
|
|
||||||
from .llms.palm import PalmConfig
|
from .llms.palm import PalmConfig
|
||||||
from .llms.gemini import GeminiConfig
|
from .llms.gemini import GeminiConfig
|
||||||
from .llms.nlp_cloud import NLPCloudConfig
|
from .llms.nlp_cloud import NLPCloudConfig
|
||||||
|
|
|
@ -195,7 +195,7 @@ def get_supported_openai_params( # noqa: PLR0915
|
||||||
"stop",
|
"stop",
|
||||||
]
|
]
|
||||||
elif custom_llm_provider == "cloudflare":
|
elif custom_llm_provider == "cloudflare":
|
||||||
return ["max_tokens", "stream"]
|
return litellm.CloudflareChatConfig().get_supported_openai_params(model=model)
|
||||||
elif custom_llm_provider == "nlp_cloud":
|
elif custom_llm_provider == "nlp_cloud":
|
||||||
return [
|
return [
|
||||||
"max_tokens",
|
"max_tokens",
|
||||||
|
|
|
@ -630,36 +630,6 @@ class CustomStreamWrapper:
|
||||||
)
|
)
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def handle_cloudlfare_stream(self, chunk):
|
|
||||||
try:
|
|
||||||
print_verbose(f"\nRaw OpenAI Chunk\n{chunk}\n")
|
|
||||||
chunk = chunk.decode("utf-8")
|
|
||||||
str_line = chunk
|
|
||||||
text = ""
|
|
||||||
is_finished = False
|
|
||||||
finish_reason = None
|
|
||||||
|
|
||||||
if "[DONE]" in chunk:
|
|
||||||
return {"text": text, "is_finished": True, "finish_reason": "stop"}
|
|
||||||
elif str_line.startswith("data:"):
|
|
||||||
data_json = json.loads(str_line[5:])
|
|
||||||
print_verbose(f"delta content: {data_json}")
|
|
||||||
text = data_json["response"]
|
|
||||||
return {
|
|
||||||
"text": text,
|
|
||||||
"is_finished": is_finished,
|
|
||||||
"finish_reason": finish_reason,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"text": text,
|
|
||||||
"is_finished": is_finished,
|
|
||||||
"finish_reason": finish_reason,
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def handle_ollama_stream(self, chunk):
|
def handle_ollama_stream(self, chunk):
|
||||||
try:
|
try:
|
||||||
if isinstance(chunk, dict):
|
if isinstance(chunk, dict):
|
||||||
|
@ -1226,12 +1196,6 @@ class CustomStreamWrapper:
|
||||||
print_verbose(f"completion obj content: {completion_obj['content']}")
|
print_verbose(f"completion obj content: {completion_obj['content']}")
|
||||||
if response_obj["is_finished"]:
|
if response_obj["is_finished"]:
|
||||||
self.received_finish_reason = response_obj["finish_reason"]
|
self.received_finish_reason = response_obj["finish_reason"]
|
||||||
elif self.custom_llm_provider == "cloudflare":
|
|
||||||
response_obj = self.handle_cloudlfare_stream(chunk)
|
|
||||||
completion_obj["content"] = response_obj["text"]
|
|
||||||
print_verbose(f"completion obj content: {completion_obj['content']}")
|
|
||||||
if response_obj["is_finished"]:
|
|
||||||
self.received_finish_reason = response_obj["finish_reason"]
|
|
||||||
elif self.custom_llm_provider == "watsonx":
|
elif self.custom_llm_provider == "watsonx":
|
||||||
response_obj = self.handle_watsonx_stream(chunk)
|
response_obj = self.handle_watsonx_stream(chunk)
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
|
@ -1722,6 +1686,7 @@ class CustomStreamWrapper:
|
||||||
or self.custom_llm_provider == "bedrock"
|
or self.custom_llm_provider == "bedrock"
|
||||||
or self.custom_llm_provider == "triton"
|
or self.custom_llm_provider == "triton"
|
||||||
or self.custom_llm_provider == "watsonx"
|
or self.custom_llm_provider == "watsonx"
|
||||||
|
or self.custom_llm_provider == "cloudflare"
|
||||||
or self.custom_llm_provider in litellm.openai_compatible_providers
|
or self.custom_llm_provider in litellm.openai_compatible_providers
|
||||||
or self.custom_llm_provider in litellm._custom_providers
|
or self.custom_llm_provider in litellm._custom_providers
|
||||||
):
|
):
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import json
|
import json
|
||||||
from abc import abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
@ -12,6 +12,103 @@ from litellm.types.utils import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseModelResponseIterator:
|
||||||
|
def __init__(
|
||||||
|
self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False
|
||||||
|
):
|
||||||
|
self.streaming_response = streaming_response
|
||||||
|
self.response_iterator = self.streaming_response
|
||||||
|
self.json_mode = json_mode
|
||||||
|
|
||||||
|
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||||
|
return GenericStreamingChunk(
|
||||||
|
text="",
|
||||||
|
is_finished=False,
|
||||||
|
finish_reason="",
|
||||||
|
usage=None,
|
||||||
|
index=0,
|
||||||
|
tool_use=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sync iterator
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _handle_string_chunk(self, str_line: str) -> GenericStreamingChunk:
|
||||||
|
# chunk is a str at this point
|
||||||
|
if "[DONE]" in str_line:
|
||||||
|
return GenericStreamingChunk(
|
||||||
|
text="",
|
||||||
|
is_finished=True,
|
||||||
|
finish_reason="stop",
|
||||||
|
usage=None,
|
||||||
|
index=0,
|
||||||
|
tool_use=None,
|
||||||
|
)
|
||||||
|
elif str_line.startswith("data:"):
|
||||||
|
data_json = json.loads(str_line[5:])
|
||||||
|
return self.chunk_parser(chunk=data_json)
|
||||||
|
else:
|
||||||
|
return GenericStreamingChunk(
|
||||||
|
text="",
|
||||||
|
is_finished=False,
|
||||||
|
finish_reason="",
|
||||||
|
usage=None,
|
||||||
|
index=0,
|
||||||
|
tool_use=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
try:
|
||||||
|
chunk = self.response_iterator.__next__()
|
||||||
|
except StopIteration:
|
||||||
|
raise StopIteration
|
||||||
|
except ValueError as e:
|
||||||
|
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
str_line = chunk
|
||||||
|
if isinstance(chunk, bytes): # Handle binary data
|
||||||
|
str_line = chunk.decode("utf-8") # Convert bytes to string
|
||||||
|
index = str_line.find("data:")
|
||||||
|
if index != -1:
|
||||||
|
str_line = str_line[index:]
|
||||||
|
# chunk is a str at this point
|
||||||
|
return self._handle_string_chunk(str_line=str_line)
|
||||||
|
except StopIteration:
|
||||||
|
raise StopIteration
|
||||||
|
except ValueError as e:
|
||||||
|
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||||
|
|
||||||
|
# Async iterator
|
||||||
|
def __aiter__(self):
|
||||||
|
self.async_response_iterator = self.streaming_response.__aiter__()
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self):
|
||||||
|
try:
|
||||||
|
chunk = await self.async_response_iterator.__anext__()
|
||||||
|
except StopAsyncIteration:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
except ValueError as e:
|
||||||
|
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
str_line = chunk
|
||||||
|
if isinstance(chunk, bytes): # Handle binary data
|
||||||
|
str_line = chunk.decode("utf-8") # Convert bytes to string
|
||||||
|
index = str_line.find("data:")
|
||||||
|
if index != -1:
|
||||||
|
str_line = str_line[index:]
|
||||||
|
|
||||||
|
# chunk is a str at this point
|
||||||
|
return self._handle_string_chunk(str_line=str_line)
|
||||||
|
except StopAsyncIteration:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
except ValueError as e:
|
||||||
|
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||||
|
|
||||||
|
|
||||||
class FakeStreamResponseIterator:
|
class FakeStreamResponseIterator:
|
||||||
def __init__(self, model_response, json_mode: Optional[bool] = False):
|
def __init__(self, model_response, json_mode: Optional[bool] = False):
|
||||||
self.model_response = model_response
|
self.model_response = model_response
|
||||||
|
|
|
@ -95,6 +95,16 @@ class BaseConfig(ABC):
|
||||||
) -> dict:
|
) -> dict:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def get_complete_url(self, api_base: str, model: str) -> str:
|
||||||
|
"""
|
||||||
|
OPTIONAL
|
||||||
|
|
||||||
|
Get the complete url for the request
|
||||||
|
|
||||||
|
Some providers need `model` in `api_base`
|
||||||
|
"""
|
||||||
|
return api_base
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def transform_request(
|
def transform_request(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -1,180 +0,0 @@
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
import types
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Callable, Optional
|
|
||||||
|
|
||||||
import httpx # type: ignore
|
|
||||||
import requests # type: ignore
|
|
||||||
|
|
||||||
import litellm
|
|
||||||
from litellm.utils import ModelResponse, Usage
|
|
||||||
|
|
||||||
from .prompt_templates.factory import custom_prompt, prompt_factory
|
|
||||||
|
|
||||||
|
|
||||||
class CloudflareError(Exception):
|
|
||||||
def __init__(self, status_code, message):
|
|
||||||
self.status_code = status_code
|
|
||||||
self.message = message
|
|
||||||
self.request = httpx.Request(method="POST", url="https://api.cloudflare.com")
|
|
||||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
|
||||||
super().__init__(
|
|
||||||
self.message
|
|
||||||
) # Call the base class constructor with the parameters it needs
|
|
||||||
|
|
||||||
|
|
||||||
class CloudflareConfig:
|
|
||||||
max_tokens: Optional[int] = None
|
|
||||||
stream: Optional[bool] = None
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
max_tokens: Optional[int] = None,
|
|
||||||
stream: Optional[bool] = None,
|
|
||||||
) -> None:
|
|
||||||
locals_ = locals()
|
|
||||||
for key, value in locals_.items():
|
|
||||||
if key != "self" and value is not None:
|
|
||||||
setattr(self.__class__, key, value)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_config(cls):
|
|
||||||
return {
|
|
||||||
k: v
|
|
||||||
for k, v in cls.__dict__.items()
|
|
||||||
if not k.startswith("__")
|
|
||||||
and not isinstance(
|
|
||||||
v,
|
|
||||||
(
|
|
||||||
types.FunctionType,
|
|
||||||
types.BuiltinFunctionType,
|
|
||||||
classmethod,
|
|
||||||
staticmethod,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
and v is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def validate_environment(api_key):
|
|
||||||
if api_key is None:
|
|
||||||
raise ValueError(
|
|
||||||
"Missing CloudflareError API Key - A call is being made to cloudflare but no key is set either in the environment variables or via params"
|
|
||||||
)
|
|
||||||
headers = {
|
|
||||||
"accept": "application/json",
|
|
||||||
"content-type": "application/json",
|
|
||||||
"Authorization": "Bearer " + api_key,
|
|
||||||
}
|
|
||||||
return headers
|
|
||||||
|
|
||||||
|
|
||||||
def completion(
|
|
||||||
model: str,
|
|
||||||
messages: list,
|
|
||||||
api_base: str,
|
|
||||||
model_response: ModelResponse,
|
|
||||||
print_verbose: Callable,
|
|
||||||
encoding,
|
|
||||||
api_key,
|
|
||||||
logging_obj,
|
|
||||||
optional_params: dict,
|
|
||||||
custom_prompt_dict={},
|
|
||||||
litellm_params=None,
|
|
||||||
logger_fn=None,
|
|
||||||
):
|
|
||||||
headers = validate_environment(api_key)
|
|
||||||
|
|
||||||
## Load Config
|
|
||||||
config = litellm.CloudflareConfig.get_config()
|
|
||||||
for k, v in config.items():
|
|
||||||
if k not in optional_params:
|
|
||||||
optional_params[k] = v
|
|
||||||
|
|
||||||
print_verbose(f"CUSTOM PROMPT DICT: {custom_prompt_dict}; model: {model}")
|
|
||||||
if model in custom_prompt_dict:
|
|
||||||
# check if the model has a registered custom prompt
|
|
||||||
model_prompt_details = custom_prompt_dict[model]
|
|
||||||
custom_prompt(
|
|
||||||
role_dict=model_prompt_details.get("roles", {}),
|
|
||||||
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
|
|
||||||
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
|
||||||
bos_token=model_prompt_details.get("bos_token", ""),
|
|
||||||
eos_token=model_prompt_details.get("eos_token", ""),
|
|
||||||
messages=messages,
|
|
||||||
)
|
|
||||||
|
|
||||||
# cloudflare adds the model to the api base
|
|
||||||
api_base = api_base + model
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"messages": messages,
|
|
||||||
**optional_params,
|
|
||||||
}
|
|
||||||
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.pre_call(
|
|
||||||
input=messages,
|
|
||||||
api_key=api_key,
|
|
||||||
additional_args={
|
|
||||||
"headers": headers,
|
|
||||||
"api_base": api_base,
|
|
||||||
"complete_input_dict": data,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
## COMPLETION CALL
|
|
||||||
if "stream" in optional_params and optional_params["stream"] is True:
|
|
||||||
response = requests.post(
|
|
||||||
api_base,
|
|
||||||
headers=headers,
|
|
||||||
data=json.dumps(data),
|
|
||||||
stream=optional_params["stream"],
|
|
||||||
)
|
|
||||||
return response.iter_lines()
|
|
||||||
else:
|
|
||||||
response = requests.post(api_base, headers=headers, data=json.dumps(data))
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.post_call(
|
|
||||||
input=messages,
|
|
||||||
api_key=api_key,
|
|
||||||
original_response=response.text,
|
|
||||||
additional_args={"complete_input_dict": data},
|
|
||||||
)
|
|
||||||
print_verbose(f"raw model_response: {response.text}")
|
|
||||||
## RESPONSE OBJECT
|
|
||||||
if response.status_code != 200:
|
|
||||||
raise CloudflareError(
|
|
||||||
status_code=response.status_code, message=response.text
|
|
||||||
)
|
|
||||||
completion_response = response.json()
|
|
||||||
|
|
||||||
model_response.choices[0].message.content = completion_response["result"][ # type: ignore
|
|
||||||
"response"
|
|
||||||
]
|
|
||||||
|
|
||||||
## CALCULATING USAGE
|
|
||||||
print_verbose(
|
|
||||||
f"CALCULATING CLOUDFLARE TOKEN USAGE. Model Response: {model_response}; model_response['choices'][0]['message'].get('content', ''): {model_response['choices'][0]['message'].get('content', None)}"
|
|
||||||
)
|
|
||||||
prompt_tokens = litellm.utils.get_token_count(messages=messages, model=model)
|
|
||||||
completion_tokens = len(
|
|
||||||
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
|
||||||
)
|
|
||||||
|
|
||||||
model_response.created = int(time.time())
|
|
||||||
model_response.model = "cloudflare/" + model
|
|
||||||
usage = Usage(
|
|
||||||
prompt_tokens=prompt_tokens,
|
|
||||||
completion_tokens=completion_tokens,
|
|
||||||
total_tokens=prompt_tokens + completion_tokens,
|
|
||||||
)
|
|
||||||
setattr(model_response, "usage", usage)
|
|
||||||
return model_response
|
|
||||||
|
|
||||||
|
|
||||||
def embedding():
|
|
||||||
# logic for parsing in - calling - parsing out model embedding calls
|
|
||||||
pass
|
|
5
litellm/llms/cloudflare/chat/handler.py
Normal file
5
litellm/llms/cloudflare/chat/handler.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
"""
|
||||||
|
Cloudflare - uses `llm_http_handler.py` to make httpx requests
|
||||||
|
|
||||||
|
Request/Response transformation is handled in `transformation.py`
|
||||||
|
"""
|
202
litellm/llms/cloudflare/chat/transformation.py
Normal file
202
litellm/llms/cloudflare/chat/transformation.py
Normal file
|
@ -0,0 +1,202 @@
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from typing import AsyncIterator, Iterator, List, Optional, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||||
|
from litellm.llms.base_llm.transformation import (
|
||||||
|
BaseConfig,
|
||||||
|
BaseLLMException,
|
||||||
|
LiteLLMLoggingObj,
|
||||||
|
)
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
|
from litellm.types.utils import (
|
||||||
|
ChatCompletionToolCallChunk,
|
||||||
|
ChatCompletionUsageBlock,
|
||||||
|
GenericStreamingChunk,
|
||||||
|
ModelResponse,
|
||||||
|
Usage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CloudflareError(BaseLLMException):
|
||||||
|
def __init__(self, status_code, message):
|
||||||
|
self.status_code = status_code
|
||||||
|
self.message = message
|
||||||
|
self.request = httpx.Request(method="POST", url="https://api.cloudflare.com")
|
||||||
|
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||||
|
super().__init__(
|
||||||
|
status_code=status_code,
|
||||||
|
message=message,
|
||||||
|
request=self.request,
|
||||||
|
response=self.response,
|
||||||
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
|
class CloudflareChatConfig(BaseConfig):
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
stream: Optional[bool] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
def validate_environment(
|
||||||
|
self,
|
||||||
|
headers: dict,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
if api_key is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Missing CloudflareError API Key - A call is being made to cloudflare but no key is set either in the environment variables or via params"
|
||||||
|
)
|
||||||
|
headers = {
|
||||||
|
"accept": "application/json",
|
||||||
|
"content-type": "apbplication/json",
|
||||||
|
"Authorization": "Bearer " + api_key,
|
||||||
|
}
|
||||||
|
return headers
|
||||||
|
|
||||||
|
def get_complete_url(self, api_base: str, model: str) -> str:
|
||||||
|
return api_base + model
|
||||||
|
|
||||||
|
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||||
|
return [
|
||||||
|
"stream",
|
||||||
|
"max_tokens",
|
||||||
|
]
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self,
|
||||||
|
non_default_params: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
model: str,
|
||||||
|
drop_params: bool,
|
||||||
|
) -> dict:
|
||||||
|
supported_openai_params = self.get_supported_openai_params(model=model)
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param == "max_completion_tokens":
|
||||||
|
optional_params["max_tokens"] = value
|
||||||
|
elif param in supported_openai_params:
|
||||||
|
optional_params[param] = value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
def transform_request(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
headers: dict,
|
||||||
|
) -> dict:
|
||||||
|
config = litellm.CloudflareChatConfig.get_config()
|
||||||
|
for k, v in config.items():
|
||||||
|
if k not in optional_params:
|
||||||
|
optional_params[k] = v
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"messages": messages,
|
||||||
|
**optional_params,
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
|
||||||
|
def transform_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
raw_response: httpx.Response,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
request_data: dict,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
encoding: str,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
json_mode: Optional[bool] = None,
|
||||||
|
) -> ModelResponse:
|
||||||
|
completion_response = raw_response.json()
|
||||||
|
|
||||||
|
model_response.choices[0].message.content = completion_response["result"][ # type: ignore
|
||||||
|
"response"
|
||||||
|
]
|
||||||
|
|
||||||
|
prompt_tokens = litellm.utils.get_token_count(messages=messages, model=model)
|
||||||
|
completion_tokens = len(
|
||||||
|
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
||||||
|
)
|
||||||
|
|
||||||
|
model_response.created = int(time.time())
|
||||||
|
model_response.model = "cloudflare/" + model
|
||||||
|
usage = Usage(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
)
|
||||||
|
setattr(model_response, "usage", usage)
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
def get_error_class(
|
||||||
|
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||||
|
) -> BaseLLMException:
|
||||||
|
return CloudflareError(
|
||||||
|
status_code=status_code,
|
||||||
|
message=error_message,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _transform_messages(
|
||||||
|
self, messages: List[AllMessageValues]
|
||||||
|
) -> List[AllMessageValues]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_model_response_iterator(
|
||||||
|
self,
|
||||||
|
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||||
|
sync_stream: bool,
|
||||||
|
json_mode: Optional[bool] = False,
|
||||||
|
):
|
||||||
|
return CloudflareChatResponseIterator(
|
||||||
|
streaming_response=streaming_response,
|
||||||
|
sync_stream=sync_stream,
|
||||||
|
json_mode=json_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CloudflareChatResponseIterator(BaseModelResponseIterator):
|
||||||
|
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||||
|
try:
|
||||||
|
text = ""
|
||||||
|
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||||
|
is_finished = False
|
||||||
|
finish_reason = ""
|
||||||
|
usage: Optional[ChatCompletionUsageBlock] = None
|
||||||
|
provider_specific_fields = None
|
||||||
|
|
||||||
|
index = int(chunk.get("index", 0))
|
||||||
|
|
||||||
|
if "response" in chunk:
|
||||||
|
text = chunk["response"]
|
||||||
|
|
||||||
|
returned_chunk = GenericStreamingChunk(
|
||||||
|
text=text,
|
||||||
|
tool_use=tool_use,
|
||||||
|
is_finished=is_finished,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
usage=usage,
|
||||||
|
index=index,
|
||||||
|
provider_specific_fields=provider_specific_fields,
|
||||||
|
)
|
||||||
|
|
||||||
|
return returned_chunk
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
|
5
litellm/llms/cohere/completion/handler.py
Normal file
5
litellm/llms/cohere/completion/handler.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
"""
|
||||||
|
Cohere /generate API - uses `llm_http_handler.py` to make httpx requests
|
||||||
|
|
||||||
|
Request/Response transformation is handled in `transformation.py`
|
||||||
|
"""
|
|
@ -13,7 +13,6 @@ from typing import (
|
||||||
)
|
)
|
||||||
|
|
||||||
import httpx # type: ignore
|
import httpx # type: ignore
|
||||||
import requests # type: ignore
|
|
||||||
from openai.types.chat.chat_completion_chunk import Choice as OpenAIStreamingChoice
|
from openai.types.chat.chat_completion_chunk import Choice as OpenAIStreamingChoice
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
@ -109,6 +108,11 @@ class BaseLLMHTTPHandler:
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
api_base = provider_config.get_complete_url(
|
||||||
|
api_base=api_base,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
|
||||||
data = provider_config.transform_request(
|
data = provider_config.transform_request(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
|
|
@ -86,7 +86,6 @@ from .litellm_core_utils.streaming_chunk_builder_utils import ChunkProcessor
|
||||||
from .llms import (
|
from .llms import (
|
||||||
aleph_alpha,
|
aleph_alpha,
|
||||||
baseten,
|
baseten,
|
||||||
cloudflare,
|
|
||||||
maritalk,
|
maritalk,
|
||||||
nlp_cloud,
|
nlp_cloud,
|
||||||
ollama,
|
ollama,
|
||||||
|
@ -471,6 +470,7 @@ async def acompletion(
|
||||||
or custom_llm_provider == "triton"
|
or custom_llm_provider == "triton"
|
||||||
or custom_llm_provider == "clarifai"
|
or custom_llm_provider == "clarifai"
|
||||||
or custom_llm_provider == "watsonx"
|
or custom_llm_provider == "watsonx"
|
||||||
|
or custom_llm_provider == "cloudflare"
|
||||||
or custom_llm_provider in litellm.openai_compatible_providers
|
or custom_llm_provider in litellm.openai_compatible_providers
|
||||||
or custom_llm_provider in litellm._custom_providers
|
or custom_llm_provider in litellm._custom_providers
|
||||||
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
||||||
|
@ -2828,37 +2828,22 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
)
|
)
|
||||||
|
|
||||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||||
response = cloudflare.completion(
|
response = base_llm_http_handler.completion(
|
||||||
model=model,
|
model=model,
|
||||||
|
stream=stream,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
acompletion=acompletion,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
custom_prompt_dict=litellm.custom_prompt_dict,
|
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
print_verbose=print_verbose,
|
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
logger_fn=logger_fn,
|
custom_llm_provider="cloudflare",
|
||||||
encoding=encoding, # for calculating input/output tokens
|
timeout=timeout,
|
||||||
|
headers=headers,
|
||||||
|
encoding=encoding,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
logging_obj=logging,
|
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
||||||
)
|
)
|
||||||
if "stream" in optional_params and optional_params["stream"] is True:
|
|
||||||
# don't try to access stream object,
|
|
||||||
response = CustomStreamWrapper(
|
|
||||||
response,
|
|
||||||
model,
|
|
||||||
custom_llm_provider="cloudflare",
|
|
||||||
logging_obj=logging,
|
|
||||||
)
|
|
||||||
|
|
||||||
if optional_params.get("stream", False) or acompletion is True:
|
|
||||||
## LOGGING
|
|
||||||
logging.post_call(
|
|
||||||
input=messages,
|
|
||||||
api_key=api_key,
|
|
||||||
original_response=response,
|
|
||||||
)
|
|
||||||
response = response
|
|
||||||
elif (
|
elif (
|
||||||
custom_llm_provider == "baseten"
|
custom_llm_provider == "baseten"
|
||||||
or litellm.api_base == "https://app.baseten.co"
|
or litellm.api_base == "https://app.baseten.co"
|
||||||
|
|
|
@ -3274,10 +3274,16 @@ def get_optional_params( # noqa: PLR0915
|
||||||
)
|
)
|
||||||
_check_valid_arg(supported_params=supported_params)
|
_check_valid_arg(supported_params=supported_params)
|
||||||
|
|
||||||
if max_tokens is not None:
|
optional_params = litellm.CloudflareChatConfig().map_openai_params(
|
||||||
optional_params["max_tokens"] = max_tokens
|
model=model,
|
||||||
if stream is not None:
|
non_default_params=non_default_params,
|
||||||
optional_params["stream"] = stream
|
optional_params=optional_params,
|
||||||
|
drop_params=(
|
||||||
|
drop_params
|
||||||
|
if drop_params is not None and isinstance(drop_params, bool)
|
||||||
|
else False
|
||||||
|
),
|
||||||
|
)
|
||||||
elif custom_llm_provider == "ollama":
|
elif custom_llm_provider == "ollama":
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
model=model, custom_llm_provider=custom_llm_provider
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
@ -6248,6 +6254,8 @@ class ProviderConfigManager:
|
||||||
elif litellm.LlmProviders.VERTEX_AI == provider:
|
elif litellm.LlmProviders.VERTEX_AI == provider:
|
||||||
if "claude" in model:
|
if "claude" in model:
|
||||||
return litellm.VertexAIAnthropicConfig()
|
return litellm.VertexAIAnthropicConfig()
|
||||||
|
elif litellm.LlmProviders.CLOUDFLARE == provider:
|
||||||
|
return litellm.CloudflareChatConfig()
|
||||||
|
|
||||||
return litellm.OpenAIGPTConfig()
|
return litellm.OpenAIGPTConfig()
|
||||||
|
|
||||||
|
|
42
tests/llm_translation/test_cloudflare.py
Normal file
42
tests/llm_translation/test_cloudflare.py
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import json
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
|
||||||
|
|
||||||
|
|
||||||
|
# Cloud flare AI test
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("stream", [True, False])
|
||||||
|
async def test_completion_cloudflare(stream):
|
||||||
|
try:
|
||||||
|
litellm.set_verbose = False
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="cloudflare/@cf/meta/llama-2-7b-chat-int8",
|
||||||
|
messages=[{"content": "what llm are you", "role": "user"}],
|
||||||
|
max_tokens=15,
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
if stream is True:
|
||||||
|
async for chunk in response:
|
||||||
|
print(chunk)
|
||||||
|
else:
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
|
@ -4181,26 +4181,6 @@ def test_completion_together_ai_stream():
|
||||||
# test_completion_together_ai_stream()
|
# test_completion_together_ai_stream()
|
||||||
|
|
||||||
|
|
||||||
# Cloud flare AI tests
|
|
||||||
@pytest.mark.skip(reason="Flaky test-cloudflare is very unstable")
|
|
||||||
def test_completion_cloudflare():
|
|
||||||
try:
|
|
||||||
litellm.set_verbose = True
|
|
||||||
response = completion(
|
|
||||||
model="cloudflare/@cf/meta/llama-2-7b-chat-int8",
|
|
||||||
messages=[{"content": "what llm are you", "role": "user"}],
|
|
||||||
max_tokens=15,
|
|
||||||
num_retries=3,
|
|
||||||
)
|
|
||||||
print(response)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
pytest.fail(f"Error occurred: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
# test_completion_cloudflare()
|
|
||||||
|
|
||||||
|
|
||||||
def test_moderation():
|
def test_moderation():
|
||||||
response = litellm.moderation(input="i'm ishaan cto of litellm")
|
response = litellm.moderation(input="i'm ishaan cto of litellm")
|
||||||
print(response)
|
print(response)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue