From 0834ffaae32fae89aac103a00f5461c1debd1b27 Mon Sep 17 00:00:00 2001 From: Sunny Wan Date: Tue, 11 Mar 2025 02:00:52 -0400 Subject: [PATCH] removed handler and refactored to deepseek/chat format --- litellm/llms/snowflake/completion/handler.py | 63 ----- .../snowflake/completion/transformation.py | 215 ++++++++++++++---- litellm/main.py | 54 +++-- 3 files changed, 204 insertions(+), 128 deletions(-) delete mode 100644 litellm/llms/snowflake/completion/handler.py diff --git a/litellm/llms/snowflake/completion/handler.py b/litellm/llms/snowflake/completion/handler.py deleted file mode 100644 index 85ec676606..0000000000 --- a/litellm/llms/snowflake/completion/handler.py +++ /dev/null @@ -1,63 +0,0 @@ -from litellm.llms.base import BaseLLM -from typing import Any, List, Optional -from typing import List, Dict, Callable, Optional, Any, cast, Union - -import litellm -from litellm.utils import ModelResponse -from litellm.types.llms.openai import AllMessageValues -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler -from litellm.llms.openai_like.chat.handler import OpenAILikeChatHandler -from ..common_utils import SnowflakeBase - -class SnowflakeChatCompletion(OpenAILikeChatHandler,SnowflakeBase): - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def completion( - self, - model: str, - messages: List[Dict[str, Any]], - api_base: str, - custom_prompt_dict: dict, - model_response: ModelResponse, - print_verbose: Callable, - encoding, - JWT: str, - logging_obj, - optional_params: dict, - acompletion=None, - litellm_params=None, - logger_fn=None, - headers: Optional[dict] = None, - client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, - ) -> None: - - messages = litellm.SnowflakeConfig()._transform_messages( - messages=cast(List[AllMessageValues], messages), model=model - ) - - headers = self.validate_environment( - headers, - JWT - ) - - return super().completion( - model=model, - messages=messages, - api_base=api_base, - custom_llm_provider= "snowflake", - custom_prompt_dict=custom_prompt_dict, - model_response=model_response, - print_verbose=print_verbose, - encoding=encoding, - api_key=JWT, - logging_obj=logging_obj, - optional_params=optional_params, - acompletion=acompletion, - litellm_params=litellm_params, - logger_fn=logger_fn, - headers=headers, - client=client, - custom_endpoint=True, - ) diff --git a/litellm/llms/snowflake/completion/transformation.py b/litellm/llms/snowflake/completion/transformation.py index 79b25925e7..48593cf0db 100644 --- a/litellm/llms/snowflake/completion/transformation.py +++ b/litellm/llms/snowflake/completion/transformation.py @@ -2,52 +2,27 @@ Support for Snowflake REST API ''' import httpx -from typing import List, Optional, Union, Any +from typing import List, Optional, Tuple, Any, TYPE_CHECKING -import litellm -from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException from litellm.types.llms.openai import AllMessageValues -from litellm.types.utils import Choices, Message, ModelResponse, TextCompletionResponse -from litellm.litellm_core_utils.prompt_templates.common_utils import ( - convert_content_list_to_str, -) -from ...openai_like.chat.transformation import OpenAILikeChatConfig +from litellm.utils import get_secret +from litellm.types.utils import ModelResponse +from litellm.types.llms.openai import ChatCompletionAssistantMessage +from litellm.llms.databricks.streaming_utils import ModelResponseIterator +from ...openai_like.chat.transformation import OpenAIGPTConfig -class SnowflakeConfig(OpenAILikeChatConfig): +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj + + LiteLLMLoggingObj = _LiteLLMLoggingObj +else: + LiteLLMLoggingObj = Any + +class SnowflakeConfig(OpenAIGPTConfig): """ source: https://docs.snowflake.com/en/sql-reference/functions/complete-snowflake-cortex - - The class `SnowflakeConfig` provides configuration for Snowflake's REST API interface. Below are the parameters: - - - `temperature` (float, optional): A value between 0 and 1 that controls randomness. Lower temperatures mean lower randomness. Default: 0 - - - `top_p` (float, optional): Limits generation at each step to top `k` most likely tokens. Default: 0 - - - `max_tokens `(int, optional): The maximum number of tokens in the response. Default: 4096. Maximum allowed: 8192. - - - `guardrails` (bool, optional): Whether to enable Cortex Guard to filter potentially unsafe responses. Default: False. - - - `response_format` (str, optional): A JSON schema that the response should follow - """ - temperature: Optional[float] - top_p: Optional[float] - max_tokens: Optional[int] - guardrails: Optional[bool] - response_format: Optional[str] - - def __init__( - self, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - max_tokens: Optional[int] = None, - guardrails: Optional[bool] = None, - response_format: Optional[str] = None, - ) -> None: - locals_ = locals().copy() - for key, value in locals_.items(): - if key != "self" and value is not None: - setattr(self.__class__, key, value) + """ @classmethod def get_config(cls): @@ -60,7 +35,7 @@ class SnowflakeConfig(OpenAILikeChatConfig): "top_p", "response_format" ] - + def map_openai_params( self, non_default_params: dict, @@ -83,4 +58,160 @@ class SnowflakeConfig(OpenAILikeChatConfig): for param, value in non_default_params.items(): if param in supported_openai_params: optional_params[param] = value - return optional_params \ No newline at end of file + return optional_params + + def _convert_tool_response_to_message( + message: ChatCompletionAssistantMessage, json_mode: bool + ) -> ChatCompletionAssistantMessage: + """ + if json_mode is true, convert the returned tool call response to a content with json str + + e.g. input: + + {"role": "assistant", "tool_calls": [{"id": "call_5ms4", "type": "function", "function": {"name": "json_tool_call", "arguments": "{\"key\": \"question\", \"value\": \"What is the capital of France?\"}"}}]} + + output: + + {"role": "assistant", "content": "{\"key\": \"question\", \"value\": \"What is the capital of France?\"}"} + """ + if not json_mode: + return message + + _tool_calls = message.get("tool_calls") + + if _tool_calls is None or len(_tool_calls) != 1: + return message + + message["content"] = _tool_calls[0]["function"].get("arguments") or "" + message["tool_calls"] = None + + return message + + + @staticmethod + def transform_response( + model: str, + raw_response: httpx.Response, + model_response: ModelResponse, + logging_obj: LiteLLMLoggingObj, + request_data: dict, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + encoding: Any, + api_key: Optional[str] = None, + json_mode: Optional[bool] = None, + ) -> ModelResponse: + response_json = raw_response.json() + logging_obj.post_call( + input=messages, + api_key="", + original_response=response_json, + additional_args={"complete_input_dict": request_data}, + ) + + if json_mode: + for choice in response_json["choices"]: + message = SnowflakeConfig._convert_tool_response_to_message( + choice.get("message"), json_mode + ) + choice["message"] = message + + returned_response = ModelResponse(**response_json) + + returned_response.model = ( + "snowflake/" + (returned_response.model or "") + ) + + if model is not None: + returned_response._hidden_params["model"] = model + return returned_response + + + def validate_environment( + self, + headers: dict, + model: str, + api_base: str = None, + api_key: Optional[str] = None, + messages: dict = None, + optional_params: dict = None, + ) -> dict: + """ + Return headers to use for Snowflake completion request + + Snowflake REST API Ref: https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-llm-rest-api#api-reference + Expected headers: + { + "Content-Type": "application/json", + "Accept": "application/json", + "Authorization": "Bearer " + , + "X-Snowflake-Authorization-Token-Type": "KEYPAIR_JWT" + } + """ + + if api_key is None: + raise ValueError( + "Missing Snowflake JWT key" + ) + + headers.update( + { + "Content-Type": "application/json", + "Accept": "application/json", + "Authorization": "Bearer " + api_key, + "X-Snowflake-Authorization-Token-Type": "KEYPAIR_JWT" + } + ) + return headers + + def _get_openai_compatible_provider_info( + self, api_base: Optional[str], api_key: Optional[str] + ) -> Tuple[Optional[str], Optional[str]]: + api_base = ( + api_base + or f"""https://{get_secret("SNOWFLAKE_ACCOUNT_ID")}.snowflakecomputing.com/api/v2/cortex/inference:complete""" + or get_secret("SNOWFLAKE_API_BASE") + ) # type: ignore + dynamic_api_key = api_key or get_secret("SNOWFLAKE_JWT") + return api_base, dynamic_api_key + + def get_complete_url( + self, + api_base: Optional[str], + model: str, + optional_params: dict, + stream: Optional[bool] = None, + ) -> str: + """ + If api_base is not provided, use the default DeepSeek /chat/completions endpoint. + """ + if not api_base: + api_base = f"""https://{get_secret("SNOWFLAKE_ACCOUNT_ID")}.snowflakecomputing.com/api/v2/cortex/inference:complete""" + + return api_base + + def transform_request( + self, + model: str, + messages: dict , + optional_params: dict, + litellm_params: dict, + headers: dict + ) -> dict: + stream: bool = optional_params.pop("stream", None) or False + extra_body = optional_params.pop("extra_body", {}) + return { + "model": model, + "messages": messages, + "stream": stream, + **optional_params, + **extra_body, + } + + def get_model_response_iterator( + self, + streaming_response: ModelResponse, + sync_stream: bool, + ): + return ModelResponseIterator(streaming_response=streaming_response, sync_stream=sync_stream) \ No newline at end of file diff --git a/litellm/main.py b/litellm/main.py index e4abbd8458..9244e47d49 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -146,7 +146,6 @@ from .llms.openai_like.embedding.handler import OpenAILikeEmbeddingHandler from .llms.petals.completion import handler as petals_handler from .llms.predibase.chat.handler import PredibaseChatCompletion from .llms.replicate.chat.handler import completion as replicate_chat_completion -from .llms.snowflake.completion.handler import SnowflakeChatCompletion from .llms.sagemaker.chat.handler import SagemakerChatHandler from .llms.sagemaker.completion.handler import SagemakerLLM from .llms.vertex_ai import vertex_ai_non_gemini @@ -237,7 +236,6 @@ databricks_embedding = DatabricksEmbeddingHandler() base_llm_http_handler = BaseLLMHTTPHandler() base_llm_aiohttp_handler = BaseLLMAIOHTTPHandler() sagemaker_chat_completion = SagemakerChatHandler() -snow_flake_chat_completion = SnowflakeChatCompletion() ####### COMPLETION ENDPOINTS ################ @@ -2977,27 +2975,37 @@ def completion( # type: ignore # noqa: PLR0915 return response response = model_response elif custom_llm_provider == "snowflake" or model in litellm.snowflake_models: - api_base = ( - api_base - or f"""https://{get_secret("SNOWFLAKE_ACCOUNT_ID")}.snowflakecomputing.com/api/v2/cortex/inference:complete""" - or get_secret("SNOWFLAKE_API_BASE") - ) - response = snow_flake_chat_completion.completion( - model=model, - messages=messages, - api_base=api_base, - acompletion=acompletion, - custom_prompt_dict=litellm.custom_prompt_dict, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - JWT=api_key, - logging_obj=logging, - headers=headers, - ) + try: + client = HTTPHandler(timeout=timeout) if stream is False else None # Keep this here, otherwise, the httpx.client closes and streaming is impossible + response = base_llm_http_handler.completion( + model=model, + messages=messages, + headers=headers, + model_response=model_response, + api_key=api_key, + api_base=api_base, + acompletion=acompletion, + logging_obj=logging, + optional_params=optional_params, + litellm_params=litellm_params, + timeout=timeout, # type: ignore + client= client, + custom_llm_provider=custom_llm_provider, + encoding=encoding, + stream=stream, + ) + + + except Exception as e: + ## LOGGING - log the original exception returned + logging.post_call( + input=messages, + api_key=api_key, + original_response=str(e), + additional_args={"headers": headers}, + ) + raise e + elif custom_llm_provider == "custom": url = litellm.api_base or api_base or "" if url is None or url == "":