mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
removed handler and refactored to deepseek/chat format
This commit is contained in:
parent
88f165853d
commit
0834ffaae3
3 changed files with 204 additions and 128 deletions
|
@ -1,63 +0,0 @@
|
|||
from litellm.llms.base import BaseLLM
|
||||
from typing import Any, List, Optional
|
||||
from typing import List, Dict, Callable, Optional, Any, cast, Union
|
||||
|
||||
import litellm
|
||||
from litellm.utils import ModelResponse
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.llms.openai_like.chat.handler import OpenAILikeChatHandler
|
||||
from ..common_utils import SnowflakeBase
|
||||
|
||||
class SnowflakeChatCompletion(OpenAILikeChatHandler,SnowflakeBase):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict[str, Any]],
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
JWT: str,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers: Optional[dict] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
) -> None:
|
||||
|
||||
messages = litellm.SnowflakeConfig()._transform_messages(
|
||||
messages=cast(List[AllMessageValues], messages), model=model
|
||||
)
|
||||
|
||||
headers = self.validate_environment(
|
||||
headers,
|
||||
JWT
|
||||
)
|
||||
|
||||
return super().completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
custom_llm_provider= "snowflake",
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
api_key=JWT,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
acompletion=acompletion,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
client=client,
|
||||
custom_endpoint=True,
|
||||
)
|
|
@ -2,52 +2,27 @@
|
|||
Support for Snowflake REST API
|
||||
'''
|
||||
import httpx
|
||||
from typing import List, Optional, Union, Any
|
||||
from typing import List, Optional, Tuple, Any, TYPE_CHECKING
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import Choices, Message, ModelResponse, TextCompletionResponse
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
convert_content_list_to_str,
|
||||
)
|
||||
from ...openai_like.chat.transformation import OpenAILikeChatConfig
|
||||
from litellm.utils import get_secret
|
||||
from litellm.types.utils import ModelResponse
|
||||
from litellm.types.llms.openai import ChatCompletionAssistantMessage
|
||||
from litellm.llms.databricks.streaming_utils import ModelResponseIterator
|
||||
|
||||
from ...openai_like.chat.transformation import OpenAIGPTConfig
|
||||
|
||||
class SnowflakeConfig(OpenAILikeChatConfig):
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
class SnowflakeConfig(OpenAIGPTConfig):
|
||||
"""
|
||||
source: https://docs.snowflake.com/en/sql-reference/functions/complete-snowflake-cortex
|
||||
|
||||
The class `SnowflakeConfig` provides configuration for Snowflake's REST API interface. Below are the parameters:
|
||||
|
||||
- `temperature` (float, optional): A value between 0 and 1 that controls randomness. Lower temperatures mean lower randomness. Default: 0
|
||||
|
||||
- `top_p` (float, optional): Limits generation at each step to top `k` most likely tokens. Default: 0
|
||||
|
||||
- `max_tokens `(int, optional): The maximum number of tokens in the response. Default: 4096. Maximum allowed: 8192.
|
||||
|
||||
- `guardrails` (bool, optional): Whether to enable Cortex Guard to filter potentially unsafe responses. Default: False.
|
||||
|
||||
- `response_format` (str, optional): A JSON schema that the response should follow
|
||||
"""
|
||||
temperature: Optional[float]
|
||||
top_p: Optional[float]
|
||||
max_tokens: Optional[int]
|
||||
guardrails: Optional[bool]
|
||||
response_format: Optional[str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
guardrails: Optional[bool] = None,
|
||||
response_format: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
|
@ -84,3 +59,159 @@ class SnowflakeConfig(OpenAILikeChatConfig):
|
|||
if param in supported_openai_params:
|
||||
optional_params[param] = value
|
||||
return optional_params
|
||||
|
||||
def _convert_tool_response_to_message(
|
||||
message: ChatCompletionAssistantMessage, json_mode: bool
|
||||
) -> ChatCompletionAssistantMessage:
|
||||
"""
|
||||
if json_mode is true, convert the returned tool call response to a content with json str
|
||||
|
||||
e.g. input:
|
||||
|
||||
{"role": "assistant", "tool_calls": [{"id": "call_5ms4", "type": "function", "function": {"name": "json_tool_call", "arguments": "{\"key\": \"question\", \"value\": \"What is the capital of France?\"}"}}]}
|
||||
|
||||
output:
|
||||
|
||||
{"role": "assistant", "content": "{\"key\": \"question\", \"value\": \"What is the capital of France?\"}"}
|
||||
"""
|
||||
if not json_mode:
|
||||
return message
|
||||
|
||||
_tool_calls = message.get("tool_calls")
|
||||
|
||||
if _tool_calls is None or len(_tool_calls) != 1:
|
||||
return message
|
||||
|
||||
message["content"] = _tool_calls[0]["function"].get("arguments") or ""
|
||||
message["tool_calls"] = None
|
||||
|
||||
return message
|
||||
|
||||
|
||||
@staticmethod
|
||||
def transform_response(
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
response_json = raw_response.json()
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response=response_json,
|
||||
additional_args={"complete_input_dict": request_data},
|
||||
)
|
||||
|
||||
if json_mode:
|
||||
for choice in response_json["choices"]:
|
||||
message = SnowflakeConfig._convert_tool_response_to_message(
|
||||
choice.get("message"), json_mode
|
||||
)
|
||||
choice["message"] = message
|
||||
|
||||
returned_response = ModelResponse(**response_json)
|
||||
|
||||
returned_response.model = (
|
||||
"snowflake/" + (returned_response.model or "")
|
||||
)
|
||||
|
||||
if model is not None:
|
||||
returned_response._hidden_params["model"] = model
|
||||
return returned_response
|
||||
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_base: str = None,
|
||||
api_key: Optional[str] = None,
|
||||
messages: dict = None,
|
||||
optional_params: dict = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Return headers to use for Snowflake completion request
|
||||
|
||||
Snowflake REST API Ref: https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-llm-rest-api#api-reference
|
||||
Expected headers:
|
||||
{
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"Authorization": "Bearer " + <JWT>,
|
||||
"X-Snowflake-Authorization-Token-Type": "KEYPAIR_JWT"
|
||||
}
|
||||
"""
|
||||
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"Missing Snowflake JWT key"
|
||||
)
|
||||
|
||||
headers.update(
|
||||
{
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"Authorization": "Bearer " + api_key,
|
||||
"X-Snowflake-Authorization-Token-Type": "KEYPAIR_JWT"
|
||||
}
|
||||
)
|
||||
return headers
|
||||
|
||||
def _get_openai_compatible_provider_info(
|
||||
self, api_base: Optional[str], api_key: Optional[str]
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
api_base = (
|
||||
api_base
|
||||
or f"""https://{get_secret("SNOWFLAKE_ACCOUNT_ID")}.snowflakecomputing.com/api/v2/cortex/inference:complete"""
|
||||
or get_secret("SNOWFLAKE_API_BASE")
|
||||
) # type: ignore
|
||||
dynamic_api_key = api_key or get_secret("SNOWFLAKE_JWT")
|
||||
return api_base, dynamic_api_key
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
If api_base is not provided, use the default DeepSeek /chat/completions endpoint.
|
||||
"""
|
||||
if not api_base:
|
||||
api_base = f"""https://{get_secret("SNOWFLAKE_ACCOUNT_ID")}.snowflakecomputing.com/api/v2/cortex/inference:complete"""
|
||||
|
||||
return api_base
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: dict ,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict
|
||||
) -> dict:
|
||||
stream: bool = optional_params.pop("stream", None) or False
|
||||
extra_body = optional_params.pop("extra_body", {})
|
||||
return {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
**optional_params,
|
||||
**extra_body,
|
||||
}
|
||||
|
||||
def get_model_response_iterator(
|
||||
self,
|
||||
streaming_response: ModelResponse,
|
||||
sync_stream: bool,
|
||||
):
|
||||
return ModelResponseIterator(streaming_response=streaming_response, sync_stream=sync_stream)
|
|
@ -146,7 +146,6 @@ from .llms.openai_like.embedding.handler import OpenAILikeEmbeddingHandler
|
|||
from .llms.petals.completion import handler as petals_handler
|
||||
from .llms.predibase.chat.handler import PredibaseChatCompletion
|
||||
from .llms.replicate.chat.handler import completion as replicate_chat_completion
|
||||
from .llms.snowflake.completion.handler import SnowflakeChatCompletion
|
||||
from .llms.sagemaker.chat.handler import SagemakerChatHandler
|
||||
from .llms.sagemaker.completion.handler import SagemakerLLM
|
||||
from .llms.vertex_ai import vertex_ai_non_gemini
|
||||
|
@ -237,7 +236,6 @@ databricks_embedding = DatabricksEmbeddingHandler()
|
|||
base_llm_http_handler = BaseLLMHTTPHandler()
|
||||
base_llm_aiohttp_handler = BaseLLMAIOHTTPHandler()
|
||||
sagemaker_chat_completion = SagemakerChatHandler()
|
||||
snow_flake_chat_completion = SnowflakeChatCompletion()
|
||||
####### COMPLETION ENDPOINTS ################
|
||||
|
||||
|
||||
|
@ -2977,27 +2975,37 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
return response
|
||||
response = model_response
|
||||
elif custom_llm_provider == "snowflake" or model in litellm.snowflake_models:
|
||||
api_base = (
|
||||
api_base
|
||||
or f"""https://{get_secret("SNOWFLAKE_ACCOUNT_ID")}.snowflakecomputing.com/api/v2/cortex/inference:complete"""
|
||||
or get_secret("SNOWFLAKE_API_BASE")
|
||||
)
|
||||
response = snow_flake_chat_completion.completion(
|
||||
try:
|
||||
client = HTTPHandler(timeout=timeout) if stream is False else None # Keep this here, otherwise, the httpx.client closes and streaming is impossible
|
||||
response = base_llm_http_handler.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
headers=headers,
|
||||
model_response=model_response,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
acompletion=acompletion,
|
||||
custom_prompt_dict=litellm.custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
logging_obj=logging,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
timeout=timeout, # type: ignore
|
||||
client= client,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
encoding=encoding,
|
||||
JWT=api_key,
|
||||
logging_obj=logging,
|
||||
headers=headers,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
## LOGGING - log the original exception returned
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=str(e),
|
||||
additional_args={"headers": headers},
|
||||
)
|
||||
raise e
|
||||
|
||||
elif custom_llm_provider == "custom":
|
||||
url = litellm.api_base or api_base or ""
|
||||
if url is None or url == "":
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue