litellm-mirror/litellm/llms/cloudflare/chat/transformation.py
Krish Dholakia 1a4910f6c0 fix(health.md): add rerank model health check information (#7295)
* fix(health.md): add rerank model health check information

* build(model_prices_and_context_window.json): add gemini 2.0 for google ai studio - pricing + commercial rate limits

* build(model_prices_and_context_window.json): add gemini-2.0 supports audio output = true

* docs(team_model_add.md): clarify allowing teams to add models is an enterprise feature

* fix(o1_transformation.py): add support for 'n', 'response_format' and 'stop' params for o1 and 'stream_options' param for o1-mini

* build(model_prices_and_context_window.json): add 'supports_system_message' to supporting openai models

needed as o1-preview, and o1-mini models don't support 'system message

* fix(o1_transformation.py): translate system message based on if o1 model supports it

* fix(o1_transformation.py): return 'stream' param support if o1-mini/o1-preview

o1 currently doesn't support streaming, but the other model versions do

Fixes https://github.com/BerriAI/litellm/issues/7292

* fix(o1_transformation.py): return tool calling/response_format in supported params if model map says so

Fixes https://github.com/BerriAI/litellm/issues/7292

* fix: fix linting errors

* fix: update '_transform_messages'

* fix(o1_transformation.py): fix provider passed for supported param checks

* test(base_llm_unit_tests.py): skip test if api takes >5s to respond

* fix(utils.py): return false in 'supports_factory' if can't find value

* fix(o1_transformation.py): always return stream + stream_options as supported params + handle stream options being passed in for azure o1

* feat(openai.py): support stream faking natively in openai handler

Allows o1 calls to be faked for just the "o1" model, allows native streaming for o1-mini, o1-preview

 Fixes https://github.com/BerriAI/litellm/issues/7292

* fix(openai.py): use inference param instead of original optional param
2024-12-18 19:18:10 -08:00

202 lines
6.2 KiB
Python

import json
import time
from typing import AsyncIterator, Iterator, List, Optional, Union
import httpx
import litellm
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
from litellm.llms.base_llm.chat.transformation import (
BaseConfig,
BaseLLMException,
LiteLLMLoggingObj,
)
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import (
ChatCompletionToolCallChunk,
ChatCompletionUsageBlock,
GenericStreamingChunk,
ModelResponse,
Usage,
)
class CloudflareError(BaseLLMException):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(method="POST", url="https://api.cloudflare.com")
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
status_code=status_code,
message=message,
request=self.request,
response=self.response,
) # Call the base class constructor with the parameters it needs
class CloudflareChatConfig(BaseConfig):
max_tokens: Optional[int] = None
stream: Optional[bool] = None
def __init__(
self,
max_tokens: Optional[int] = None,
stream: Optional[bool] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return super().get_config()
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
) -> dict:
if api_key is None:
raise ValueError(
"Missing CloudflareError API Key - A call is being made to cloudflare but no key is set either in the environment variables or via params"
)
headers = {
"accept": "application/json",
"content-type": "apbplication/json",
"Authorization": "Bearer " + api_key,
}
return headers
def get_complete_url(self, api_base: str, model: str) -> str:
return api_base + model
def get_supported_openai_params(self, model: str) -> List[str]:
return [
"stream",
"max_tokens",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
supported_openai_params = self.get_supported_openai_params(model=model)
for param, value in non_default_params.items():
if param == "max_completion_tokens":
optional_params["max_tokens"] = value
elif param in supported_openai_params:
optional_params[param] = value
return optional_params
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
config = litellm.CloudflareChatConfig.get_config()
for k, v in config.items():
if k not in optional_params:
optional_params[k] = v
data = {
"messages": messages,
**optional_params,
}
return data
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: str,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
completion_response = raw_response.json()
model_response.choices[0].message.content = completion_response["result"][ # type: ignore
"response"
]
prompt_tokens = litellm.utils.get_token_count(messages=messages, model=model)
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
model_response.created = int(time.time())
model_response.model = "cloudflare/" + model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return CloudflareError(
status_code=status_code,
message=error_message,
)
def get_model_response_iterator(
self,
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
sync_stream: bool,
json_mode: Optional[bool] = False,
):
return CloudflareChatResponseIterator(
streaming_response=streaming_response,
sync_stream=sync_stream,
json_mode=json_mode,
)
class CloudflareChatResponseIterator(BaseModelResponseIterator):
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
try:
text = ""
tool_use: Optional[ChatCompletionToolCallChunk] = None
is_finished = False
finish_reason = ""
usage: Optional[ChatCompletionUsageBlock] = None
provider_specific_fields = None
index = int(chunk.get("index", 0))
if "response" in chunk:
text = chunk["response"]
returned_chunk = GenericStreamingChunk(
text=text,
tool_use=tool_use,
is_finished=is_finished,
finish_reason=finish_reason,
usage=usage,
index=index,
provider_specific_fields=provider_specific_fields,
)
return returned_chunk
except json.JSONDecodeError:
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")