diff --git a/litellm/llms/cloudflare.py b/litellm/llms/cloudflare.py index 9f1c390e6d..a9e60bb7e0 100644 --- a/litellm/llms/cloudflare.py +++ b/litellm/llms/cloudflare.py @@ -23,10 +23,12 @@ class CloudflareError(Exception): class CloudflareConfig: max_tokens: Optional[int] = None + stream: Optional[bool] = None def __init__( self, max_tokens: Optional[int] = None, + stream: Optional[bool] = None, ) -> None: locals_ = locals() for key, value in locals_.items(): diff --git a/litellm/utils.py b/litellm/utils.py index 4a983f9be6..a1f0749315 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -3398,6 +3398,16 @@ def get_optional_params( optional_params["n"] = n if stop is not None: optional_params["stop_sequences"] = stop + elif ( + custom_llm_provider == "cloudlfare" + ): # https://developers.cloudflare.com/workers-ai/models/text-generation/#input + supported_params = ["max_tokens", "stream"] + _check_valid_arg(supported_params=supported_params) + + if max_tokens is not None: + optional_params["max_tokens"] = temperature + if stream is not None: + optional_params["stream"] = stream elif custom_llm_provider == "ollama": supported_params = [ "max_tokens",