Merge pull request #9222 from BerriAI/litellm_snowflake_pr_mar_13

[Feat] Add Snowflake Cortex to LiteLLM
This commit is contained in:
Ishaan Jaff 2025-03-13 21:35:39 -07:00 committed by GitHub
commit 241a36a74f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 766 additions and 0 deletions

View file

@ -0,0 +1,89 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Snowflake
| Property | Details |
|-------|-------|
| Description | The Snowflake Cortex LLM REST API lets you access the COMPLETE function via HTTP POST requests|
| Provider Route on LiteLLM | `snowflake/` |
| Link to Provider Doc | [Vertex AI ↗](https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-llm-rest-api) |
| Base URL | [https://{account-id}.snowflakecomputing.com/api/v2/cortex/inference:complete/](https://{account-id}.snowflakecomputing.com/api/v2/cortex/inference:complete) |
| Supported Operations | `/completions`|
Currently, Snowflake's REST API does not have an endpoint for `snowflake-arctic-embed` embedding models. If you want to use these embedding models with Litellm, you can call them through our Hugging Face provider.
Find the Arctic Embed models [here](https://huggingface.co/collections/Snowflake/arctic-embed-661fd57d50fab5fc314e4c18) on Hugging Face.
## Supported OpenAI Parameters
```
"temperature",
"max_tokens",
"top_p",
"response_format"
```
## API KEYS
Snowflake does have API keys. Instead, you access the Snowflake API with your JWT token and account identifier.
```python
import os
os.environ["SNOWFLAKE_JWT"] = "YOUR JWT"
os.environ["SNOWFLAKE_ACCOUNT_ID"] = "YOUR ACCOUNT IDENTIFIER"
```
## Usage
```python
from litellm import completion
## set ENV variables
os.environ["SNOWFLAKE_JWT"] = "YOUR JWT"
os.environ["SNOWFLAKE_ACCOUNT_ID"] = "YOUR ACCOUNT IDENTIFIER"
# Snowflake call
response = completion(
model="snowflake/mistral-7b",
messages = [{ "content": "Hello, how are you?","role": "user"}]
)
```
## Usage with LiteLLM Proxy
#### 1. Required env variables
```bash
export SNOWFLAKE_JWT=""
export SNOWFLAKE_ACCOUNT_ID = ""
```
#### 2. Start the proxy~
```yaml
model_list:
- model_name: mistral-7b
litellm_params:
model: snowflake/mistral-7b
api_key: YOUR_API_KEY
api_base: https://YOUR-ACCOUNT-ID.snowflakecomputing.com/api/v2/cortex/inference:complete
```
```bash
litellm --config /path/to/config.yaml
```
#### 3. Test it
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--data ' {
"model": "snowflake/mistral-7b",
"messages": [
{
"role": "user",
"content": "Hello, how are you?"
}
]
}
'
```

View file

@ -231,6 +231,7 @@ const sidebars = {
"providers/sambanova", "providers/sambanova",
"providers/custom_llm_server", "providers/custom_llm_server",
"providers/petals", "providers/petals",
"providers/snowflake"
], ],
}, },
{ {

View file

@ -182,6 +182,7 @@ cloudflare_api_key: Optional[str] = None
baseten_key: Optional[str] = None baseten_key: Optional[str] = None
aleph_alpha_key: Optional[str] = None aleph_alpha_key: Optional[str] = None
nlp_cloud_key: Optional[str] = None nlp_cloud_key: Optional[str] = None
snowflake_key: Optional[str] = None
common_cloud_provider_auth_params: dict = { common_cloud_provider_auth_params: dict = {
"params": ["project", "region_name", "token"], "params": ["project", "region_name", "token"],
"providers": ["vertex_ai", "bedrock", "watsonx", "azure", "vertex_ai_beta"], "providers": ["vertex_ai", "bedrock", "watsonx", "azure", "vertex_ai_beta"],
@ -416,6 +417,7 @@ cerebras_models: List = []
galadriel_models: List = [] galadriel_models: List = []
sambanova_models: List = [] sambanova_models: List = []
assemblyai_models: List = [] assemblyai_models: List = []
snowflake_models: List = []
def is_bedrock_pricing_only_model(key: str) -> bool: def is_bedrock_pricing_only_model(key: str) -> bool:
@ -569,6 +571,8 @@ def add_known_models():
assemblyai_models.append(key) assemblyai_models.append(key)
elif value.get("litellm_provider") == "jina_ai": elif value.get("litellm_provider") == "jina_ai":
jina_ai_models.append(key) jina_ai_models.append(key)
elif value.get("litellm_provider") == "snowflake":
snowflake_models.append(key)
add_known_models() add_known_models()
@ -598,6 +602,7 @@ ollama_models = ["llama2"]
maritalk_models = ["maritalk"] maritalk_models = ["maritalk"]
model_list = ( model_list = (
open_ai_chat_completion_models open_ai_chat_completion_models
+ open_ai_text_completion_models + open_ai_text_completion_models
@ -642,6 +647,7 @@ model_list = (
+ azure_text_models + azure_text_models
+ assemblyai_models + assemblyai_models
+ jina_ai_models + jina_ai_models
+ snowflake_models
) )
model_list_set = set(model_list) model_list_set = set(model_list)
@ -697,6 +703,7 @@ models_by_provider: dict = {
"sambanova": sambanova_models, "sambanova": sambanova_models,
"assemblyai": assemblyai_models, "assemblyai": assemblyai_models,
"jina_ai": jina_ai_models, "jina_ai": jina_ai_models,
"snowflake": snowflake_models,
} }
# mapping for those models which have larger equivalents # mapping for those models which have larger equivalents
@ -813,6 +820,7 @@ from .llms.databricks.embed.transformation import DatabricksEmbeddingConfig
from .llms.predibase.chat.transformation import PredibaseConfig from .llms.predibase.chat.transformation import PredibaseConfig
from .llms.replicate.chat.transformation import ReplicateConfig from .llms.replicate.chat.transformation import ReplicateConfig
from .llms.cohere.completion.transformation import CohereTextConfig as CohereConfig from .llms.cohere.completion.transformation import CohereTextConfig as CohereConfig
from .llms.snowflake.chat.transformation import SnowflakeConfig
from .llms.cohere.rerank.transformation import CohereRerankConfig from .llms.cohere.rerank.transformation import CohereRerankConfig
from .llms.cohere.rerank_v2.transformation import CohereRerankV2Config from .llms.cohere.rerank_v2.transformation import CohereRerankV2Config
from .llms.azure_ai.rerank.transformation import AzureAIRerankConfig from .llms.azure_ai.rerank.transformation import AzureAIRerankConfig
@ -932,6 +940,8 @@ from .llms.openai.chat.o_series_transformation import (
OpenAIOSeriesConfig, OpenAIOSeriesConfig,
) )
from .llms.snowflake.chat.transformation import SnowflakeConfig
openaiOSeriesConfig = OpenAIOSeriesConfig() openaiOSeriesConfig = OpenAIOSeriesConfig()
from .llms.openai.chat.gpt_transformation import ( from .llms.openai.chat.gpt_transformation import (
OpenAIGPTConfig, OpenAIGPTConfig,

View file

@ -571,6 +571,14 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915
or "https://api.galadriel.com/v1" or "https://api.galadriel.com/v1"
) # type: ignore ) # type: ignore
dynamic_api_key = api_key or get_secret_str("GALADRIEL_API_KEY") dynamic_api_key = api_key or get_secret_str("GALADRIEL_API_KEY")
elif custom_llm_provider == "snowflake":
api_base = (
api_base
or get_secret("SNOWFLAKE_API_BASE")
or f"https://{get_secret('SNOWFLAKE_ACCOUNT_ID')}.snowflakecomputing.com/api/v2/cortex/inference:complete"
) # type: ignore
dynamic_api_key = api_key or get_secret("SNOWFLAKE_JWT")
if api_base is not None and not isinstance(api_base, str): if api_base is not None and not isinstance(api_base, str):
raise Exception("api base needs to be a string. api_base={}".format(api_base)) raise Exception("api base needs to be a string. api_base={}".format(api_base))
if dynamic_api_key is not None and not isinstance(dynamic_api_key, str): if dynamic_api_key is not None and not isinstance(dynamic_api_key, str):

View file

@ -0,0 +1,167 @@
"""
Support for Snowflake REST API
"""
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
import httpx
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse
from ...openai_like.chat.transformation import OpenAIGPTConfig
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class SnowflakeConfig(OpenAIGPTConfig):
"""
source: https://docs.snowflake.com/en/sql-reference/functions/complete-snowflake-cortex
"""
@classmethod
def get_config(cls):
return super().get_config()
def get_supported_openai_params(self, model: str) -> List:
return ["temperature", "max_tokens", "top_p", "response_format"]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
"""
If any supported_openai_params are in non_default_params, add them to optional_params, so they are used in API call
Args:
non_default_params (dict): Non-default parameters to filter.
optional_params (dict): Optional parameters to update.
model (str): Model name for parameter support check.
Returns:
dict: Updated optional_params with supported non-default parameters.
"""
supported_openai_params = self.get_supported_openai_params(model)
for param, value in non_default_params.items():
if param in supported_openai_params:
optional_params[param] = value
return optional_params
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
response_json = raw_response.json()
logging_obj.post_call(
input=messages,
api_key="",
original_response=response_json,
additional_args={"complete_input_dict": request_data},
)
returned_response = ModelResponse(**response_json)
returned_response.model = "snowflake/" + (returned_response.model or "")
if model is not None:
returned_response._hidden_params["model"] = model
return returned_response
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
"""
Return headers to use for Snowflake completion request
Snowflake REST API Ref: https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-llm-rest-api#api-reference
Expected headers:
{
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": "Bearer " + <JWT>,
"X-Snowflake-Authorization-Token-Type": "KEYPAIR_JWT"
}
"""
if api_key is None:
raise ValueError("Missing Snowflake JWT key")
headers.update(
{
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": "Bearer " + api_key,
"X-Snowflake-Authorization-Token-Type": "KEYPAIR_JWT",
}
)
return headers
def _get_openai_compatible_provider_info(
self, api_base: Optional[str], api_key: Optional[str]
) -> Tuple[Optional[str], Optional[str]]:
api_base = (
api_base
or f"""https://{get_secret_str("SNOWFLAKE_ACCOUNT_ID")}.snowflakecomputing.com/api/v2/cortex/inference:complete"""
or get_secret_str("SNOWFLAKE_API_BASE")
)
dynamic_api_key = api_key or get_secret_str("SNOWFLAKE_JWT")
return api_base, dynamic_api_key
def get_complete_url(
self,
api_base: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
If api_base is not provided, use the default DeepSeek /chat/completions endpoint.
"""
if not api_base:
api_base = f"""https://{get_secret_str("SNOWFLAKE_ACCOUNT_ID")}.snowflakecomputing.com/api/v2/cortex/inference:complete"""
return api_base
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
stream: bool = optional_params.pop("stream", None) or False
extra_body = optional_params.pop("extra_body", {})
return {
"model": model,
"messages": messages,
"stream": stream,
**optional_params,
**extra_body,
}

View file

@ -0,0 +1,34 @@
from typing import Optional
class SnowflakeBase:
def validate_environment(
self,
headers: dict,
JWT: Optional[str] = None,
) -> dict:
"""
Return headers to use for Snowflake completion request
Snowflake REST API Ref: https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-llm-rest-api#api-reference
Expected headers:
{
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": "Bearer " + <JWT>,
"X-Snowflake-Authorization-Token-Type": "KEYPAIR_JWT"
}
"""
if JWT is None:
raise ValueError("Missing Snowflake JWT key")
headers.update(
{
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": "Bearer " + JWT,
"X-Snowflake-Authorization-Token-Type": "KEYPAIR_JWT",
}
)
return headers

View file

@ -2986,6 +2986,38 @@ def completion( # type: ignore # noqa: PLR0915
) )
return response return response
response = model_response response = model_response
elif custom_llm_provider == "snowflake" or model in litellm.snowflake_models:
try:
client = HTTPHandler(timeout=timeout) if stream is False else None # Keep this here, otherwise, the httpx.client closes and streaming is impossible
response = base_llm_http_handler.completion(
model=model,
messages=messages,
headers=headers,
model_response=model_response,
api_key=api_key,
api_base=api_base,
acompletion=acompletion,
logging_obj=logging,
optional_params=optional_params,
litellm_params=litellm_params,
timeout=timeout, # type: ignore
client= client,
custom_llm_provider=custom_llm_provider,
encoding=encoding,
stream=stream,
)
except Exception as e:
## LOGGING - log the original exception returned
logging.post_call(
input=messages,
api_key=api_key,
original_response=str(e),
additional_args={"headers": headers},
)
raise e
elif custom_llm_provider == "custom": elif custom_llm_provider == "custom":
url = litellm.api_base or api_base or "" url = litellm.api_base or api_base or ""
if url is None or url == "": if url is None or url == "":
@ -3044,6 +3076,7 @@ def completion( # type: ignore # noqa: PLR0915
model_response.created = int(time.time()) model_response.created = int(time.time())
model_response.model = model model_response.model = model
response = model_response response = model_response
elif ( elif (
custom_llm_provider in litellm._custom_providers custom_llm_provider in litellm._custom_providers
): # Assume custom LLM provider ): # Assume custom LLM provider

View file

@ -10067,5 +10067,173 @@
"output_cost_per_token": 0.000000018, "output_cost_per_token": 0.000000018,
"litellm_provider": "jina_ai", "litellm_provider": "jina_ai",
"mode": "rerank" "mode": "rerank"
},
"snowflake/deepseek-r1": {
"max_tokens": 32768,
"max_input_tokens": 32768,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/snowflake-arctic": {
"max_tokens": 4096,
"max_input_tokens": 4096,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/claude-3-5-sonnet": {
"max_tokens": 18000,
"max_input_tokens": 18000,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/mistral-large": {
"max_tokens": 32000,
"max_input_tokens": 32000,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/mistral-large2": {
"max_tokens": 128000,
"max_input_tokens": 128000,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/reka-flash": {
"max_tokens": 100000,
"max_input_tokens": 100000,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/reka-core": {
"max_tokens": 32000,
"max_input_tokens": 32000,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/jamba-instruct": {
"max_tokens": 256000,
"max_input_tokens": 256000,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/jamba-1.5-mini": {
"max_tokens": 256000,
"max_input_tokens": 256000,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/jamba-1.5-large": {
"max_tokens": 256000,
"max_input_tokens": 256000,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/mixtral-8x7b": {
"max_tokens": 32000,
"max_input_tokens": 32000,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/llama2-70b-chat": {
"max_tokens": 4096,
"max_input_tokens": 4096,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/llama3-8b": {
"max_tokens": 8000,
"max_input_tokens": 8000,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/llama3-70b": {
"max_tokens": 8000,
"max_input_tokens": 8000,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/llama3.1-8b": {
"max_tokens": 128000,
"max_input_tokens": 128000,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/llama3.1-70b": {
"max_tokens": 128000,
"max_input_tokens": 128000,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/llama3.3-70b": {
"max_tokens": 128000,
"max_input_tokens": 128000,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/snowflake-llama-3.3-70b": {
"max_tokens": 8000,
"max_input_tokens": 8000,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/llama3.1-405b": {
"max_tokens": 128000,
"max_input_tokens": 128000,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/snowflake-llama-3.1-405b": {
"max_tokens": 8000,
"max_input_tokens": 8000,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/llama3.2-1b": {
"max_tokens": 128000,
"max_input_tokens": 128000,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/llama3.2-3b": {
"max_tokens": 128000,
"max_input_tokens": 128000,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/mistral-7b": {
"max_tokens": 32000,
"max_input_tokens": 32000,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/gemma-7b": {
"max_tokens": 8000,
"max_input_tokens": 8000,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
} }
} }

View file

@ -1967,6 +1967,7 @@ class LlmProviders(str, Enum):
HUMANLOOP = "humanloop" HUMANLOOP = "humanloop"
TOPAZ = "topaz" TOPAZ = "topaz"
ASSEMBLYAI = "assemblyai" ASSEMBLYAI = "assemblyai"
SNOWFLAKE = "snowflake"
# Create a set of all provider values for quick lookup # Create a set of all provider values for quick lookup

View file

@ -6107,6 +6107,8 @@ class ProviderConfigManager:
return litellm.CohereChatConfig() return litellm.CohereChatConfig()
elif litellm.LlmProviders.COHERE == provider: elif litellm.LlmProviders.COHERE == provider:
return litellm.CohereConfig() return litellm.CohereConfig()
elif litellm.LlmProviders.SNOWFLAKE == provider:
return litellm.SnowflakeConfig()
elif litellm.LlmProviders.CLARIFAI == provider: elif litellm.LlmProviders.CLARIFAI == provider:
return litellm.ClarifaiConfig() return litellm.ClarifaiConfig()
elif litellm.LlmProviders.ANTHROPIC == provider: elif litellm.LlmProviders.ANTHROPIC == provider:

View file

@ -10067,5 +10067,173 @@
"output_cost_per_token": 0.000000018, "output_cost_per_token": 0.000000018,
"litellm_provider": "jina_ai", "litellm_provider": "jina_ai",
"mode": "rerank" "mode": "rerank"
},
"snowflake/deepseek-r1": {
"max_tokens": 32768,
"max_input_tokens": 32768,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/snowflake-arctic": {
"max_tokens": 4096,
"max_input_tokens": 4096,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/claude-3-5-sonnet": {
"max_tokens": 18000,
"max_input_tokens": 18000,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/mistral-large": {
"max_tokens": 32000,
"max_input_tokens": 32000,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/mistral-large2": {
"max_tokens": 128000,
"max_input_tokens": 128000,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/reka-flash": {
"max_tokens": 100000,
"max_input_tokens": 100000,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/reka-core": {
"max_tokens": 32000,
"max_input_tokens": 32000,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/jamba-instruct": {
"max_tokens": 256000,
"max_input_tokens": 256000,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/jamba-1.5-mini": {
"max_tokens": 256000,
"max_input_tokens": 256000,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/jamba-1.5-large": {
"max_tokens": 256000,
"max_input_tokens": 256000,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/mixtral-8x7b": {
"max_tokens": 32000,
"max_input_tokens": 32000,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/llama2-70b-chat": {
"max_tokens": 4096,
"max_input_tokens": 4096,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/llama3-8b": {
"max_tokens": 8000,
"max_input_tokens": 8000,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/llama3-70b": {
"max_tokens": 8000,
"max_input_tokens": 8000,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/llama3.1-8b": {
"max_tokens": 128000,
"max_input_tokens": 128000,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/llama3.1-70b": {
"max_tokens": 128000,
"max_input_tokens": 128000,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/llama3.3-70b": {
"max_tokens": 128000,
"max_input_tokens": 128000,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/snowflake-llama-3.3-70b": {
"max_tokens": 8000,
"max_input_tokens": 8000,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/llama3.1-405b": {
"max_tokens": 128000,
"max_input_tokens": 128000,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/snowflake-llama-3.1-405b": {
"max_tokens": 8000,
"max_input_tokens": 8000,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/llama3.2-1b": {
"max_tokens": 128000,
"max_input_tokens": 128000,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/llama3.2-3b": {
"max_tokens": 128000,
"max_input_tokens": 128000,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/mistral-7b": {
"max_tokens": 32000,
"max_input_tokens": 32000,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
},
"snowflake/gemma-7b": {
"max_tokens": 8000,
"max_input_tokens": 8000,
"max_output_tokens": 8192,
"litellm_provider": "snowflake",
"mode": "chat"
} }
} }

View file

@ -0,0 +1,76 @@
import os
import sys
import traceback
from dotenv import load_dotenv
load_dotenv()
import pytest
from litellm import completion, acompletion
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_chat_completion_snowflake(sync_mode):
try:
messages = [
{
"role": "user",
"content": "Write me a poem about the blue sky",
},
]
if sync_mode:
response = completion(
model="snowflake/mistral-7b",
messages=messages,
api_base = "https://exampleopenaiendpoint-production.up.railway.app/v1/chat/completions"
)
print(response)
assert response is not None
else:
response = await acompletion(
model="snowflake/mistral-7b",
messages=messages,
api_base = "https://exampleopenaiendpoint-production.up.railway.app/v1/chat/completions"
)
print(response)
assert response is not None
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio
@pytest.mark.parametrize("sync_mode", [True, False])
async def test_chat_completion_snowflake_stream(sync_mode):
try:
set_verbose = True
messages = [
{
"role": "user",
"content": "Write me a poem about the blue sky",
},
]
if sync_mode is False:
response = await acompletion(
model="snowflake/mistral-7b",
messages=messages,
max_tokens=100,
stream=True,
api_base = "https://exampleopenaiendpoint-production.up.railway.app/v1/chat/completions"
)
async for chunk in response:
print(chunk)
else:
response = completion(
model="snowflake/mistral-7b",
messages=messages,
max_tokens=100,
stream=True,
api_base = "https://exampleopenaiendpoint-production.up.railway.app/v1/chat/completions"
)
for chunk in response:
print(chunk)
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -55,6 +55,7 @@ def make_config_map(config: dict):
), ),
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.skip(reason="lakera deprecated their v1 endpoint.")
async def test_lakera_prompt_injection_detection(): async def test_lakera_prompt_injection_detection():
""" """
Tests to see OpenAI Moderation raises an error for a flagged response Tests to see OpenAI Moderation raises an error for a flagged response
@ -121,6 +122,7 @@ async def test_lakera_prompt_injection_detection():
), ),
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.skip(reason="lakera deprecated their v1 endpoint.")
async def test_lakera_safe_prompt(): async def test_lakera_safe_prompt():
""" """
Nothing should get raised here Nothing should get raised here
@ -146,6 +148,7 @@ async def test_lakera_safe_prompt():
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.skip(reason="lakera deprecated their v1 endpoint.")
async def test_moderations_on_embeddings(): async def test_moderations_on_embeddings():
try: try:
temp_router = litellm.Router( temp_router = litellm.Router(
@ -208,6 +211,7 @@ async def test_moderations_on_embeddings():
} }
), ),
) )
@pytest.mark.skip(reason="lakera deprecated their v1 endpoint.")
async def test_messages_for_disabled_role(spy_post): async def test_messages_for_disabled_role(spy_post):
moderation = lakeraAI_Moderation() moderation = lakeraAI_Moderation()
data = { data = {
@ -246,6 +250,7 @@ async def test_messages_for_disabled_role(spy_post):
), ),
) )
@patch("litellm.add_function_to_prompt", False) @patch("litellm.add_function_to_prompt", False)
@pytest.mark.skip(reason="lakera deprecated their v1 endpoint.")
async def test_system_message_with_function_input(spy_post): async def test_system_message_with_function_input(spy_post):
moderation = lakeraAI_Moderation() moderation = lakeraAI_Moderation()
data = { data = {
@ -290,6 +295,7 @@ async def test_system_message_with_function_input(spy_post):
), ),
) )
@patch("litellm.add_function_to_prompt", False) @patch("litellm.add_function_to_prompt", False)
@pytest.mark.skip(reason="lakera deprecated their v1 endpoint.")
async def test_multi_message_with_function_input(spy_post): async def test_multi_message_with_function_input(spy_post):
moderation = lakeraAI_Moderation() moderation = lakeraAI_Moderation()
data = { data = {
@ -337,6 +343,7 @@ async def test_multi_message_with_function_input(spy_post):
} }
), ),
) )
@pytest.mark.skip(reason="lakera deprecated their v1 endpoint.")
async def test_message_ordering(spy_post): async def test_message_ordering(spy_post):
moderation = lakeraAI_Moderation() moderation = lakeraAI_Moderation()
data = { data = {
@ -363,6 +370,7 @@ async def test_message_ordering(spy_post):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.skip(reason="lakera deprecated their v1 endpoint.")
async def test_callback_specific_param_run_pre_call_check_lakera(): async def test_callback_specific_param_run_pre_call_check_lakera():
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
@ -409,6 +417,7 @@ async def test_callback_specific_param_run_pre_call_check_lakera():
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.skip(reason="lakera deprecated their v1 endpoint.")
async def test_callback_specific_thresholds(): async def test_callback_specific_thresholds():
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union