removed handler and refactored to deepseek/chat format

This commit is contained in:
Sunny Wan 2025-03-11 02:00:52 -04:00
parent 88f165853d
commit 0834ffaae3
3 changed files with 204 additions and 128 deletions

View file

@ -1,63 +0,0 @@
from litellm.llms.base import BaseLLM
from typing import Any, List, Optional
from typing import List, Dict, Callable, Optional, Any, cast, Union
import litellm
from litellm.utils import ModelResponse
from litellm.types.llms.openai import AllMessageValues
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.openai_like.chat.handler import OpenAILikeChatHandler
from ..common_utils import SnowflakeBase
class SnowflakeChatCompletion(OpenAILikeChatHandler,SnowflakeBase):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def completion(
self,
model: str,
messages: List[Dict[str, Any]],
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
JWT: str,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers: Optional[dict] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
) -> None:
messages = litellm.SnowflakeConfig()._transform_messages(
messages=cast(List[AllMessageValues], messages), model=model
)
headers = self.validate_environment(
headers,
JWT
)
return super().completion(
model=model,
messages=messages,
api_base=api_base,
custom_llm_provider= "snowflake",
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
api_key=JWT,
logging_obj=logging_obj,
optional_params=optional_params,
acompletion=acompletion,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
client=client,
custom_endpoint=True,
)

View file

@ -2,52 +2,27 @@
Support for Snowflake REST API Support for Snowflake REST API
''' '''
import httpx import httpx
from typing import List, Optional, Union, Any from typing import List, Optional, Tuple, Any, TYPE_CHECKING
import litellm
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.openai import AllMessageValues from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import Choices, Message, ModelResponse, TextCompletionResponse from litellm.utils import get_secret
from litellm.litellm_core_utils.prompt_templates.common_utils import ( from litellm.types.utils import ModelResponse
convert_content_list_to_str, from litellm.types.llms.openai import ChatCompletionAssistantMessage
) from litellm.llms.databricks.streaming_utils import ModelResponseIterator
from ...openai_like.chat.transformation import OpenAILikeChatConfig
from ...openai_like.chat.transformation import OpenAIGPTConfig
class SnowflakeConfig(OpenAILikeChatConfig): if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class SnowflakeConfig(OpenAIGPTConfig):
""" """
source: https://docs.snowflake.com/en/sql-reference/functions/complete-snowflake-cortex source: https://docs.snowflake.com/en/sql-reference/functions/complete-snowflake-cortex
"""
The class `SnowflakeConfig` provides configuration for Snowflake's REST API interface. Below are the parameters:
- `temperature` (float, optional): A value between 0 and 1 that controls randomness. Lower temperatures mean lower randomness. Default: 0
- `top_p` (float, optional): Limits generation at each step to top `k` most likely tokens. Default: 0
- `max_tokens `(int, optional): The maximum number of tokens in the response. Default: 4096. Maximum allowed: 8192.
- `guardrails` (bool, optional): Whether to enable Cortex Guard to filter potentially unsafe responses. Default: False.
- `response_format` (str, optional): A JSON schema that the response should follow
"""
temperature: Optional[float]
top_p: Optional[float]
max_tokens: Optional[int]
guardrails: Optional[bool]
response_format: Optional[str]
def __init__(
self,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
max_tokens: Optional[int] = None,
guardrails: Optional[bool] = None,
response_format: Optional[str] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
@ -60,7 +35,7 @@ class SnowflakeConfig(OpenAILikeChatConfig):
"top_p", "top_p",
"response_format" "response_format"
] ]
def map_openai_params( def map_openai_params(
self, self,
non_default_params: dict, non_default_params: dict,
@ -83,4 +58,160 @@ class SnowflakeConfig(OpenAILikeChatConfig):
for param, value in non_default_params.items(): for param, value in non_default_params.items():
if param in supported_openai_params: if param in supported_openai_params:
optional_params[param] = value optional_params[param] = value
return optional_params return optional_params
def _convert_tool_response_to_message(
message: ChatCompletionAssistantMessage, json_mode: bool
) -> ChatCompletionAssistantMessage:
"""
if json_mode is true, convert the returned tool call response to a content with json str
e.g. input:
{"role": "assistant", "tool_calls": [{"id": "call_5ms4", "type": "function", "function": {"name": "json_tool_call", "arguments": "{\"key\": \"question\", \"value\": \"What is the capital of France?\"}"}}]}
output:
{"role": "assistant", "content": "{\"key\": \"question\", \"value\": \"What is the capital of France?\"}"}
"""
if not json_mode:
return message
_tool_calls = message.get("tool_calls")
if _tool_calls is None or len(_tool_calls) != 1:
return message
message["content"] = _tool_calls[0]["function"].get("arguments") or ""
message["tool_calls"] = None
return message
@staticmethod
def transform_response(
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
response_json = raw_response.json()
logging_obj.post_call(
input=messages,
api_key="",
original_response=response_json,
additional_args={"complete_input_dict": request_data},
)
if json_mode:
for choice in response_json["choices"]:
message = SnowflakeConfig._convert_tool_response_to_message(
choice.get("message"), json_mode
)
choice["message"] = message
returned_response = ModelResponse(**response_json)
returned_response.model = (
"snowflake/" + (returned_response.model or "")
)
if model is not None:
returned_response._hidden_params["model"] = model
return returned_response
def validate_environment(
self,
headers: dict,
model: str,
api_base: str = None,
api_key: Optional[str] = None,
messages: dict = None,
optional_params: dict = None,
) -> dict:
"""
Return headers to use for Snowflake completion request
Snowflake REST API Ref: https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-llm-rest-api#api-reference
Expected headers:
{
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": "Bearer " + <JWT>,
"X-Snowflake-Authorization-Token-Type": "KEYPAIR_JWT"
}
"""
if api_key is None:
raise ValueError(
"Missing Snowflake JWT key"
)
headers.update(
{
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": "Bearer " + api_key,
"X-Snowflake-Authorization-Token-Type": "KEYPAIR_JWT"
}
)
return headers
def _get_openai_compatible_provider_info(
self, api_base: Optional[str], api_key: Optional[str]
) -> Tuple[Optional[str], Optional[str]]:
api_base = (
api_base
or f"""https://{get_secret("SNOWFLAKE_ACCOUNT_ID")}.snowflakecomputing.com/api/v2/cortex/inference:complete"""
or get_secret("SNOWFLAKE_API_BASE")
) # type: ignore
dynamic_api_key = api_key or get_secret("SNOWFLAKE_JWT")
return api_base, dynamic_api_key
def get_complete_url(
self,
api_base: Optional[str],
model: str,
optional_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
If api_base is not provided, use the default DeepSeek /chat/completions endpoint.
"""
if not api_base:
api_base = f"""https://{get_secret("SNOWFLAKE_ACCOUNT_ID")}.snowflakecomputing.com/api/v2/cortex/inference:complete"""
return api_base
def transform_request(
self,
model: str,
messages: dict ,
optional_params: dict,
litellm_params: dict,
headers: dict
) -> dict:
stream: bool = optional_params.pop("stream", None) or False
extra_body = optional_params.pop("extra_body", {})
return {
"model": model,
"messages": messages,
"stream": stream,
**optional_params,
**extra_body,
}
def get_model_response_iterator(
self,
streaming_response: ModelResponse,
sync_stream: bool,
):
return ModelResponseIterator(streaming_response=streaming_response, sync_stream=sync_stream)

View file

@ -146,7 +146,6 @@ from .llms.openai_like.embedding.handler import OpenAILikeEmbeddingHandler
from .llms.petals.completion import handler as petals_handler from .llms.petals.completion import handler as petals_handler
from .llms.predibase.chat.handler import PredibaseChatCompletion from .llms.predibase.chat.handler import PredibaseChatCompletion
from .llms.replicate.chat.handler import completion as replicate_chat_completion from .llms.replicate.chat.handler import completion as replicate_chat_completion
from .llms.snowflake.completion.handler import SnowflakeChatCompletion
from .llms.sagemaker.chat.handler import SagemakerChatHandler from .llms.sagemaker.chat.handler import SagemakerChatHandler
from .llms.sagemaker.completion.handler import SagemakerLLM from .llms.sagemaker.completion.handler import SagemakerLLM
from .llms.vertex_ai import vertex_ai_non_gemini from .llms.vertex_ai import vertex_ai_non_gemini
@ -237,7 +236,6 @@ databricks_embedding = DatabricksEmbeddingHandler()
base_llm_http_handler = BaseLLMHTTPHandler() base_llm_http_handler = BaseLLMHTTPHandler()
base_llm_aiohttp_handler = BaseLLMAIOHTTPHandler() base_llm_aiohttp_handler = BaseLLMAIOHTTPHandler()
sagemaker_chat_completion = SagemakerChatHandler() sagemaker_chat_completion = SagemakerChatHandler()
snow_flake_chat_completion = SnowflakeChatCompletion()
####### COMPLETION ENDPOINTS ################ ####### COMPLETION ENDPOINTS ################
@ -2977,27 +2975,37 @@ def completion( # type: ignore # noqa: PLR0915
return response return response
response = model_response response = model_response
elif custom_llm_provider == "snowflake" or model in litellm.snowflake_models: elif custom_llm_provider == "snowflake" or model in litellm.snowflake_models:
api_base = ( try:
api_base client = HTTPHandler(timeout=timeout) if stream is False else None # Keep this here, otherwise, the httpx.client closes and streaming is impossible
or f"""https://{get_secret("SNOWFLAKE_ACCOUNT_ID")}.snowflakecomputing.com/api/v2/cortex/inference:complete""" response = base_llm_http_handler.completion(
or get_secret("SNOWFLAKE_API_BASE") model=model,
) messages=messages,
response = snow_flake_chat_completion.completion( headers=headers,
model=model, model_response=model_response,
messages=messages, api_key=api_key,
api_base=api_base, api_base=api_base,
acompletion=acompletion, acompletion=acompletion,
custom_prompt_dict=litellm.custom_prompt_dict, logging_obj=logging,
model_response=model_response, optional_params=optional_params,
print_verbose=print_verbose, litellm_params=litellm_params,
optional_params=optional_params, timeout=timeout, # type: ignore
litellm_params=litellm_params, client= client,
logger_fn=logger_fn, custom_llm_provider=custom_llm_provider,
encoding=encoding, encoding=encoding,
JWT=api_key, stream=stream,
logging_obj=logging, )
headers=headers,
)
except Exception as e:
## LOGGING - log the original exception returned
logging.post_call(
input=messages,
api_key=api_key,
original_response=str(e),
additional_args={"headers": headers},
)
raise e
elif custom_llm_provider == "custom": elif custom_llm_provider == "custom":
url = litellm.api_base or api_base or "" url = litellm.api_base or api_base or ""
if url is None or url == "": if url is None or url == "":