mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
[FEAT] Added snowflake completion provider
This commit is contained in:
parent
842d8dec09
commit
fd090c8043
8 changed files with 288 additions and 0 deletions
|
@ -180,6 +180,7 @@ cloudflare_api_key: Optional[str] = None
|
||||||
baseten_key: Optional[str] = None
|
baseten_key: Optional[str] = None
|
||||||
aleph_alpha_key: Optional[str] = None
|
aleph_alpha_key: Optional[str] = None
|
||||||
nlp_cloud_key: Optional[str] = None
|
nlp_cloud_key: Optional[str] = None
|
||||||
|
snowflake_key: Optional[str] = None
|
||||||
common_cloud_provider_auth_params: dict = {
|
common_cloud_provider_auth_params: dict = {
|
||||||
"params": ["project", "region_name", "token"],
|
"params": ["project", "region_name", "token"],
|
||||||
"providers": ["vertex_ai", "bedrock", "watsonx", "azure", "vertex_ai_beta"],
|
"providers": ["vertex_ai", "bedrock", "watsonx", "azure", "vertex_ai_beta"],
|
||||||
|
@ -414,6 +415,7 @@ cerebras_models: List = []
|
||||||
galadriel_models: List = []
|
galadriel_models: List = []
|
||||||
sambanova_models: List = []
|
sambanova_models: List = []
|
||||||
assemblyai_models: List = []
|
assemblyai_models: List = []
|
||||||
|
snowflake_models: List = []
|
||||||
|
|
||||||
|
|
||||||
def is_bedrock_pricing_only_model(key: str) -> bool:
|
def is_bedrock_pricing_only_model(key: str) -> bool:
|
||||||
|
@ -567,6 +569,8 @@ def add_known_models():
|
||||||
assemblyai_models.append(key)
|
assemblyai_models.append(key)
|
||||||
elif value.get("litellm_provider") == "jina_ai":
|
elif value.get("litellm_provider") == "jina_ai":
|
||||||
jina_ai_models.append(key)
|
jina_ai_models.append(key)
|
||||||
|
elif value.get("litellm_provider") == "snowflake":
|
||||||
|
snowflake_models.append(key)
|
||||||
|
|
||||||
|
|
||||||
add_known_models()
|
add_known_models()
|
||||||
|
@ -596,6 +600,30 @@ ollama_models = ["llama2"]
|
||||||
|
|
||||||
maritalk_models = ["maritalk"]
|
maritalk_models = ["maritalk"]
|
||||||
|
|
||||||
|
# Probably shouldn't hard code this, change later
|
||||||
|
snowflake_models = [
|
||||||
|
"snowflake/deepseek-r1",
|
||||||
|
"snowflake/claude-3-5-sonnet",
|
||||||
|
"snowflake/llama3.2-1b",
|
||||||
|
"snowflake/llama3.2-3b",
|
||||||
|
"snowflake/llama3.1-8b",
|
||||||
|
"snowflake/llama3.1-70b",
|
||||||
|
"snowflake/llama3.3-70b",
|
||||||
|
"snowflake/snowflake-llama-3.3-70b",
|
||||||
|
"snowflake/llama3.1-405b",
|
||||||
|
"snowflake/snowflake-llama-3.1-405b",
|
||||||
|
"snowflake/snowflake-arctic",
|
||||||
|
"snowflake/reka-core",
|
||||||
|
"snowflake/reka-flash",
|
||||||
|
"snowflake/mistral-large2",
|
||||||
|
"snowflake/mixtral-8x7b",
|
||||||
|
"snowflake/mistral-7b",
|
||||||
|
"snowflake/jamba-instruct",
|
||||||
|
"snowflake/jamba-1.5-mini",
|
||||||
|
"snowflake/jamba-1.5-large",
|
||||||
|
"snowflake/gemma-7b"
|
||||||
|
]
|
||||||
|
|
||||||
model_list = (
|
model_list = (
|
||||||
open_ai_chat_completion_models
|
open_ai_chat_completion_models
|
||||||
+ open_ai_text_completion_models
|
+ open_ai_text_completion_models
|
||||||
|
@ -640,6 +668,7 @@ model_list = (
|
||||||
+ azure_text_models
|
+ azure_text_models
|
||||||
+ assemblyai_models
|
+ assemblyai_models
|
||||||
+ jina_ai_models
|
+ jina_ai_models
|
||||||
|
+ snowflake_models
|
||||||
)
|
)
|
||||||
|
|
||||||
model_list_set = set(model_list)
|
model_list_set = set(model_list)
|
||||||
|
@ -695,6 +724,7 @@ models_by_provider: dict = {
|
||||||
"sambanova": sambanova_models,
|
"sambanova": sambanova_models,
|
||||||
"assemblyai": assemblyai_models,
|
"assemblyai": assemblyai_models,
|
||||||
"jina_ai": jina_ai_models,
|
"jina_ai": jina_ai_models,
|
||||||
|
"snowflake": snowflake_models,
|
||||||
}
|
}
|
||||||
|
|
||||||
# mapping for those models which have larger equivalents
|
# mapping for those models which have larger equivalents
|
||||||
|
@ -928,6 +958,8 @@ from .llms.openai.chat.o_series_transformation import (
|
||||||
OpenAIOSeriesConfig,
|
OpenAIOSeriesConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .llms.snowflake.completion.transformation import SnowflakeConfig
|
||||||
|
|
||||||
openaiOSeriesConfig = OpenAIOSeriesConfig()
|
openaiOSeriesConfig = OpenAIOSeriesConfig()
|
||||||
from .llms.openai.chat.gpt_transformation import (
|
from .llms.openai.chat.gpt_transformation import (
|
||||||
OpenAIGPTConfig,
|
OpenAIGPTConfig,
|
||||||
|
|
|
@ -571,6 +571,14 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915
|
||||||
or "https://api.galadriel.com/v1"
|
or "https://api.galadriel.com/v1"
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
dynamic_api_key = api_key or get_secret_str("GALADRIEL_API_KEY")
|
dynamic_api_key = api_key or get_secret_str("GALADRIEL_API_KEY")
|
||||||
|
elif custom_llm_provider == "snowflake":
|
||||||
|
api_base = (
|
||||||
|
api_base
|
||||||
|
or get_secret("SNOWFLAKE_API_BASE")
|
||||||
|
or f"https://{get_secret('SNOWFLAKE_ACCOUNT_ID')}.snowflakecomputing.com/api/v2/cortex/inference:complete"
|
||||||
|
) # type: ignore
|
||||||
|
dynamic_api_key = api_key or get_secret("SNOWFLAKE_JWT") # Snowflake doesn't use API keys so this will have to change. Support of OAuth and JWT
|
||||||
|
|
||||||
if api_base is not None and not isinstance(api_base, str):
|
if api_base is not None and not isinstance(api_base, str):
|
||||||
raise Exception("api base needs to be a string. api_base={}".format(api_base))
|
raise Exception("api base needs to be a string. api_base={}".format(api_base))
|
||||||
if dynamic_api_key is not None and not isinstance(dynamic_api_key, str):
|
if dynamic_api_key is not None and not isinstance(dynamic_api_key, str):
|
||||||
|
|
40
litellm/llms/snowflake/common_utils.py
Normal file
40
litellm/llms/snowflake/common_utils.py
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
import httpx
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class SnowflakeBase:
|
||||||
|
def validate_environment(
|
||||||
|
self,
|
||||||
|
headers: dict,
|
||||||
|
JWT: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Return headers to use for Snowflake completion request
|
||||||
|
|
||||||
|
Snowflake REST API Ref: https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-llm-rest-api#api-reference
|
||||||
|
Expected headers:
|
||||||
|
{
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept": "application/json",
|
||||||
|
"Authorization": "Bearer " + <JWT>,
|
||||||
|
"X-Snowflake-Authorization-Token-Type": "KEYPAIR_JWT"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
if JWT is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Missing Snowflake JWT key"
|
||||||
|
)
|
||||||
|
|
||||||
|
headers.update(
|
||||||
|
{
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept": "application/json",
|
||||||
|
"Authorization": "Bearer " + JWT,
|
||||||
|
"X-Snowflake-Authorization-Token-Type": "KEYPAIR_JWT"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return headers
|
||||||
|
|
||||||
|
|
63
litellm/llms/snowflake/completion/handler.py
Normal file
63
litellm/llms/snowflake/completion/handler.py
Normal file
|
@ -0,0 +1,63 @@
|
||||||
|
from litellm.llms.base import BaseLLM
|
||||||
|
from typing import Any, List, Optional
|
||||||
|
from typing import List, Dict, Callable, Optional, Any,cast
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.utils import ModelResponse
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
|
|
||||||
|
from litellm.llms.openai_like.chat.handler import OpenAILikeChatHandler
|
||||||
|
from ..common_utils import SnowflakeBase
|
||||||
|
|
||||||
|
class SnowflakeChatCompletion(OpenAILikeChatHandler,SnowflakeBase):
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[Dict[str, Any]],
|
||||||
|
api_base: str,
|
||||||
|
acompletion: str,
|
||||||
|
custom_prompt_dict: Dict[str, Any],
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
encoding: Any,
|
||||||
|
JWT: str,
|
||||||
|
logging_obj: Any,
|
||||||
|
optional_params: Optional[Dict[str, Any]] = None,
|
||||||
|
litellm_params: Optional[Dict[str, Any]] = None,
|
||||||
|
logger_fn: Optional[Callable] = None,
|
||||||
|
headers: Optional[Dict[str, str]] = None,
|
||||||
|
client: Optional[Any] = None,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
messages = litellm.SnowflakeConfig()._transform_messages(
|
||||||
|
messages=cast(List[AllMessageValues], messages), model=model
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = self.validate_environment(
|
||||||
|
headers,
|
||||||
|
JWT
|
||||||
|
)
|
||||||
|
|
||||||
|
return super().completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
api_base=api_base,
|
||||||
|
custom_llm_provider= "snowflake",
|
||||||
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
api_key=JWT,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
acompletion=acompletion,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
headers=headers,
|
||||||
|
client=client,
|
||||||
|
custom_endpoint=True,
|
||||||
|
)
|
110
litellm/llms/snowflake/completion/transformation.py
Normal file
110
litellm/llms/snowflake/completion/transformation.py
Normal file
|
@ -0,0 +1,110 @@
|
||||||
|
'''
|
||||||
|
Support for Snowflake REST API
|
||||||
|
'''
|
||||||
|
import httpx
|
||||||
|
from typing import List, Optional, Union, Any
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
|
from litellm.types.utils import Choices, Message, ModelResponse, TextCompletionResponse
|
||||||
|
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||||
|
convert_content_list_to_str,
|
||||||
|
)
|
||||||
|
from ...openai_like.chat.transformation import OpenAILikeChatConfig
|
||||||
|
|
||||||
|
|
||||||
|
class SnowflakeConfig(OpenAILikeChatConfig):
|
||||||
|
"""
|
||||||
|
source: https://docs.snowflake.com/en/sql-reference/functions/complete-snowflake-cortex
|
||||||
|
|
||||||
|
The class `SnowflakeConfig` provides configuration for Snowflake's REST API interface. Below are the parameters:
|
||||||
|
|
||||||
|
- `temperature` (float, optional): A value between 0 and 1 that controls randomness. Lower temperatures mean lower randomness. Default: 0
|
||||||
|
|
||||||
|
- `top_p` (float, optional): Limits generation at each step to top `k` most likely tokens. Default: 0
|
||||||
|
|
||||||
|
- `max_tokens `(int, optional): The maximum number of tokens in the response. Default: 4096. Maximum allowed: 8192.
|
||||||
|
|
||||||
|
- `guardrails` (bool, optional): Whether to enable Cortex Guard to filter potentially unsafe responses. Default: False.
|
||||||
|
|
||||||
|
- `response_format` (str, optional): A JSON schema that the response should follow
|
||||||
|
"""
|
||||||
|
temperature: Optional[float]
|
||||||
|
top_p: Optional[float]
|
||||||
|
max_tokens: Optional[int]
|
||||||
|
guardrails: Optional[bool]
|
||||||
|
response_format: Optional[str]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
guardrails: Optional[bool] = None,
|
||||||
|
response_format: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals().copy()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return super().get_config()
|
||||||
|
|
||||||
|
def get_supported_openai_params(self, model: str) -> List:
|
||||||
|
return [
|
||||||
|
"temperature",
|
||||||
|
"max_tokens",
|
||||||
|
"top_p",
|
||||||
|
"response_format"
|
||||||
|
]
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self,
|
||||||
|
non_default_params: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
model: str,
|
||||||
|
drop_params: bool,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
If any supported_openai_params are in non_default_params, add them to optional_params, so they are used in API call
|
||||||
|
|
||||||
|
Args:
|
||||||
|
non_default_params (dict): Non-default parameters to filter.
|
||||||
|
optional_params (dict): Optional parameters to update.
|
||||||
|
model (str): Model name for parameter support check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Updated optional_params with supported non-default parameters.
|
||||||
|
"""
|
||||||
|
supported_openai_params = self.get_supported_openai_params(model)
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param in supported_openai_params:
|
||||||
|
optional_params[param] = value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
# def _transform_messages(
|
||||||
|
# self,
|
||||||
|
# model: str,
|
||||||
|
# messages: List[AllMessageValues],
|
||||||
|
# optional_params: dict,
|
||||||
|
# litellm_params: dict,
|
||||||
|
# headers: dict,
|
||||||
|
# ) -> dict:
|
||||||
|
# config = litellm.SnowflakeConfig.get_config()
|
||||||
|
# for k, v in config.items():
|
||||||
|
# if (
|
||||||
|
# k not in optional_params
|
||||||
|
# ):
|
||||||
|
# optional_params[k] = v
|
||||||
|
|
||||||
|
# text = " ".join(convert_content_list_to_str(message) for message in messages)
|
||||||
|
|
||||||
|
# data = {
|
||||||
|
# "text": text,
|
||||||
|
# **optional_params,
|
||||||
|
# }
|
||||||
|
|
||||||
|
# return data
|
|
@ -146,6 +146,7 @@ from .llms.openai_like.embedding.handler import OpenAILikeEmbeddingHandler
|
||||||
from .llms.petals.completion import handler as petals_handler
|
from .llms.petals.completion import handler as petals_handler
|
||||||
from .llms.predibase.chat.handler import PredibaseChatCompletion
|
from .llms.predibase.chat.handler import PredibaseChatCompletion
|
||||||
from .llms.replicate.chat.handler import completion as replicate_chat_completion
|
from .llms.replicate.chat.handler import completion as replicate_chat_completion
|
||||||
|
from .llms.snowflake.completion.handler import SnowflakeChatCompletion
|
||||||
from .llms.sagemaker.chat.handler import SagemakerChatHandler
|
from .llms.sagemaker.chat.handler import SagemakerChatHandler
|
||||||
from .llms.sagemaker.completion.handler import SagemakerLLM
|
from .llms.sagemaker.completion.handler import SagemakerLLM
|
||||||
from .llms.vertex_ai import vertex_ai_non_gemini
|
from .llms.vertex_ai import vertex_ai_non_gemini
|
||||||
|
@ -236,6 +237,7 @@ databricks_embedding = DatabricksEmbeddingHandler()
|
||||||
base_llm_http_handler = BaseLLMHTTPHandler()
|
base_llm_http_handler = BaseLLMHTTPHandler()
|
||||||
base_llm_aiohttp_handler = BaseLLMAIOHTTPHandler()
|
base_llm_aiohttp_handler = BaseLLMAIOHTTPHandler()
|
||||||
sagemaker_chat_completion = SagemakerChatHandler()
|
sagemaker_chat_completion = SagemakerChatHandler()
|
||||||
|
snow_flake_chat_completion = SnowflakeChatCompletion()
|
||||||
####### COMPLETION ENDPOINTS ################
|
####### COMPLETION ENDPOINTS ################
|
||||||
|
|
||||||
|
|
||||||
|
@ -2974,6 +2976,28 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
response = model_response
|
response = model_response
|
||||||
|
elif custom_llm_provider == "snowflake" or model in litellm.snowflake_models:
|
||||||
|
api_base = (
|
||||||
|
api_base
|
||||||
|
or f"""https://{get_secret("SNOWFLAKE_ACCOUNT_ID")}.snowflakecomputing.com/api/v2/cortex/inference:complete"""
|
||||||
|
or get_secret("SNOWFLAKE_API_BASE")
|
||||||
|
)
|
||||||
|
response = snow_flake_chat_completion.completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
api_base=api_base,
|
||||||
|
acompletion=acompletion,
|
||||||
|
custom_prompt_dict=litellm.custom_prompt_dict,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
encoding=encoding,
|
||||||
|
JWT=api_key,
|
||||||
|
logging_obj=logging,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
elif custom_llm_provider == "custom":
|
elif custom_llm_provider == "custom":
|
||||||
url = litellm.api_base or api_base or ""
|
url = litellm.api_base or api_base or ""
|
||||||
if url is None or url == "":
|
if url is None or url == "":
|
||||||
|
@ -3032,6 +3056,7 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
model_response.created = int(time.time())
|
model_response.created = int(time.time())
|
||||||
model_response.model = model
|
model_response.model = model
|
||||||
response = model_response
|
response = model_response
|
||||||
|
|
||||||
elif (
|
elif (
|
||||||
custom_llm_provider in litellm._custom_providers
|
custom_llm_provider in litellm._custom_providers
|
||||||
): # Assume custom LLM provider
|
): # Assume custom LLM provider
|
||||||
|
|
|
@ -1911,6 +1911,7 @@ class LlmProviders(str, Enum):
|
||||||
HUMANLOOP = "humanloop"
|
HUMANLOOP = "humanloop"
|
||||||
TOPAZ = "topaz"
|
TOPAZ = "topaz"
|
||||||
ASSEMBLYAI = "assemblyai"
|
ASSEMBLYAI = "assemblyai"
|
||||||
|
SNOWFLAKE = "snowflake"
|
||||||
|
|
||||||
|
|
||||||
# Create a set of all provider values for quick lookup
|
# Create a set of all provider values for quick lookup
|
||||||
|
|
9
snowflake_testing.py
Normal file
9
snowflake_testing.py
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
import os
|
||||||
|
from litellm import completion
|
||||||
|
|
||||||
|
os.environ["SNOWFLAKE_ACCOUNT_ID"] = "EBSRFJH-BI29448"
|
||||||
|
os.environ["SNOWFLAKE_JWT"] = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJFQlNSRkpILkJJMjk0NDguU0hBMjU2OjZXdVlwazZPSTBUNHhMb0VGaVVWRWN0R3V2cWsrOC9oVmJibzcwcUIrOFk9Iiwic3ViIjoiRUJTUkZKSC5CSTI5NDQ4IiwiaWF0IjoxNzQwOTc5NzEwLCJleHAiOjE3NDEwNjYxMTB9.XpI50hT1O6SbnNCeAfz2TFke_V4y3fBoaNaS230lg2eTTzhfVKoda0azCQDeYf8BTLSJjAjtjPuXbEgnoB1J0keQW9H8hJUItvRhfYnqN3ci_Ln4IoiLvwYM2BneoQ7pZdYrC3nxz0PBRxuMpkNTSp4FFoFwtbPhvzgHH5TMBJA3Kyt7Usr1RpNxJIIcR43M9wjCpovj_9wJlG2ry1dpqrB_aZTssnynLFnE9533V8WgLbtiX-balobjpZcPNUMZB_fv-aHGUT6wq5SOP2G0opbVBGq_NpW5R1ZF-oYVIXiaKxzfN_PK9RhbjHVVxZU-As4llKKlAYmC8ArFMOVsrA"
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "Write me a poem about the blue sky"}]
|
||||||
|
|
||||||
|
completion(model="snowflake/mistral-7b", messages=messages)
|
Loading…
Add table
Add a link
Reference in a new issue