From fd090c80435238d14325040f5bf2e3ac7f8dcb91 Mon Sep 17 00:00:00 2001 From: Sunny Wan Date: Mon, 3 Mar 2025 01:20:00 -0500 Subject: [PATCH 01/21] [FEAT] Added snowflake completion provider --- litellm/__init__.py | 32 +++++ .../get_llm_provider_logic.py | 8 ++ litellm/llms/snowflake/common_utils.py | 40 +++++++ litellm/llms/snowflake/completion/handler.py | 63 ++++++++++ .../snowflake/completion/transformation.py | 110 ++++++++++++++++++ litellm/main.py | 25 ++++ litellm/types/utils.py | 1 + snowflake_testing.py | 9 ++ 8 files changed, 288 insertions(+) create mode 100644 litellm/llms/snowflake/common_utils.py create mode 100644 litellm/llms/snowflake/completion/handler.py create mode 100644 litellm/llms/snowflake/completion/transformation.py create mode 100644 snowflake_testing.py diff --git a/litellm/__init__.py b/litellm/__init__.py index 60b8cf81a0..0086e20899 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -180,6 +180,7 @@ cloudflare_api_key: Optional[str] = None baseten_key: Optional[str] = None aleph_alpha_key: Optional[str] = None nlp_cloud_key: Optional[str] = None +snowflake_key: Optional[str] = None common_cloud_provider_auth_params: dict = { "params": ["project", "region_name", "token"], "providers": ["vertex_ai", "bedrock", "watsonx", "azure", "vertex_ai_beta"], @@ -414,6 +415,7 @@ cerebras_models: List = [] galadriel_models: List = [] sambanova_models: List = [] assemblyai_models: List = [] +snowflake_models: List = [] def is_bedrock_pricing_only_model(key: str) -> bool: @@ -567,6 +569,8 @@ def add_known_models(): assemblyai_models.append(key) elif value.get("litellm_provider") == "jina_ai": jina_ai_models.append(key) + elif value.get("litellm_provider") == "snowflake": + snowflake_models.append(key) add_known_models() @@ -596,6 +600,30 @@ ollama_models = ["llama2"] maritalk_models = ["maritalk"] +# Probably shouldn't hard code this, change later +snowflake_models = [ + "snowflake/deepseek-r1", + "snowflake/claude-3-5-sonnet", + "snowflake/llama3.2-1b", + "snowflake/llama3.2-3b", + "snowflake/llama3.1-8b", + "snowflake/llama3.1-70b", + "snowflake/llama3.3-70b", + "snowflake/snowflake-llama-3.3-70b", + "snowflake/llama3.1-405b", + "snowflake/snowflake-llama-3.1-405b", + "snowflake/snowflake-arctic", + "snowflake/reka-core", + "snowflake/reka-flash", + "snowflake/mistral-large2", + "snowflake/mixtral-8x7b", + "snowflake/mistral-7b", + "snowflake/jamba-instruct", + "snowflake/jamba-1.5-mini", + "snowflake/jamba-1.5-large", + "snowflake/gemma-7b" +] + model_list = ( open_ai_chat_completion_models + open_ai_text_completion_models @@ -640,6 +668,7 @@ model_list = ( + azure_text_models + assemblyai_models + jina_ai_models + + snowflake_models ) model_list_set = set(model_list) @@ -695,6 +724,7 @@ models_by_provider: dict = { "sambanova": sambanova_models, "assemblyai": assemblyai_models, "jina_ai": jina_ai_models, + "snowflake": snowflake_models, } # mapping for those models which have larger equivalents @@ -928,6 +958,8 @@ from .llms.openai.chat.o_series_transformation import ( OpenAIOSeriesConfig, ) +from .llms.snowflake.completion.transformation import SnowflakeConfig + openaiOSeriesConfig = OpenAIOSeriesConfig() from .llms.openai.chat.gpt_transformation import ( OpenAIGPTConfig, diff --git a/litellm/litellm_core_utils/get_llm_provider_logic.py b/litellm/litellm_core_utils/get_llm_provider_logic.py index a64e7dd700..5e9489d427 100644 --- a/litellm/litellm_core_utils/get_llm_provider_logic.py +++ b/litellm/litellm_core_utils/get_llm_provider_logic.py @@ -571,6 +571,14 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915 or "https://api.galadriel.com/v1" ) # type: ignore dynamic_api_key = api_key or get_secret_str("GALADRIEL_API_KEY") + elif custom_llm_provider == "snowflake": + api_base = ( + api_base + or get_secret("SNOWFLAKE_API_BASE") + or f"https://{get_secret('SNOWFLAKE_ACCOUNT_ID')}.snowflakecomputing.com/api/v2/cortex/inference:complete" + ) # type: ignore + dynamic_api_key = api_key or get_secret("SNOWFLAKE_JWT") # Snowflake doesn't use API keys so this will have to change. Support of OAuth and JWT + if api_base is not None and not isinstance(api_base, str): raise Exception("api base needs to be a string. api_base={}".format(api_base)) if dynamic_api_key is not None and not isinstance(dynamic_api_key, str): diff --git a/litellm/llms/snowflake/common_utils.py b/litellm/llms/snowflake/common_utils.py new file mode 100644 index 0000000000..5ceeee7403 --- /dev/null +++ b/litellm/llms/snowflake/common_utils.py @@ -0,0 +1,40 @@ +import httpx +from typing import List, Optional + + + +class SnowflakeBase: + def validate_environment( + self, + headers: dict, + JWT: Optional[str] = None, + ) -> dict: + """ + Return headers to use for Snowflake completion request + + Snowflake REST API Ref: https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-llm-rest-api#api-reference + Expected headers: + { + "Content-Type": "application/json", + "Accept": "application/json", + "Authorization": "Bearer " + , + "X-Snowflake-Authorization-Token-Type": "KEYPAIR_JWT" + } + """ + + if JWT is None: + raise ValueError( + "Missing Snowflake JWT key" + ) + + headers.update( + { + "Content-Type": "application/json", + "Accept": "application/json", + "Authorization": "Bearer " + JWT, + "X-Snowflake-Authorization-Token-Type": "KEYPAIR_JWT" + } + ) + return headers + + \ No newline at end of file diff --git a/litellm/llms/snowflake/completion/handler.py b/litellm/llms/snowflake/completion/handler.py new file mode 100644 index 0000000000..a330ca83c9 --- /dev/null +++ b/litellm/llms/snowflake/completion/handler.py @@ -0,0 +1,63 @@ +from litellm.llms.base import BaseLLM +from typing import Any, List, Optional +from typing import List, Dict, Callable, Optional, Any,cast + +import litellm +from litellm.utils import ModelResponse +from litellm.types.llms.openai import AllMessageValues + +from litellm.llms.openai_like.chat.handler import OpenAILikeChatHandler +from ..common_utils import SnowflakeBase + +class SnowflakeChatCompletion(OpenAILikeChatHandler,SnowflakeBase): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def completion( + self, + model: str, + messages: List[Dict[str, Any]], + api_base: str, + acompletion: str, + custom_prompt_dict: Dict[str, Any], + model_response: ModelResponse, + print_verbose: Callable, + encoding: Any, + JWT: str, + logging_obj: Any, + optional_params: Optional[Dict[str, Any]] = None, + litellm_params: Optional[Dict[str, Any]] = None, + logger_fn: Optional[Callable] = None, + headers: Optional[Dict[str, str]] = None, + client: Optional[Any] = None, + ) -> None: + + messages = litellm.SnowflakeConfig()._transform_messages( + messages=cast(List[AllMessageValues], messages), model=model + ) + + headers = self.validate_environment( + headers, + JWT + ) + + return super().completion( + model=model, + messages=messages, + api_base=api_base, + custom_llm_provider= "snowflake", + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + api_key=JWT, + logging_obj=logging_obj, + optional_params=optional_params, + acompletion=acompletion, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, + client=client, + custom_endpoint=True, + ) diff --git a/litellm/llms/snowflake/completion/transformation.py b/litellm/llms/snowflake/completion/transformation.py new file mode 100644 index 0000000000..1208ce3586 --- /dev/null +++ b/litellm/llms/snowflake/completion/transformation.py @@ -0,0 +1,110 @@ +''' +Support for Snowflake REST API +''' +import httpx +from typing import List, Optional, Union, Any + +import litellm +from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException +from litellm.types.llms.openai import AllMessageValues +from litellm.types.utils import Choices, Message, ModelResponse, TextCompletionResponse +from litellm.litellm_core_utils.prompt_templates.common_utils import ( + convert_content_list_to_str, +) +from ...openai_like.chat.transformation import OpenAILikeChatConfig + + +class SnowflakeConfig(OpenAILikeChatConfig): + """ + source: https://docs.snowflake.com/en/sql-reference/functions/complete-snowflake-cortex + + The class `SnowflakeConfig` provides configuration for Snowflake's REST API interface. Below are the parameters: + + - `temperature` (float, optional): A value between 0 and 1 that controls randomness. Lower temperatures mean lower randomness. Default: 0 + + - `top_p` (float, optional): Limits generation at each step to top `k` most likely tokens. Default: 0 + + - `max_tokens `(int, optional): The maximum number of tokens in the response. Default: 4096. Maximum allowed: 8192. + + - `guardrails` (bool, optional): Whether to enable Cortex Guard to filter potentially unsafe responses. Default: False. + + - `response_format` (str, optional): A JSON schema that the response should follow + """ + temperature: Optional[float] + top_p: Optional[float] + max_tokens: Optional[int] + guardrails: Optional[bool] + response_format: Optional[str] + + def __init__( + self, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + max_tokens: Optional[int] = None, + guardrails: Optional[bool] = None, + response_format: Optional[str] = None, + ) -> None: + locals_ = locals().copy() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return super().get_config() + + def get_supported_openai_params(self, model: str) -> List: + return [ + "temperature", + "max_tokens", + "top_p", + "response_format" + ] + + def map_openai_params( + self, + non_default_params: dict, + optional_params: dict, + model: str, + drop_params: bool, + ) -> dict: + """ + If any supported_openai_params are in non_default_params, add them to optional_params, so they are used in API call + + Args: + non_default_params (dict): Non-default parameters to filter. + optional_params (dict): Optional parameters to update. + model (str): Model name for parameter support check. + + Returns: + dict: Updated optional_params with supported non-default parameters. + """ + supported_openai_params = self.get_supported_openai_params(model) + for param, value in non_default_params.items(): + if param in supported_openai_params: + optional_params[param] = value + return optional_params + + # def _transform_messages( + # self, + # model: str, + # messages: List[AllMessageValues], + # optional_params: dict, + # litellm_params: dict, + # headers: dict, + # ) -> dict: + # config = litellm.SnowflakeConfig.get_config() + # for k, v in config.items(): + # if ( + # k not in optional_params + # ): + # optional_params[k] = v + + # text = " ".join(convert_content_list_to_str(message) for message in messages) + + # data = { + # "text": text, + # **optional_params, + # } + + # return data \ No newline at end of file diff --git a/litellm/main.py b/litellm/main.py index 57bcda61fd..e4abbd8458 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -146,6 +146,7 @@ from .llms.openai_like.embedding.handler import OpenAILikeEmbeddingHandler from .llms.petals.completion import handler as petals_handler from .llms.predibase.chat.handler import PredibaseChatCompletion from .llms.replicate.chat.handler import completion as replicate_chat_completion +from .llms.snowflake.completion.handler import SnowflakeChatCompletion from .llms.sagemaker.chat.handler import SagemakerChatHandler from .llms.sagemaker.completion.handler import SagemakerLLM from .llms.vertex_ai import vertex_ai_non_gemini @@ -236,6 +237,7 @@ databricks_embedding = DatabricksEmbeddingHandler() base_llm_http_handler = BaseLLMHTTPHandler() base_llm_aiohttp_handler = BaseLLMAIOHTTPHandler() sagemaker_chat_completion = SagemakerChatHandler() +snow_flake_chat_completion = SnowflakeChatCompletion() ####### COMPLETION ENDPOINTS ################ @@ -2974,6 +2976,28 @@ def completion( # type: ignore # noqa: PLR0915 ) return response response = model_response + elif custom_llm_provider == "snowflake" or model in litellm.snowflake_models: + api_base = ( + api_base + or f"""https://{get_secret("SNOWFLAKE_ACCOUNT_ID")}.snowflakecomputing.com/api/v2/cortex/inference:complete""" + or get_secret("SNOWFLAKE_API_BASE") + ) + response = snow_flake_chat_completion.completion( + model=model, + messages=messages, + api_base=api_base, + acompletion=acompletion, + custom_prompt_dict=litellm.custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + JWT=api_key, + logging_obj=logging, + headers=headers, + ) elif custom_llm_provider == "custom": url = litellm.api_base or api_base or "" if url is None or url == "": @@ -3032,6 +3056,7 @@ def completion( # type: ignore # noqa: PLR0915 model_response.created = int(time.time()) model_response.model = model response = model_response + elif ( custom_llm_provider in litellm._custom_providers ): # Assume custom LLM provider diff --git a/litellm/types/utils.py b/litellm/types/utils.py index dcaf5f35d1..b4de66d2fa 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1911,6 +1911,7 @@ class LlmProviders(str, Enum): HUMANLOOP = "humanloop" TOPAZ = "topaz" ASSEMBLYAI = "assemblyai" + SNOWFLAKE = "snowflake" # Create a set of all provider values for quick lookup diff --git a/snowflake_testing.py b/snowflake_testing.py new file mode 100644 index 0000000000..0ed59cc43c --- /dev/null +++ b/snowflake_testing.py @@ -0,0 +1,9 @@ +import os +from litellm import completion + +os.environ["SNOWFLAKE_ACCOUNT_ID"] = "EBSRFJH-BI29448" +os.environ["SNOWFLAKE_JWT"] = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJFQlNSRkpILkJJMjk0NDguU0hBMjU2OjZXdVlwazZPSTBUNHhMb0VGaVVWRWN0R3V2cWsrOC9oVmJibzcwcUIrOFk9Iiwic3ViIjoiRUJTUkZKSC5CSTI5NDQ4IiwiaWF0IjoxNzQwOTc5NzEwLCJleHAiOjE3NDEwNjYxMTB9.XpI50hT1O6SbnNCeAfz2TFke_V4y3fBoaNaS230lg2eTTzhfVKoda0azCQDeYf8BTLSJjAjtjPuXbEgnoB1J0keQW9H8hJUItvRhfYnqN3ci_Ln4IoiLvwYM2BneoQ7pZdYrC3nxz0PBRxuMpkNTSp4FFoFwtbPhvzgHH5TMBJA3Kyt7Usr1RpNxJIIcR43M9wjCpovj_9wJlG2ry1dpqrB_aZTssnynLFnE9533V8WgLbtiX-balobjpZcPNUMZB_fv-aHGUT6wq5SOP2G0opbVBGq_NpW5R1ZF-oYVIXiaKxzfN_PK9RhbjHVVxZU-As4llKKlAYmC8ArFMOVsrA" + +messages = [{"role": "user", "content": "Write me a poem about the blue sky"}] + +completion(model="snowflake/mistral-7b", messages=messages) \ No newline at end of file From b87704cc34a1c857d2e017001ea01d6b349924bc Mon Sep 17 00:00:00 2001 From: Sunny Wan Date: Mon, 3 Mar 2025 01:34:44 -0500 Subject: [PATCH 02/21] [CHORE] Fixed some style issues and leaks --- litellm/llms/snowflake/common_utils.py | 2 -- litellm/llms/snowflake/completion/handler.py | 2 +- snowflake_testing.py | 4 ++-- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/litellm/llms/snowflake/common_utils.py b/litellm/llms/snowflake/common_utils.py index 5ceeee7403..ab8d37bcf1 100644 --- a/litellm/llms/snowflake/common_utils.py +++ b/litellm/llms/snowflake/common_utils.py @@ -1,8 +1,6 @@ import httpx from typing import List, Optional - - class SnowflakeBase: def validate_environment( self, diff --git a/litellm/llms/snowflake/completion/handler.py b/litellm/llms/snowflake/completion/handler.py index a330ca83c9..039ae78351 100644 --- a/litellm/llms/snowflake/completion/handler.py +++ b/litellm/llms/snowflake/completion/handler.py @@ -19,7 +19,7 @@ class SnowflakeChatCompletion(OpenAILikeChatHandler,SnowflakeBase): model: str, messages: List[Dict[str, Any]], api_base: str, - acompletion: str, + acompletion: bool, custom_prompt_dict: Dict[str, Any], model_response: ModelResponse, print_verbose: Callable, diff --git a/snowflake_testing.py b/snowflake_testing.py index 0ed59cc43c..7e779b5803 100644 --- a/snowflake_testing.py +++ b/snowflake_testing.py @@ -1,8 +1,8 @@ import os from litellm import completion -os.environ["SNOWFLAKE_ACCOUNT_ID"] = "EBSRFJH-BI29448" -os.environ["SNOWFLAKE_JWT"] = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJFQlNSRkpILkJJMjk0NDguU0hBMjU2OjZXdVlwazZPSTBUNHhMb0VGaVVWRWN0R3V2cWsrOC9oVmJibzcwcUIrOFk9Iiwic3ViIjoiRUJTUkZKSC5CSTI5NDQ4IiwiaWF0IjoxNzQwOTc5NzEwLCJleHAiOjE3NDEwNjYxMTB9.XpI50hT1O6SbnNCeAfz2TFke_V4y3fBoaNaS230lg2eTTzhfVKoda0azCQDeYf8BTLSJjAjtjPuXbEgnoB1J0keQW9H8hJUItvRhfYnqN3ci_Ln4IoiLvwYM2BneoQ7pZdYrC3nxz0PBRxuMpkNTSp4FFoFwtbPhvzgHH5TMBJA3Kyt7Usr1RpNxJIIcR43M9wjCpovj_9wJlG2ry1dpqrB_aZTssnynLFnE9533V8WgLbtiX-balobjpZcPNUMZB_fv-aHGUT6wq5SOP2G0opbVBGq_NpW5R1ZF-oYVIXiaKxzfN_PK9RhbjHVVxZU-As4llKKlAYmC8ArFMOVsrA" +os.environ["SNOWFLAKE_ACCOUNT_ID"] = "YOUR ACCOUNT" +os.environ["SNOWFLAKE_JWT"] = "YOUR JWT" messages = [{"role": "user", "content": "Write me a poem about the blue sky"}] From 61ee71745a8fe02d69b36df7b5b49bc34effccd4 Mon Sep 17 00:00:00 2001 From: Sunny Wan Date: Mon, 3 Mar 2025 01:42:48 -0500 Subject: [PATCH 03/21] [CHORE] Added proper typing --- litellm/llms/snowflake/completion/handler.py | 22 ++++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/litellm/llms/snowflake/completion/handler.py b/litellm/llms/snowflake/completion/handler.py index 039ae78351..85ec676606 100644 --- a/litellm/llms/snowflake/completion/handler.py +++ b/litellm/llms/snowflake/completion/handler.py @@ -1,11 +1,11 @@ from litellm.llms.base import BaseLLM from typing import Any, List, Optional -from typing import List, Dict, Callable, Optional, Any,cast +from typing import List, Dict, Callable, Optional, Any, cast, Union import litellm from litellm.utils import ModelResponse from litellm.types.llms.openai import AllMessageValues - +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.openai_like.chat.handler import OpenAILikeChatHandler from ..common_utils import SnowflakeBase @@ -19,18 +19,18 @@ class SnowflakeChatCompletion(OpenAILikeChatHandler,SnowflakeBase): model: str, messages: List[Dict[str, Any]], api_base: str, - acompletion: bool, - custom_prompt_dict: Dict[str, Any], + custom_prompt_dict: dict, model_response: ModelResponse, print_verbose: Callable, - encoding: Any, + encoding, JWT: str, - logging_obj: Any, - optional_params: Optional[Dict[str, Any]] = None, - litellm_params: Optional[Dict[str, Any]] = None, - logger_fn: Optional[Callable] = None, - headers: Optional[Dict[str, str]] = None, - client: Optional[Any] = None, + logging_obj, + optional_params: dict, + acompletion=None, + litellm_params=None, + logger_fn=None, + headers: Optional[dict] = None, + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, ) -> None: messages = litellm.SnowflakeConfig()._transform_messages( From 162ea295e94d594678ef044c1883ab7be3d47142 Mon Sep 17 00:00:00 2001 From: Sunny Wan Date: Mon, 3 Mar 2025 01:45:13 -0500 Subject: [PATCH 04/21] [CHORE] Removed old code --- .../snowflake/completion/transformation.py | 26 +------------------ 1 file changed, 1 insertion(+), 25 deletions(-) diff --git a/litellm/llms/snowflake/completion/transformation.py b/litellm/llms/snowflake/completion/transformation.py index 1208ce3586..79b25925e7 100644 --- a/litellm/llms/snowflake/completion/transformation.py +++ b/litellm/llms/snowflake/completion/transformation.py @@ -83,28 +83,4 @@ class SnowflakeConfig(OpenAILikeChatConfig): for param, value in non_default_params.items(): if param in supported_openai_params: optional_params[param] = value - return optional_params - - # def _transform_messages( - # self, - # model: str, - # messages: List[AllMessageValues], - # optional_params: dict, - # litellm_params: dict, - # headers: dict, - # ) -> dict: - # config = litellm.SnowflakeConfig.get_config() - # for k, v in config.items(): - # if ( - # k not in optional_params - # ): - # optional_params[k] = v - - # text = " ".join(convert_content_list_to_str(message) for message in messages) - - # data = { - # "text": text, - # **optional_params, - # } - - # return data \ No newline at end of file + return optional_params \ No newline at end of file From c413686eadb0f63d2aca4d2cc68f62dfdfbb6a31 Mon Sep 17 00:00:00 2001 From: Sunny Wan Date: Mon, 3 Mar 2025 17:49:11 -0500 Subject: [PATCH 05/21] wrote tests for snowflake --- snowflake_testing.py | 9 --- tests/llm_translation/test_snowflake.py | 78 +++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 9 deletions(-) delete mode 100644 snowflake_testing.py create mode 100644 tests/llm_translation/test_snowflake.py diff --git a/snowflake_testing.py b/snowflake_testing.py deleted file mode 100644 index 7e779b5803..0000000000 --- a/snowflake_testing.py +++ /dev/null @@ -1,9 +0,0 @@ -import os -from litellm import completion - -os.environ["SNOWFLAKE_ACCOUNT_ID"] = "YOUR ACCOUNT" -os.environ["SNOWFLAKE_JWT"] = "YOUR JWT" - -messages = [{"role": "user", "content": "Write me a poem about the blue sky"}] - -completion(model="snowflake/mistral-7b", messages=messages) \ No newline at end of file diff --git a/tests/llm_translation/test_snowflake.py b/tests/llm_translation/test_snowflake.py new file mode 100644 index 0000000000..e64fb8d6d1 --- /dev/null +++ b/tests/llm_translation/test_snowflake.py @@ -0,0 +1,78 @@ +import os +import sys +import traceback +from dotenv import load_dotenv + +load_dotenv() +import pytest + +from litellm import completion, acompletion, set_verbose + +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_chat_completion_snowflake(sync_mode): + try: + messages = [ + { + "role": "user", + "content": "Write me a poem about the blue sky", + }, + ] + + if sync_mode: + response = completion( + model="snowflake/mistral-7b", + messages=messages, + ) + print(response) + assert response is not None + else: + response = await acompletion( + model="snowflake/mistral-7b", + messages=messages, + ) + print(response) + assert response is not None + except Exception as e: + pytest.fail(f"Error occurred: {e}") + +@pytest.mark.asyncio +@pytest.mark.parametrize("sync_mode", [True, False]) +async def test_chat_completion_snowflake_stream(sync_mode): + try: + set_verbose = True + messages = [ + { + "role": "user", + "content": "Write me a poem about the blue sky", + }, + ] + + if sync_mode is False: + response = await acompletion( + model="snowflake/mistral-7b", + messages=messages, + max_tokens=100, + stream=True, + ) + + chunk_count = 0 + async for chunk in response: + print(chunk) + chunk_count += 1 + assert chunk_count > 0 + else: + response = completion( + model="snowflake/mistral-7b", + messages=messages, + max_tokens=100, + stream=True, + ) + + chunk_count = 0 + for chunk in response: + print(chunk) + chunk_count += 1 + assert chunk_count > 0 + except Exception as e: + pytest.fail(f"Error occurred: {e}") From 4d61ac5f134685e583a7e35217a7c7238abf5552 Mon Sep 17 00:00:00 2001 From: Sunny Wan Date: Mon, 3 Mar 2025 18:11:33 -0500 Subject: [PATCH 06/21] Added models to model_prices_and_context --- ...odel_prices_and_context_window_backup.json | 168 ++++++++++++++++++ 1 file changed, 168 insertions(+) diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index 96076fa3b8..d465506a3a 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -9461,5 +9461,173 @@ "output_cost_per_token": 0.000000018, "litellm_provider": "jina_ai", "mode": "rerank" + }, + "snowflake/deepseek-r1": { + "max_tokens": 32768, + "max_input_tokens": 32768, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/snowflake-arctic": { + "max_tokens": 4096, + "max_input_tokens": 4096, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/claude-3-5-sonnet": { + "max_tokens": 18000, + "max_input_tokens": 18000, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/mistral-large": { + "max_tokens": 32000, + "max_input_tokens": 32000, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/mistral-large2": { + "max_tokens": 128000, + "max_input_tokens": 128000, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/reka-flash": { + "max_tokens": 100000, + "max_input_tokens": 100000, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/reka-core": { + "max_tokens": 32000, + "max_input_tokens": 32000, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/jamba-instruct": { + "max_tokens": 256000, + "max_input_tokens": 256000, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/jamba-1.5-mini": { + "max_tokens": 256000, + "max_input_tokens": 256000, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/jamba-1.5-large": { + "max_tokens": 256000, + "max_input_tokens": 256000, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/mixtral-8x7b": { + "max_tokens": 32000, + "max_input_tokens": 32000, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/llama2-70b-chat": { + "max_tokens": 4096, + "max_input_tokens": 4096, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/llama3-8b": { + "max_tokens": 8000, + "max_input_tokens": 8000, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/llama3-70b": { + "max_tokens": 8000, + "max_input_tokens": 8000, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/llama3.1-8b": { + "max_tokens": 128000, + "max_input_tokens": 128000, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/llama3.1-70b": { + "max_tokens": 128000, + "max_input_tokens": 128000, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/llama3.3-70b": { + "max_tokens": 128000, + "max_input_tokens": 128000, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/snowflake-llama-3.3-70b": { + "max_tokens": 8000, + "max_input_tokens": 8000, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/llama3.1-405b": { + "max_tokens": 128000, + "max_input_tokens": 128000, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/snowflake-llama-3.1-405b": { + "max_tokens": 8000, + "max_input_tokens": 8000, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/llama3.2-1b": { + "max_tokens": 128000, + "max_input_tokens": 128000, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/llama3.2-3b": { + "max_tokens": 128000, + "max_input_tokens": 128000, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/mistral-7b": { + "max_tokens": 32000, + "max_input_tokens": 32000, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/gemma-7b": { + "max_tokens": 8000, + "max_input_tokens": 8000, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" } } From bdd03405fe47e7deb1e9e447313f38176a2f82e6 Mon Sep 17 00:00:00 2001 From: Sunny Wan Date: Mon, 3 Mar 2025 18:18:24 -0500 Subject: [PATCH 07/21] Removed unnecessary comments --- litellm/__init__.py | 1 - litellm/litellm_core_utils/get_llm_provider_logic.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/litellm/__init__.py b/litellm/__init__.py index 0086e20899..db386f7ebd 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -600,7 +600,6 @@ ollama_models = ["llama2"] maritalk_models = ["maritalk"] -# Probably shouldn't hard code this, change later snowflake_models = [ "snowflake/deepseek-r1", "snowflake/claude-3-5-sonnet", diff --git a/litellm/litellm_core_utils/get_llm_provider_logic.py b/litellm/litellm_core_utils/get_llm_provider_logic.py index 5e9489d427..0bf74c5dca 100644 --- a/litellm/litellm_core_utils/get_llm_provider_logic.py +++ b/litellm/litellm_core_utils/get_llm_provider_logic.py @@ -577,7 +577,7 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915 or get_secret("SNOWFLAKE_API_BASE") or f"https://{get_secret('SNOWFLAKE_ACCOUNT_ID')}.snowflakecomputing.com/api/v2/cortex/inference:complete" ) # type: ignore - dynamic_api_key = api_key or get_secret("SNOWFLAKE_JWT") # Snowflake doesn't use API keys so this will have to change. Support of OAuth and JWT + dynamic_api_key = api_key or get_secret("SNOWFLAKE_JWT") if api_base is not None and not isinstance(api_base, str): raise Exception("api base needs to be a string. api_base={}".format(api_base)) From 02dd126be9e6d12c945156bfb518737c57d115ac Mon Sep 17 00:00:00 2001 From: Sunny Wan Date: Tue, 4 Mar 2025 17:13:00 -0500 Subject: [PATCH 08/21] added documentation for snowflake --- docs/my-website/docs/providers/snowflake.md | 101 ++++++++++++++++++++ docs/my-website/sidebars.js | 1 + 2 files changed, 102 insertions(+) create mode 100644 docs/my-website/docs/providers/snowflake.md diff --git a/docs/my-website/docs/providers/snowflake.md b/docs/my-website/docs/providers/snowflake.md new file mode 100644 index 0000000000..04b46b09e1 --- /dev/null +++ b/docs/my-website/docs/providers/snowflake.md @@ -0,0 +1,101 @@ +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + + +# Snowflake +LiteLLM supports all Snowflake models. +- `snowflake/deepseek-r1` +- `snowflake/claude-3-5-sonnet` +- `snowflake/llama3.2-1b` +- `snowflake/llama3.2-3b` +- `snowflake/llama3.1-8b` +- `snowflake/llama3.1-70b` +- `snowflake/llama3.3-70b` +- `snowflake/snowflake-llama-3.3-70b` +- `snowflake/llama3.1-405b` +- `snowflake/snowflake-llama-3.1-405b` +- `snowflake/snowflake-arctic` +- `snowflake/reka-core` +- `snowflake/reka-flash` +- `snowflake/mistral-large2` +- `snowflake/mixtral-8x7b` +- `snowflake/mistral-7b` +- `snowflake/jamba-instruct` +- `snowflake/jamba-1.5-mini` +- `snowflake/jamba-1.5-large` +- `snowflake/gemma-7b` + +Currently, Snowflake's REST API does not have an endpoint for `snowflake-arctic-embed` embedding models. If you want to use these embedding models with Litellm, you can call them through our Hugging Face provider. + +Find the Arctic Embed models [here](https://huggingface.co/collections/Snowflake/arctic-embed-661fd57d50fab5fc314e4c18) on Hugging Face. +## Supported OpenAI Parameters +``` + "temperature", + "max_tokens", + "top_p", + "response_format" +``` + +## API KEYS + +Snowflake does have API keys. Instead, you access the Snowflake API with your JWT token and account identifier. + +```python +import os +os.environ["SNOWFLAKE_JWT"] = "YOUR JWT" +os.environ["SNOWFLAKE_ACCOUNT_ID"] = "YOUR ACCOUNT IDENTIFIER" +``` +## Usage + +```python +from litellm import completion + +## set ENV variables +os.environ["SNOWFLAKE_JWT"] = "YOUR JWT" +os.environ["SNOWFLAKE_ACCOUNT_ID"] = "YOUR ACCOUNT IDENTIFIER" + +# Snowflake call +response = completion( + model="snowflake/mistral-7b", + messages = [{ "content": "Hello, how are you?","role": "user"}] +) +``` + +## Usage with LiteLLM Proxy + +#### 1. Required env variables +```bash +export SNOWFLAKE_JWT="" +export SNOWFLAKE_ACCOUNT_ID = "" +``` + +#### 2. Start the proxy~ +```yaml +model_list: + - model_name: mistral-7b + litellm_params: + model: snowflake/mistral-7b + api_key: YOUR_API_KEY + api_base: https://YOUR-ACCOUNT-ID.snowflakecomputing.com/api/v2/cortex/inference:complete + +``` + +```bash +litellm --config /path/to/config.yaml +``` + +#### 3. Test it +```shell +curl --location 'http://0.0.0.0:4000/chat/completions' \ +--header 'Content-Type: application/json' \ +--data ' { + "model": "snowflake/mistral-7b", + "messages": [ + { + "role": "user", + "content": "Hello, how are you?" + } + ] + } +' +``` diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index 51c4c8c21e..8315537a17 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -230,6 +230,7 @@ const sidebars = { "providers/sambanova", "providers/custom_llm_server", "providers/petals", + "providers/snowflake" ], }, { From a2fed4059e151222e9c1673e39a8fee53b60700e Mon Sep 17 00:00:00 2001 From: Sunny Wan Date: Wed, 5 Mar 2025 20:32:18 -0500 Subject: [PATCH 09/21] added Snowflake config to ProviderConfigManager --- litellm/__init__.py | 1 + litellm/utils.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/litellm/__init__.py b/litellm/__init__.py index db386f7ebd..da89bba901 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -843,6 +843,7 @@ from .llms.databricks.embed.transformation import DatabricksEmbeddingConfig from .llms.predibase.chat.transformation import PredibaseConfig from .llms.replicate.chat.transformation import ReplicateConfig from .llms.cohere.completion.transformation import CohereTextConfig as CohereConfig +from .llms.snowflake.completion.transformation import SnowflakeConfig from .llms.cohere.rerank.transformation import CohereRerankConfig from .llms.cohere.rerank_v2.transformation import CohereRerankV2Config from .llms.azure_ai.rerank.transformation import AzureAIRerankConfig diff --git a/litellm/utils.py b/litellm/utils.py index cbd5e2d0d3..5f01f9f1fc 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -6071,6 +6071,8 @@ class ProviderConfigManager: return litellm.CohereChatConfig() elif litellm.LlmProviders.COHERE == provider: return litellm.CohereConfig() + elif litellm.LlmProviders.SNOWFLAKE == provider: + return litellm.SnowflakeConfig() elif litellm.LlmProviders.CLARIFAI == provider: return litellm.ClarifaiConfig() elif litellm.LlmProviders.ANTHROPIC == provider: From 1b611477f2e35b6228bb846378861012f23a4789 Mon Sep 17 00:00:00 2001 From: Sunny Wan Date: Mon, 10 Mar 2025 18:03:17 -0400 Subject: [PATCH 10/21] removed supported models from docs --- docs/my-website/docs/providers/snowflake.md | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/docs/my-website/docs/providers/snowflake.md b/docs/my-website/docs/providers/snowflake.md index 04b46b09e1..93b6f4f2c5 100644 --- a/docs/my-website/docs/providers/snowflake.md +++ b/docs/my-website/docs/providers/snowflake.md @@ -3,27 +3,6 @@ import TabItem from '@theme/TabItem'; # Snowflake -LiteLLM supports all Snowflake models. -- `snowflake/deepseek-r1` -- `snowflake/claude-3-5-sonnet` -- `snowflake/llama3.2-1b` -- `snowflake/llama3.2-3b` -- `snowflake/llama3.1-8b` -- `snowflake/llama3.1-70b` -- `snowflake/llama3.3-70b` -- `snowflake/snowflake-llama-3.3-70b` -- `snowflake/llama3.1-405b` -- `snowflake/snowflake-llama-3.1-405b` -- `snowflake/snowflake-arctic` -- `snowflake/reka-core` -- `snowflake/reka-flash` -- `snowflake/mistral-large2` -- `snowflake/mixtral-8x7b` -- `snowflake/mistral-7b` -- `snowflake/jamba-instruct` -- `snowflake/jamba-1.5-mini` -- `snowflake/jamba-1.5-large` -- `snowflake/gemma-7b` Currently, Snowflake's REST API does not have an endpoint for `snowflake-arctic-embed` embedding models. If you want to use these embedding models with Litellm, you can call them through our Hugging Face provider. From a775c9ca1385345684c2beaf4a576f2b85c422ef Mon Sep 17 00:00:00 2001 From: Sunny Wan Date: Tue, 11 Mar 2025 02:00:52 -0400 Subject: [PATCH 11/21] removed handler and refactored to deepseek/chat format --- litellm/llms/snowflake/completion/handler.py | 63 ----- .../snowflake/completion/transformation.py | 215 ++++++++++++++---- litellm/main.py | 54 +++-- 3 files changed, 204 insertions(+), 128 deletions(-) delete mode 100644 litellm/llms/snowflake/completion/handler.py diff --git a/litellm/llms/snowflake/completion/handler.py b/litellm/llms/snowflake/completion/handler.py deleted file mode 100644 index 85ec676606..0000000000 --- a/litellm/llms/snowflake/completion/handler.py +++ /dev/null @@ -1,63 +0,0 @@ -from litellm.llms.base import BaseLLM -from typing import Any, List, Optional -from typing import List, Dict, Callable, Optional, Any, cast, Union - -import litellm -from litellm.utils import ModelResponse -from litellm.types.llms.openai import AllMessageValues -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler -from litellm.llms.openai_like.chat.handler import OpenAILikeChatHandler -from ..common_utils import SnowflakeBase - -class SnowflakeChatCompletion(OpenAILikeChatHandler,SnowflakeBase): - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def completion( - self, - model: str, - messages: List[Dict[str, Any]], - api_base: str, - custom_prompt_dict: dict, - model_response: ModelResponse, - print_verbose: Callable, - encoding, - JWT: str, - logging_obj, - optional_params: dict, - acompletion=None, - litellm_params=None, - logger_fn=None, - headers: Optional[dict] = None, - client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, - ) -> None: - - messages = litellm.SnowflakeConfig()._transform_messages( - messages=cast(List[AllMessageValues], messages), model=model - ) - - headers = self.validate_environment( - headers, - JWT - ) - - return super().completion( - model=model, - messages=messages, - api_base=api_base, - custom_llm_provider= "snowflake", - custom_prompt_dict=custom_prompt_dict, - model_response=model_response, - print_verbose=print_verbose, - encoding=encoding, - api_key=JWT, - logging_obj=logging_obj, - optional_params=optional_params, - acompletion=acompletion, - litellm_params=litellm_params, - logger_fn=logger_fn, - headers=headers, - client=client, - custom_endpoint=True, - ) diff --git a/litellm/llms/snowflake/completion/transformation.py b/litellm/llms/snowflake/completion/transformation.py index 79b25925e7..48593cf0db 100644 --- a/litellm/llms/snowflake/completion/transformation.py +++ b/litellm/llms/snowflake/completion/transformation.py @@ -2,52 +2,27 @@ Support for Snowflake REST API ''' import httpx -from typing import List, Optional, Union, Any +from typing import List, Optional, Tuple, Any, TYPE_CHECKING -import litellm -from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException from litellm.types.llms.openai import AllMessageValues -from litellm.types.utils import Choices, Message, ModelResponse, TextCompletionResponse -from litellm.litellm_core_utils.prompt_templates.common_utils import ( - convert_content_list_to_str, -) -from ...openai_like.chat.transformation import OpenAILikeChatConfig +from litellm.utils import get_secret +from litellm.types.utils import ModelResponse +from litellm.types.llms.openai import ChatCompletionAssistantMessage +from litellm.llms.databricks.streaming_utils import ModelResponseIterator +from ...openai_like.chat.transformation import OpenAIGPTConfig -class SnowflakeConfig(OpenAILikeChatConfig): +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj + + LiteLLMLoggingObj = _LiteLLMLoggingObj +else: + LiteLLMLoggingObj = Any + +class SnowflakeConfig(OpenAIGPTConfig): """ source: https://docs.snowflake.com/en/sql-reference/functions/complete-snowflake-cortex - - The class `SnowflakeConfig` provides configuration for Snowflake's REST API interface. Below are the parameters: - - - `temperature` (float, optional): A value between 0 and 1 that controls randomness. Lower temperatures mean lower randomness. Default: 0 - - - `top_p` (float, optional): Limits generation at each step to top `k` most likely tokens. Default: 0 - - - `max_tokens `(int, optional): The maximum number of tokens in the response. Default: 4096. Maximum allowed: 8192. - - - `guardrails` (bool, optional): Whether to enable Cortex Guard to filter potentially unsafe responses. Default: False. - - - `response_format` (str, optional): A JSON schema that the response should follow - """ - temperature: Optional[float] - top_p: Optional[float] - max_tokens: Optional[int] - guardrails: Optional[bool] - response_format: Optional[str] - - def __init__( - self, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - max_tokens: Optional[int] = None, - guardrails: Optional[bool] = None, - response_format: Optional[str] = None, - ) -> None: - locals_ = locals().copy() - for key, value in locals_.items(): - if key != "self" and value is not None: - setattr(self.__class__, key, value) + """ @classmethod def get_config(cls): @@ -60,7 +35,7 @@ class SnowflakeConfig(OpenAILikeChatConfig): "top_p", "response_format" ] - + def map_openai_params( self, non_default_params: dict, @@ -83,4 +58,160 @@ class SnowflakeConfig(OpenAILikeChatConfig): for param, value in non_default_params.items(): if param in supported_openai_params: optional_params[param] = value - return optional_params \ No newline at end of file + return optional_params + + def _convert_tool_response_to_message( + message: ChatCompletionAssistantMessage, json_mode: bool + ) -> ChatCompletionAssistantMessage: + """ + if json_mode is true, convert the returned tool call response to a content with json str + + e.g. input: + + {"role": "assistant", "tool_calls": [{"id": "call_5ms4", "type": "function", "function": {"name": "json_tool_call", "arguments": "{\"key\": \"question\", \"value\": \"What is the capital of France?\"}"}}]} + + output: + + {"role": "assistant", "content": "{\"key\": \"question\", \"value\": \"What is the capital of France?\"}"} + """ + if not json_mode: + return message + + _tool_calls = message.get("tool_calls") + + if _tool_calls is None or len(_tool_calls) != 1: + return message + + message["content"] = _tool_calls[0]["function"].get("arguments") or "" + message["tool_calls"] = None + + return message + + + @staticmethod + def transform_response( + model: str, + raw_response: httpx.Response, + model_response: ModelResponse, + logging_obj: LiteLLMLoggingObj, + request_data: dict, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + encoding: Any, + api_key: Optional[str] = None, + json_mode: Optional[bool] = None, + ) -> ModelResponse: + response_json = raw_response.json() + logging_obj.post_call( + input=messages, + api_key="", + original_response=response_json, + additional_args={"complete_input_dict": request_data}, + ) + + if json_mode: + for choice in response_json["choices"]: + message = SnowflakeConfig._convert_tool_response_to_message( + choice.get("message"), json_mode + ) + choice["message"] = message + + returned_response = ModelResponse(**response_json) + + returned_response.model = ( + "snowflake/" + (returned_response.model or "") + ) + + if model is not None: + returned_response._hidden_params["model"] = model + return returned_response + + + def validate_environment( + self, + headers: dict, + model: str, + api_base: str = None, + api_key: Optional[str] = None, + messages: dict = None, + optional_params: dict = None, + ) -> dict: + """ + Return headers to use for Snowflake completion request + + Snowflake REST API Ref: https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-llm-rest-api#api-reference + Expected headers: + { + "Content-Type": "application/json", + "Accept": "application/json", + "Authorization": "Bearer " + , + "X-Snowflake-Authorization-Token-Type": "KEYPAIR_JWT" + } + """ + + if api_key is None: + raise ValueError( + "Missing Snowflake JWT key" + ) + + headers.update( + { + "Content-Type": "application/json", + "Accept": "application/json", + "Authorization": "Bearer " + api_key, + "X-Snowflake-Authorization-Token-Type": "KEYPAIR_JWT" + } + ) + return headers + + def _get_openai_compatible_provider_info( + self, api_base: Optional[str], api_key: Optional[str] + ) -> Tuple[Optional[str], Optional[str]]: + api_base = ( + api_base + or f"""https://{get_secret("SNOWFLAKE_ACCOUNT_ID")}.snowflakecomputing.com/api/v2/cortex/inference:complete""" + or get_secret("SNOWFLAKE_API_BASE") + ) # type: ignore + dynamic_api_key = api_key or get_secret("SNOWFLAKE_JWT") + return api_base, dynamic_api_key + + def get_complete_url( + self, + api_base: Optional[str], + model: str, + optional_params: dict, + stream: Optional[bool] = None, + ) -> str: + """ + If api_base is not provided, use the default DeepSeek /chat/completions endpoint. + """ + if not api_base: + api_base = f"""https://{get_secret("SNOWFLAKE_ACCOUNT_ID")}.snowflakecomputing.com/api/v2/cortex/inference:complete""" + + return api_base + + def transform_request( + self, + model: str, + messages: dict , + optional_params: dict, + litellm_params: dict, + headers: dict + ) -> dict: + stream: bool = optional_params.pop("stream", None) or False + extra_body = optional_params.pop("extra_body", {}) + return { + "model": model, + "messages": messages, + "stream": stream, + **optional_params, + **extra_body, + } + + def get_model_response_iterator( + self, + streaming_response: ModelResponse, + sync_stream: bool, + ): + return ModelResponseIterator(streaming_response=streaming_response, sync_stream=sync_stream) \ No newline at end of file diff --git a/litellm/main.py b/litellm/main.py index e4abbd8458..9244e47d49 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -146,7 +146,6 @@ from .llms.openai_like.embedding.handler import OpenAILikeEmbeddingHandler from .llms.petals.completion import handler as petals_handler from .llms.predibase.chat.handler import PredibaseChatCompletion from .llms.replicate.chat.handler import completion as replicate_chat_completion -from .llms.snowflake.completion.handler import SnowflakeChatCompletion from .llms.sagemaker.chat.handler import SagemakerChatHandler from .llms.sagemaker.completion.handler import SagemakerLLM from .llms.vertex_ai import vertex_ai_non_gemini @@ -237,7 +236,6 @@ databricks_embedding = DatabricksEmbeddingHandler() base_llm_http_handler = BaseLLMHTTPHandler() base_llm_aiohttp_handler = BaseLLMAIOHTTPHandler() sagemaker_chat_completion = SagemakerChatHandler() -snow_flake_chat_completion = SnowflakeChatCompletion() ####### COMPLETION ENDPOINTS ################ @@ -2977,27 +2975,37 @@ def completion( # type: ignore # noqa: PLR0915 return response response = model_response elif custom_llm_provider == "snowflake" or model in litellm.snowflake_models: - api_base = ( - api_base - or f"""https://{get_secret("SNOWFLAKE_ACCOUNT_ID")}.snowflakecomputing.com/api/v2/cortex/inference:complete""" - or get_secret("SNOWFLAKE_API_BASE") - ) - response = snow_flake_chat_completion.completion( - model=model, - messages=messages, - api_base=api_base, - acompletion=acompletion, - custom_prompt_dict=litellm.custom_prompt_dict, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - JWT=api_key, - logging_obj=logging, - headers=headers, - ) + try: + client = HTTPHandler(timeout=timeout) if stream is False else None # Keep this here, otherwise, the httpx.client closes and streaming is impossible + response = base_llm_http_handler.completion( + model=model, + messages=messages, + headers=headers, + model_response=model_response, + api_key=api_key, + api_base=api_base, + acompletion=acompletion, + logging_obj=logging, + optional_params=optional_params, + litellm_params=litellm_params, + timeout=timeout, # type: ignore + client= client, + custom_llm_provider=custom_llm_provider, + encoding=encoding, + stream=stream, + ) + + + except Exception as e: + ## LOGGING - log the original exception returned + logging.post_call( + input=messages, + api_key=api_key, + original_response=str(e), + additional_args={"headers": headers}, + ) + raise e + elif custom_llm_provider == "custom": url = litellm.api_base or api_base or "" if url is None or url == "": From 1dabc62d7b88e34a6ecab3d90ebff5fe4c252881 Mon Sep 17 00:00:00 2001 From: Sunny Wan Date: Tue, 11 Mar 2025 02:05:02 -0400 Subject: [PATCH 12/21] removed hardcoding and added models to model_prices --- litellm/__init__.py | 22 ---- model_prices_and_context_window.json | 168 +++++++++++++++++++++++++++ 2 files changed, 168 insertions(+), 22 deletions(-) diff --git a/litellm/__init__.py b/litellm/__init__.py index da89bba901..86c75d8b14 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -600,28 +600,6 @@ ollama_models = ["llama2"] maritalk_models = ["maritalk"] -snowflake_models = [ - "snowflake/deepseek-r1", - "snowflake/claude-3-5-sonnet", - "snowflake/llama3.2-1b", - "snowflake/llama3.2-3b", - "snowflake/llama3.1-8b", - "snowflake/llama3.1-70b", - "snowflake/llama3.3-70b", - "snowflake/snowflake-llama-3.3-70b", - "snowflake/llama3.1-405b", - "snowflake/snowflake-llama-3.1-405b", - "snowflake/snowflake-arctic", - "snowflake/reka-core", - "snowflake/reka-flash", - "snowflake/mistral-large2", - "snowflake/mixtral-8x7b", - "snowflake/mistral-7b", - "snowflake/jamba-instruct", - "snowflake/jamba-1.5-mini", - "snowflake/jamba-1.5-large", - "snowflake/gemma-7b" -] model_list = ( open_ai_chat_completion_models diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index 96076fa3b8..d465506a3a 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -9461,5 +9461,173 @@ "output_cost_per_token": 0.000000018, "litellm_provider": "jina_ai", "mode": "rerank" + }, + "snowflake/deepseek-r1": { + "max_tokens": 32768, + "max_input_tokens": 32768, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/snowflake-arctic": { + "max_tokens": 4096, + "max_input_tokens": 4096, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/claude-3-5-sonnet": { + "max_tokens": 18000, + "max_input_tokens": 18000, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/mistral-large": { + "max_tokens": 32000, + "max_input_tokens": 32000, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/mistral-large2": { + "max_tokens": 128000, + "max_input_tokens": 128000, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/reka-flash": { + "max_tokens": 100000, + "max_input_tokens": 100000, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/reka-core": { + "max_tokens": 32000, + "max_input_tokens": 32000, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/jamba-instruct": { + "max_tokens": 256000, + "max_input_tokens": 256000, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/jamba-1.5-mini": { + "max_tokens": 256000, + "max_input_tokens": 256000, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/jamba-1.5-large": { + "max_tokens": 256000, + "max_input_tokens": 256000, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/mixtral-8x7b": { + "max_tokens": 32000, + "max_input_tokens": 32000, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/llama2-70b-chat": { + "max_tokens": 4096, + "max_input_tokens": 4096, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/llama3-8b": { + "max_tokens": 8000, + "max_input_tokens": 8000, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/llama3-70b": { + "max_tokens": 8000, + "max_input_tokens": 8000, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/llama3.1-8b": { + "max_tokens": 128000, + "max_input_tokens": 128000, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/llama3.1-70b": { + "max_tokens": 128000, + "max_input_tokens": 128000, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/llama3.3-70b": { + "max_tokens": 128000, + "max_input_tokens": 128000, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/snowflake-llama-3.3-70b": { + "max_tokens": 8000, + "max_input_tokens": 8000, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/llama3.1-405b": { + "max_tokens": 128000, + "max_input_tokens": 128000, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/snowflake-llama-3.1-405b": { + "max_tokens": 8000, + "max_input_tokens": 8000, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/llama3.2-1b": { + "max_tokens": 128000, + "max_input_tokens": 128000, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/llama3.2-3b": { + "max_tokens": 128000, + "max_input_tokens": 128000, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/mistral-7b": { + "max_tokens": 32000, + "max_input_tokens": 32000, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" + }, + "snowflake/gemma-7b": { + "max_tokens": 8000, + "max_input_tokens": 8000, + "max_output_tokens": 8192, + "litellm_provider": "snowflake", + "mode": "completion" } } From 0f0f8f2c0e7ce8700bc459a860dd5346b8ef3361 Mon Sep 17 00:00:00 2001 From: Sunny Wan Date: Tue, 11 Mar 2025 16:14:31 -0400 Subject: [PATCH 13/21] updated snowflake docs --- docs/my-website/docs/providers/snowflake.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/docs/my-website/docs/providers/snowflake.md b/docs/my-website/docs/providers/snowflake.md index 93b6f4f2c5..6a666e2953 100644 --- a/docs/my-website/docs/providers/snowflake.md +++ b/docs/my-website/docs/providers/snowflake.md @@ -3,6 +3,15 @@ import TabItem from '@theme/TabItem'; # Snowflake +| Property | Details | +|-------|-------| +| Description | The Snowflake Cortex LLM REST API lets you access the COMPLETE function via HTTP POST requests| +| Provider Route on LiteLLM | `snowflake/` | +| Link to Provider Doc | [Vertex AI ↗](https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-llm-rest-api) | +| Base URL | [https://{account-id}.snowflakecomputing.com/api/v2/cortex/inference:complete/](https://{account-id}.snowflakecomputing.com/api/v2/cortex/inference:complete) | +| Supported Operations | `/completions`| + + Currently, Snowflake's REST API does not have an endpoint for `snowflake-arctic-embed` embedding models. If you want to use these embedding models with Litellm, you can call them through our Hugging Face provider. From 844c27a9b2dd7a667279a2a8be7bce3e0d0b7ed2 Mon Sep 17 00:00:00 2001 From: Sunny Wan Date: Tue, 11 Mar 2025 16:32:15 -0400 Subject: [PATCH 14/21] added mock_tests --- tests/llm_translation/test_snowflake.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/tests/llm_translation/test_snowflake.py b/tests/llm_translation/test_snowflake.py index e64fb8d6d1..139fa16d7a 100644 --- a/tests/llm_translation/test_snowflake.py +++ b/tests/llm_translation/test_snowflake.py @@ -6,7 +6,7 @@ from dotenv import load_dotenv load_dotenv() import pytest -from litellm import completion, acompletion, set_verbose +from litellm import completion, acompletion @pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.asyncio @@ -23,6 +23,7 @@ async def test_chat_completion_snowflake(sync_mode): response = completion( model="snowflake/mistral-7b", messages=messages, + api_base = "https://exampleopenaiendpoint-production.up.railway.app/v1/chat/completions" ) print(response) assert response is not None @@ -30,6 +31,7 @@ async def test_chat_completion_snowflake(sync_mode): response = await acompletion( model="snowflake/mistral-7b", messages=messages, + api_base = "https://exampleopenaiendpoint-production.up.railway.app/v1/chat/completions" ) print(response) assert response is not None @@ -54,25 +56,21 @@ async def test_chat_completion_snowflake_stream(sync_mode): messages=messages, max_tokens=100, stream=True, + api_base = "https://exampleopenaiendpoint-production.up.railway.app/v1/chat/completions" ) - chunk_count = 0 async for chunk in response: print(chunk) - chunk_count += 1 - assert chunk_count > 0 else: response = completion( model="snowflake/mistral-7b", messages=messages, max_tokens=100, stream=True, + api_base = "https://exampleopenaiendpoint-production.up.railway.app/v1/chat/completions" ) - - chunk_count = 0 + for chunk in response: print(chunk) - chunk_count += 1 - assert chunk_count > 0 except Exception as e: pytest.fail(f"Error occurred: {e}") From 91d6dc388fd8da86d9ae3386fcab98aaf32570b4 Mon Sep 17 00:00:00 2001 From: Sunny Wan Date: Tue, 11 Mar 2025 22:20:16 -0400 Subject: [PATCH 15/21] changed completion to chat for modes --- model_prices_and_context_window.json | 48 ++++++++++++++-------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index d465506a3a..140fa36478 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -9467,167 +9467,167 @@ "max_input_tokens": 32768, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/snowflake-arctic": { "max_tokens": 4096, "max_input_tokens": 4096, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/claude-3-5-sonnet": { "max_tokens": 18000, "max_input_tokens": 18000, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/mistral-large": { "max_tokens": 32000, "max_input_tokens": 32000, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/mistral-large2": { "max_tokens": 128000, "max_input_tokens": 128000, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/reka-flash": { "max_tokens": 100000, "max_input_tokens": 100000, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/reka-core": { "max_tokens": 32000, "max_input_tokens": 32000, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/jamba-instruct": { "max_tokens": 256000, "max_input_tokens": 256000, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/jamba-1.5-mini": { "max_tokens": 256000, "max_input_tokens": 256000, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/jamba-1.5-large": { "max_tokens": 256000, "max_input_tokens": 256000, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/mixtral-8x7b": { "max_tokens": 32000, "max_input_tokens": 32000, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/llama2-70b-chat": { "max_tokens": 4096, "max_input_tokens": 4096, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/llama3-8b": { "max_tokens": 8000, "max_input_tokens": 8000, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/llama3-70b": { "max_tokens": 8000, "max_input_tokens": 8000, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/llama3.1-8b": { "max_tokens": 128000, "max_input_tokens": 128000, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/llama3.1-70b": { "max_tokens": 128000, "max_input_tokens": 128000, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/llama3.3-70b": { "max_tokens": 128000, "max_input_tokens": 128000, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/snowflake-llama-3.3-70b": { "max_tokens": 8000, "max_input_tokens": 8000, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/llama3.1-405b": { "max_tokens": 128000, "max_input_tokens": 128000, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/snowflake-llama-3.1-405b": { "max_tokens": 8000, "max_input_tokens": 8000, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/llama3.2-1b": { "max_tokens": 128000, "max_input_tokens": 128000, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/llama3.2-3b": { "max_tokens": 128000, "max_input_tokens": 128000, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/mistral-7b": { "max_tokens": 32000, "max_input_tokens": 32000, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/gemma-7b": { "max_tokens": 8000, "max_input_tokens": 8000, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" } } From f4539fb95a4c7147d08ca7b8217dd287c97d3f7a Mon Sep 17 00:00:00 2001 From: Sunny Wan Date: Tue, 11 Mar 2025 22:21:12 -0400 Subject: [PATCH 16/21] changed to chat/transformations --- litellm/llms/snowflake/{completion => chat}/transformation.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename litellm/llms/snowflake/{completion => chat}/transformation.py (100%) diff --git a/litellm/llms/snowflake/completion/transformation.py b/litellm/llms/snowflake/chat/transformation.py similarity index 100% rename from litellm/llms/snowflake/completion/transformation.py rename to litellm/llms/snowflake/chat/transformation.py From 5dfd0adf19e0a515f5f7d36d6683c2178afab8ef Mon Sep 17 00:00:00 2001 From: Sunny Wan Date: Tue, 11 Mar 2025 22:22:28 -0400 Subject: [PATCH 17/21] Update model_prices_and_context_window_backup.json --- ...odel_prices_and_context_window_backup.json | 48 +++++++++---------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index d465506a3a..140fa36478 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -9467,167 +9467,167 @@ "max_input_tokens": 32768, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/snowflake-arctic": { "max_tokens": 4096, "max_input_tokens": 4096, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/claude-3-5-sonnet": { "max_tokens": 18000, "max_input_tokens": 18000, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/mistral-large": { "max_tokens": 32000, "max_input_tokens": 32000, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/mistral-large2": { "max_tokens": 128000, "max_input_tokens": 128000, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/reka-flash": { "max_tokens": 100000, "max_input_tokens": 100000, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/reka-core": { "max_tokens": 32000, "max_input_tokens": 32000, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/jamba-instruct": { "max_tokens": 256000, "max_input_tokens": 256000, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/jamba-1.5-mini": { "max_tokens": 256000, "max_input_tokens": 256000, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/jamba-1.5-large": { "max_tokens": 256000, "max_input_tokens": 256000, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/mixtral-8x7b": { "max_tokens": 32000, "max_input_tokens": 32000, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/llama2-70b-chat": { "max_tokens": 4096, "max_input_tokens": 4096, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/llama3-8b": { "max_tokens": 8000, "max_input_tokens": 8000, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/llama3-70b": { "max_tokens": 8000, "max_input_tokens": 8000, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/llama3.1-8b": { "max_tokens": 128000, "max_input_tokens": 128000, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/llama3.1-70b": { "max_tokens": 128000, "max_input_tokens": 128000, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/llama3.3-70b": { "max_tokens": 128000, "max_input_tokens": 128000, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/snowflake-llama-3.3-70b": { "max_tokens": 8000, "max_input_tokens": 8000, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/llama3.1-405b": { "max_tokens": 128000, "max_input_tokens": 128000, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/snowflake-llama-3.1-405b": { "max_tokens": 8000, "max_input_tokens": 8000, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/llama3.2-1b": { "max_tokens": 128000, "max_input_tokens": 128000, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/llama3.2-3b": { "max_tokens": 128000, "max_input_tokens": 128000, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/mistral-7b": { "max_tokens": 32000, "max_input_tokens": 32000, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" }, "snowflake/gemma-7b": { "max_tokens": 8000, "max_input_tokens": 8000, "max_output_tokens": 8192, "litellm_provider": "snowflake", - "mode": "completion" + "mode": "chat" } } From 70770b6aa4643453fd27ef7a8f1c47d9cd2cfffd Mon Sep 17 00:00:00 2001 From: Sunny Wan Date: Thu, 13 Mar 2025 19:42:10 -0400 Subject: [PATCH 18/21] Removed unnecessary code and refactored --- litellm/__init__.py | 4 +- litellm/llms/snowflake/chat/transformation.py | 43 +------------------ 2 files changed, 3 insertions(+), 44 deletions(-) diff --git a/litellm/__init__.py b/litellm/__init__.py index 86c75d8b14..55a185e571 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -821,7 +821,7 @@ from .llms.databricks.embed.transformation import DatabricksEmbeddingConfig from .llms.predibase.chat.transformation import PredibaseConfig from .llms.replicate.chat.transformation import ReplicateConfig from .llms.cohere.completion.transformation import CohereTextConfig as CohereConfig -from .llms.snowflake.completion.transformation import SnowflakeConfig +from .llms.snowflake.chat.transformation import SnowflakeConfig from .llms.cohere.rerank.transformation import CohereRerankConfig from .llms.cohere.rerank_v2.transformation import CohereRerankV2Config from .llms.azure_ai.rerank.transformation import AzureAIRerankConfig @@ -936,7 +936,7 @@ from .llms.openai.chat.o_series_transformation import ( OpenAIOSeriesConfig, ) -from .llms.snowflake.completion.transformation import SnowflakeConfig +from .llms.snowflake.chat.transformation import SnowflakeConfig openaiOSeriesConfig = OpenAIOSeriesConfig() from .llms.openai.chat.gpt_transformation import ( diff --git a/litellm/llms/snowflake/chat/transformation.py b/litellm/llms/snowflake/chat/transformation.py index 48593cf0db..7700607958 100644 --- a/litellm/llms/snowflake/chat/transformation.py +++ b/litellm/llms/snowflake/chat/transformation.py @@ -59,34 +59,6 @@ class SnowflakeConfig(OpenAIGPTConfig): if param in supported_openai_params: optional_params[param] = value return optional_params - - def _convert_tool_response_to_message( - message: ChatCompletionAssistantMessage, json_mode: bool - ) -> ChatCompletionAssistantMessage: - """ - if json_mode is true, convert the returned tool call response to a content with json str - - e.g. input: - - {"role": "assistant", "tool_calls": [{"id": "call_5ms4", "type": "function", "function": {"name": "json_tool_call", "arguments": "{\"key\": \"question\", \"value\": \"What is the capital of France?\"}"}}]} - - output: - - {"role": "assistant", "content": "{\"key\": \"question\", \"value\": \"What is the capital of France?\"}"} - """ - if not json_mode: - return message - - _tool_calls = message.get("tool_calls") - - if _tool_calls is None or len(_tool_calls) != 1: - return message - - message["content"] = _tool_calls[0]["function"].get("arguments") or "" - message["tool_calls"] = None - - return message - @staticmethod def transform_response( @@ -110,12 +82,6 @@ class SnowflakeConfig(OpenAIGPTConfig): additional_args={"complete_input_dict": request_data}, ) - if json_mode: - for choice in response_json["choices"]: - message = SnowflakeConfig._convert_tool_response_to_message( - choice.get("message"), json_mode - ) - choice["message"] = message returned_response = ModelResponse(**response_json) @@ -207,11 +173,4 @@ class SnowflakeConfig(OpenAIGPTConfig): "stream": stream, **optional_params, **extra_body, - } - - def get_model_response_iterator( - self, - streaming_response: ModelResponse, - sync_stream: bool, - ): - return ModelResponseIterator(streaming_response=streaming_response, sync_stream=sync_stream) \ No newline at end of file + } \ No newline at end of file From d3781dfe36898a44c78b0e100157b48ab0f678ce Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 13 Mar 2025 16:58:34 -0700 Subject: [PATCH 19/21] fix linting errors --- litellm/llms/snowflake/chat/transformation.py | 79 ++++++++----------- 1 file changed, 35 insertions(+), 44 deletions(-) diff --git a/litellm/llms/snowflake/chat/transformation.py b/litellm/llms/snowflake/chat/transformation.py index 7700607958..d3634e7950 100644 --- a/litellm/llms/snowflake/chat/transformation.py +++ b/litellm/llms/snowflake/chat/transformation.py @@ -1,14 +1,14 @@ -''' +""" Support for Snowflake REST API -''' -import httpx -from typing import List, Optional, Tuple, Any, TYPE_CHECKING +""" +from typing import TYPE_CHECKING, Any, List, Optional, Tuple + +import httpx + +from litellm.secret_managers.main import get_secret_str from litellm.types.llms.openai import AllMessageValues -from litellm.utils import get_secret from litellm.types.utils import ModelResponse -from litellm.types.llms.openai import ChatCompletionAssistantMessage -from litellm.llms.databricks.streaming_utils import ModelResponseIterator from ...openai_like.chat.transformation import OpenAIGPTConfig @@ -19,6 +19,7 @@ if TYPE_CHECKING: else: LiteLLMLoggingObj = Any + class SnowflakeConfig(OpenAIGPTConfig): """ source: https://docs.snowflake.com/en/sql-reference/functions/complete-snowflake-cortex @@ -27,15 +28,10 @@ class SnowflakeConfig(OpenAIGPTConfig): @classmethod def get_config(cls): return super().get_config() - + def get_supported_openai_params(self, model: str) -> List: - return [ - "temperature", - "max_tokens", - "top_p", - "response_format" - ] - + return ["temperature", "max_tokens", "top_p", "response_format"] + def map_openai_params( self, non_default_params: dict, @@ -60,8 +56,8 @@ class SnowflakeConfig(OpenAIGPTConfig): optional_params[param] = value return optional_params - @staticmethod def transform_response( + self, model: str, raw_response: httpx.Response, model_response: ModelResponse, @@ -82,26 +78,22 @@ class SnowflakeConfig(OpenAIGPTConfig): additional_args={"complete_input_dict": request_data}, ) - returned_response = ModelResponse(**response_json) - returned_response.model = ( - "snowflake/" + (returned_response.model or "") - ) + returned_response.model = "snowflake/" + (returned_response.model or "") if model is not None: returned_response._hidden_params["model"] = model return returned_response - def validate_environment( self, headers: dict, model: str, - api_base: str = None, + messages: List[AllMessageValues], + optional_params: dict, api_key: Optional[str] = None, - messages: dict = None, - optional_params: dict = None, + api_base: Optional[str] = None, ) -> dict: """ Return headers to use for Snowflake completion request @@ -113,57 +105,56 @@ class SnowflakeConfig(OpenAIGPTConfig): "Accept": "application/json", "Authorization": "Bearer " + , "X-Snowflake-Authorization-Token-Type": "KEYPAIR_JWT" - } + } """ if api_key is None: - raise ValueError( - "Missing Snowflake JWT key" - ) + raise ValueError("Missing Snowflake JWT key") headers.update( { "Content-Type": "application/json", "Accept": "application/json", "Authorization": "Bearer " + api_key, - "X-Snowflake-Authorization-Token-Type": "KEYPAIR_JWT" + "X-Snowflake-Authorization-Token-Type": "KEYPAIR_JWT", } ) return headers - + def _get_openai_compatible_provider_info( self, api_base: Optional[str], api_key: Optional[str] ) -> Tuple[Optional[str], Optional[str]]: api_base = ( api_base - or f"""https://{get_secret("SNOWFLAKE_ACCOUNT_ID")}.snowflakecomputing.com/api/v2/cortex/inference:complete""" - or get_secret("SNOWFLAKE_API_BASE") - ) # type: ignore - dynamic_api_key = api_key or get_secret("SNOWFLAKE_JWT") + or f"""https://{get_secret_str("SNOWFLAKE_ACCOUNT_ID")}.snowflakecomputing.com/api/v2/cortex/inference:complete""" + or get_secret_str("SNOWFLAKE_API_BASE") + ) + dynamic_api_key = api_key or get_secret_str("SNOWFLAKE_JWT") return api_base, dynamic_api_key - + def get_complete_url( self, api_base: Optional[str], model: str, optional_params: dict, + litellm_params: dict, stream: Optional[bool] = None, ) -> str: """ If api_base is not provided, use the default DeepSeek /chat/completions endpoint. """ if not api_base: - api_base = f"""https://{get_secret("SNOWFLAKE_ACCOUNT_ID")}.snowflakecomputing.com/api/v2/cortex/inference:complete""" + api_base = f"""https://{get_secret_str("SNOWFLAKE_ACCOUNT_ID")}.snowflakecomputing.com/api/v2/cortex/inference:complete""" return api_base - + def transform_request( - self, - model: str, - messages: dict , - optional_params: dict, - litellm_params: dict, - headers: dict + self, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + headers: dict, ) -> dict: stream: bool = optional_params.pop("stream", None) or False extra_body = optional_params.pop("extra_body", {}) @@ -173,4 +164,4 @@ class SnowflakeConfig(OpenAIGPTConfig): "stream": stream, **optional_params, **extra_body, - } \ No newline at end of file + } From 7dd55ce70cbd77592f638b775abec25089b7c03c Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 13 Mar 2025 17:49:37 -0700 Subject: [PATCH 20/21] fix @pytest.mark.skip(reason="lakera deprecated their v1 endpoint.") --- tests/local_testing/test_lakera_ai_prompt_injection.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/local_testing/test_lakera_ai_prompt_injection.py b/tests/local_testing/test_lakera_ai_prompt_injection.py index f9035a74f4..e9704143be 100644 --- a/tests/local_testing/test_lakera_ai_prompt_injection.py +++ b/tests/local_testing/test_lakera_ai_prompt_injection.py @@ -55,6 +55,7 @@ def make_config_map(config: dict): ), ) @pytest.mark.asyncio +@pytest.mark.skip(reason="lakera deprecated their v1 endpoint.") async def test_lakera_prompt_injection_detection(): """ Tests to see OpenAI Moderation raises an error for a flagged response @@ -101,6 +102,7 @@ async def test_lakera_prompt_injection_detection(): ), ) @pytest.mark.asyncio +@pytest.mark.skip(reason="lakera deprecated their v1 endpoint.") async def test_lakera_safe_prompt(): """ Nothing should get raised here @@ -126,6 +128,7 @@ async def test_lakera_safe_prompt(): @pytest.mark.asyncio +@pytest.mark.skip(reason="lakera deprecated their v1 endpoint.") async def test_moderations_on_embeddings(): try: temp_router = litellm.Router( @@ -188,6 +191,7 @@ async def test_moderations_on_embeddings(): } ), ) +@pytest.mark.skip(reason="lakera deprecated their v1 endpoint.") async def test_messages_for_disabled_role(spy_post): moderation = lakeraAI_Moderation() data = { @@ -226,6 +230,7 @@ async def test_messages_for_disabled_role(spy_post): ), ) @patch("litellm.add_function_to_prompt", False) +@pytest.mark.skip(reason="lakera deprecated their v1 endpoint.") async def test_system_message_with_function_input(spy_post): moderation = lakeraAI_Moderation() data = { @@ -270,6 +275,7 @@ async def test_system_message_with_function_input(spy_post): ), ) @patch("litellm.add_function_to_prompt", False) +@pytest.mark.skip(reason="lakera deprecated their v1 endpoint.") async def test_multi_message_with_function_input(spy_post): moderation = lakeraAI_Moderation() data = { @@ -317,6 +323,7 @@ async def test_multi_message_with_function_input(spy_post): } ), ) +@pytest.mark.skip(reason="lakera deprecated their v1 endpoint.") async def test_message_ordering(spy_post): moderation = lakeraAI_Moderation() data = { @@ -343,6 +350,7 @@ async def test_message_ordering(spy_post): @pytest.mark.asyncio +@pytest.mark.skip(reason="lakera deprecated their v1 endpoint.") async def test_callback_specific_param_run_pre_call_check_lakera(): from typing import Dict, List, Optional, Union @@ -389,6 +397,7 @@ async def test_callback_specific_param_run_pre_call_check_lakera(): @pytest.mark.asyncio +@pytest.mark.skip(reason="lakera deprecated their v1 endpoint.") async def test_callback_specific_thresholds(): from typing import Dict, List, Optional, Union From 69b47cf738c47b74027f4da31c771233107c19ff Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 13 Mar 2025 20:10:41 -0700 Subject: [PATCH 21/21] fix code quality check --- litellm/llms/snowflake/common_utils.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/litellm/llms/snowflake/common_utils.py b/litellm/llms/snowflake/common_utils.py index ab8d37bcf1..40c8270f95 100644 --- a/litellm/llms/snowflake/common_utils.py +++ b/litellm/llms/snowflake/common_utils.py @@ -1,5 +1,5 @@ -import httpx -from typing import List, Optional +from typing import Optional + class SnowflakeBase: def validate_environment( @@ -17,22 +17,18 @@ class SnowflakeBase: "Accept": "application/json", "Authorization": "Bearer " + , "X-Snowflake-Authorization-Token-Type": "KEYPAIR_JWT" - } + } """ if JWT is None: - raise ValueError( - "Missing Snowflake JWT key" - ) + raise ValueError("Missing Snowflake JWT key") headers.update( { "Content-Type": "application/json", "Accept": "application/json", "Authorization": "Bearer " + JWT, - "X-Snowflake-Authorization-Token-Type": "KEYPAIR_JWT" + "X-Snowflake-Authorization-Token-Type": "KEYPAIR_JWT", } ) return headers - - \ No newline at end of file