mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
63 lines
2 KiB
Python
63 lines
2 KiB
Python
from litellm.llms.base import BaseLLM
|
|
from typing import Any, List, Optional
|
|
from typing import List, Dict, Callable, Optional, Any,cast
|
|
|
|
import litellm
|
|
from litellm.utils import ModelResponse
|
|
from litellm.types.llms.openai import AllMessageValues
|
|
|
|
from litellm.llms.openai_like.chat.handler import OpenAILikeChatHandler
|
|
from ..common_utils import SnowflakeBase
|
|
|
|
class SnowflakeChatCompletion(OpenAILikeChatHandler,SnowflakeBase):
|
|
|
|
def __init__(self, **kwargs):
|
|
super().__init__(**kwargs)
|
|
|
|
def completion(
|
|
self,
|
|
model: str,
|
|
messages: List[Dict[str, Any]],
|
|
api_base: str,
|
|
acompletion: str,
|
|
custom_prompt_dict: Dict[str, Any],
|
|
model_response: ModelResponse,
|
|
print_verbose: Callable,
|
|
encoding: Any,
|
|
JWT: str,
|
|
logging_obj: Any,
|
|
optional_params: Optional[Dict[str, Any]] = None,
|
|
litellm_params: Optional[Dict[str, Any]] = None,
|
|
logger_fn: Optional[Callable] = None,
|
|
headers: Optional[Dict[str, str]] = None,
|
|
client: Optional[Any] = None,
|
|
) -> None:
|
|
|
|
messages = litellm.SnowflakeConfig()._transform_messages(
|
|
messages=cast(List[AllMessageValues], messages), model=model
|
|
)
|
|
|
|
headers = self.validate_environment(
|
|
headers,
|
|
JWT
|
|
)
|
|
|
|
return super().completion(
|
|
model=model,
|
|
messages=messages,
|
|
api_base=api_base,
|
|
custom_llm_provider= "snowflake",
|
|
custom_prompt_dict=custom_prompt_dict,
|
|
model_response=model_response,
|
|
print_verbose=print_verbose,
|
|
encoding=encoding,
|
|
api_key=JWT,
|
|
logging_obj=logging_obj,
|
|
optional_params=optional_params,
|
|
acompletion=acompletion,
|
|
litellm_params=litellm_params,
|
|
logger_fn=logger_fn,
|
|
headers=headers,
|
|
client=client,
|
|
custom_endpoint=True,
|
|
)
|