diff --git a/docs/my-website/docs/providers/palm.md b/docs/my-website/docs/providers/palm.md
deleted file mode 100644
index 8de1947be9..0000000000
--- a/docs/my-website/docs/providers/palm.md
+++ /dev/null
@@ -1,43 +0,0 @@
-# PaLM API - Google
-
-:::warning
-
-Warning: [The PaLM API is decomissioned by Google](https://ai.google.dev/palm_docs/deprecation) The PaLM API is scheduled to be decomissioned in October 2024. Please upgrade to the Gemini API or Vertex AI API
-
-:::
-
-## Pre-requisites
-* `pip install -q google-generativeai`
-
-## Sample Usage
-```python
-from litellm import completion
-import os
-
-os.environ['PALM_API_KEY'] = ""
-response = completion(
- model="palm/chat-bison",
- messages=[{"role": "user", "content": "write code for saying hi from LiteLLM"}]
-)
-```
-
-## Sample Usage - Streaming
-```python
-from litellm import completion
-import os
-
-os.environ['PALM_API_KEY'] = ""
-response = completion(
- model="palm/chat-bison",
- messages=[{"role": "user", "content": "write code for saying hi from LiteLLM"}],
- stream=True
-)
-
-for chunk in response:
- print(chunk)
-```
-
-## Chat Models
-| Model Name | Function Call | Required OS Variables |
-|------------------|--------------------------------------|-------------------------|
-| chat-bison | `completion('palm/chat-bison', messages)` | `os.environ['PALM_API_KEY']` |
diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js
index 9aaf77787b..d1075f4e26 100644
--- a/docs/my-website/sidebars.js
+++ b/docs/my-website/sidebars.js
@@ -190,11 +190,9 @@ const sidebars = {
"providers/aleph_alpha",
"providers/baseten",
"providers/openrouter",
- "providers/palm",
"providers/sambanova",
"providers/custom_llm_server",
"providers/petals",
-
],
},
{
diff --git a/litellm/__init__.py b/litellm/__init__.py
index 058fe30d75..c8fa8f3c36 100644
--- a/litellm/__init__.py
+++ b/litellm/__init__.py
@@ -601,6 +601,7 @@ openai_compatible_providers: List = [
"cerebras",
"sambanova",
"ai21_chat",
+ "ai21",
"volcengine",
"codestral",
"deepseek",
@@ -853,7 +854,6 @@ class LlmProviders(str, Enum):
OPENROUTER = "openrouter"
VERTEX_AI = "vertex_ai"
VERTEX_AI_BETA = "vertex_ai_beta"
- PALM = "palm"
GEMINI = "gemini"
AI21 = "ai21"
BASETEN = "baseten"
@@ -871,7 +871,6 @@ class LlmProviders(str, Enum):
OLLAMA_CHAT = "ollama_chat"
DEEPINFRA = "deepinfra"
PERPLEXITY = "perplexity"
- ANYSCALE = "anyscale"
MISTRAL = "mistral"
GROQ = "groq"
NVIDIA_NIM = "nvidia_nim"
@@ -1057,10 +1056,15 @@ from .types.utils import ImageObject
from .llms.custom_llm import CustomLLM
from .llms.openai_like.chat.handler import OpenAILikeChatConfig
from .llms.galadriel.chat.transformation import GaladrielChatConfig
-from .llms.huggingface_restapi import HuggingfaceConfig
-from .llms.empower.chat.transformation import EmpowerChatConfig
from .llms.github.chat.transformation import GithubChatConfig
-from .llms.anthropic.chat.handler import AnthropicConfig
+from .llms.empower.chat.transformation import EmpowerChatConfig
+from .llms.huggingface.chat.transformation import (
+ HuggingfaceChatConfig as HuggingfaceConfig,
+)
+from .llms.oobabooga.chat.transformation import OobaboogaConfig
+from .llms.maritalk import MaritalkConfig
+from .llms.openrouter.chat.transformation import OpenrouterConfig
+from .llms.anthropic.chat.transformation import AnthropicConfig
from .llms.anthropic.experimental_pass_through.transformation import (
AnthropicExperimentalPassThroughConfig,
)
@@ -1069,24 +1073,26 @@ from .llms.anthropic.completion.transformation import AnthropicTextConfig
from .llms.databricks.chat.transformation import DatabricksConfig
from .llms.databricks.embed.transformation import DatabricksEmbeddingConfig
from .llms.predibase import PredibaseConfig
-from .llms.replicate import ReplicateConfig
+from .llms.replicate.chat.transformation import ReplicateConfig
from .llms.cohere.completion.transformation import CohereTextConfig as CohereConfig
from .llms.clarifai.chat.transformation import ClarifaiConfig
-from .llms.cloudflare.chat.transformation import CloudflareChatConfig
-from .llms.ai21.completion import AI21Config
-from .llms.ai21.chat import AI21ChatConfig
+from .llms.ai21.chat.transformation import AI21ChatConfig, AI21ChatConfig as AI21Config
from .llms.together_ai.chat import TogetherAIConfig
-from .llms.palm import PalmConfig
-from .llms.gemini import GeminiConfig
-from .llms.nlp_cloud import NLPCloudConfig
+from .llms.cloudflare.chat.transformation import CloudflareChatConfig
+from .llms.deprecated_providers.palm import (
+ PalmConfig,
+) # here to prevent breaking changes
+from .llms.nlp_cloud.chat.handler import NLPCloudConfig
from .llms.aleph_alpha import AlephAlphaConfig
from .llms.petals import PetalsConfig
from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexGeminiConfig,
GoogleAIStudioGeminiConfig,
VertexAIConfig,
+ GoogleAIStudioGeminiConfig as GeminiConfig,
)
+
from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings.transformation import (
VertexAITextEmbeddingConfig,
)
@@ -1107,7 +1113,6 @@ from .llms.ollama.completion.transformation import OllamaConfig
from .llms.sagemaker.completion.transformation import SagemakerConfig
from .llms.sagemaker.chat.transformation import SagemakerChatConfig
from .llms.ollama_chat import OllamaChatConfig
-from .llms.maritalk import MaritTalkConfig
from .llms.bedrock.chat.invoke_handler import (
AmazonCohereChatConfig,
AmazonConverseConfig,
@@ -1134,11 +1139,8 @@ from .llms.bedrock.embed.amazon_titan_v2_transformation import (
)
from .llms.cohere.chat.transformation import CohereChatConfig
from .llms.bedrock.embed.cohere_transformation import BedrockCohereEmbeddingConfig
-from .llms.openai.openai import (
- OpenAIConfig,
- MistralEmbeddingConfig,
- DeepInfraConfig,
-)
+from .llms.openai.openai import OpenAIConfig, MistralEmbeddingConfig
+from .llms.deepinfra.chat.transformation import DeepInfraConfig
from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig
from .llms.groq.chat.transformation import GroqChatConfig
from .llms.azure_ai.chat.transformation import AzureAIStudioConfig
@@ -1167,7 +1169,7 @@ nvidiaNimEmbeddingConfig = NvidiaNimEmbeddingConfig()
from .llms.cerebras.chat import CerebrasConfig
from .llms.sambanova.chat import SambanovaConfig
-from .llms.ai21.chat import AI21ChatConfig
+from .llms.ai21.chat.transformation import AI21ChatConfig
from .llms.fireworks_ai.chat.transformation import FireworksAIConfig
from .llms.fireworks_ai.embed.fireworks_ai_transformation import (
FireworksAIEmbeddingConfig,
@@ -1183,6 +1185,7 @@ from .llms.azure.azure import (
)
from .llms.azure.chat.gpt_transformation import AzureOpenAIConfig
+from .llms.azure.completion.transformation import AzureOpenAITextConfig
from .llms.hosted_vllm.chat.transformation import HostedVLLMChatConfig
from .llms.vllm.completion.transformation import VLLMConfig
from .llms.deepseek.chat.transformation import DeepSeekChatConfig
diff --git a/litellm/constants.py b/litellm/constants.py
index a5a629c9fd..8f96941025 100644
--- a/litellm/constants.py
+++ b/litellm/constants.py
@@ -3,54 +3,51 @@ DEFAULT_BATCH_SIZE = 512
DEFAULT_FLUSH_INTERVAL_SECONDS = 5
DEFAULT_MAX_RETRIES = 2
LITELLM_CHAT_PROVIDERS = [
- # "openai",
- # "openai_like",
- # "xai",
- # "custom_openai",
- # "text-completion-openai",
- # "cohere",
- # "cohere_chat",
- # "clarifai",
- # "anthropic",
- # "anthropic_text",
- # "replicate",
- # "huggingface",
- # "together_ai",
- # "openrouter",
- # "vertex_ai",
- # "vertex_ai_beta",
- # "palm",
- # "gemini",
- # "ai21",
- # "baseten",
- # "azure",
- # "azure_text",
- # "azure_ai",
- # "sagemaker",
- # "sagemaker_chat",
- # "bedrock",
+ "openai",
+ "openai_like",
+ "xai",
+ "custom_openai",
+ "text-completion-openai",
+ "cohere",
+ "cohere_chat",
+ "clarifai",
+ "anthropic",
+ "anthropic_text",
+ "replicate",
+ "huggingface",
+ "together_ai",
+ "openrouter",
+ "vertex_ai",
+ "vertex_ai_beta",
+ "gemini",
+ "ai21",
+ "baseten",
+ "azure",
+ "azure_text",
+ "azure_ai",
+ "sagemaker",
+ "sagemaker_chat",
+ "bedrock",
"vllm",
- # "nlp_cloud",
- # "petals",
- # "oobabooga",
+ "nlp_cloud",
+ "petals",
+ "oobabooga",
"ollama",
- # "ollama_chat",
- # "deepinfra",
- # "perplexity",
- # "anyscale",
- # "mistral",
- # "groq",
- # "nvidia_nim",
- # "cerebras",
- # "ai21_chat",
- # "volcengine",
- # "codestral",
- # "text-completion-codestral",
- # "deepseek",
- # "sambanova",
- # "maritalk",
- # "voyage",
- # "cloudflare",
+ "ollama_chat",
+ "deepinfra",
+ "perplexity",
+ "mistral",
+ "groq",
+ "nvidia_nim",
+ "cerebras",
+ "ai21_chat",
+ "volcengine",
+ "codestral",
+ "text-completion-codestral",
+ "deepseek",
+ "sambanova",
+ "maritalk",
+ "cloudflare",
"fireworks_ai",
"friendliai",
"watsonx",
diff --git a/litellm/litellm_core_utils/get_llm_provider_logic.py b/litellm/litellm_core_utils/get_llm_provider_logic.py
index 522068d571..57ab1ec7ef 100644
--- a/litellm/litellm_core_utils/get_llm_provider_logic.py
+++ b/litellm/litellm_core_utils/get_llm_provider_logic.py
@@ -285,9 +285,7 @@ def get_llm_provider( # noqa: PLR0915
):
custom_llm_provider = "vertex_ai"
## ai21
- elif model in litellm.ai21_models:
- custom_llm_provider = "ai21"
- elif model in litellm.ai21_chat_models:
+ elif model in litellm.ai21_chat_models or model in litellm.ai21_models:
custom_llm_provider = "ai21_chat"
api_base = (
api_base
diff --git a/litellm/litellm_core_utils/get_supported_openai_params.py b/litellm/litellm_core_utils/get_supported_openai_params.py
index 153b77cc63..d33ccfe969 100644
--- a/litellm/litellm_core_utils/get_supported_openai_params.py
+++ b/litellm/litellm_core_utils/get_supported_openai_params.py
@@ -31,7 +31,7 @@ def get_supported_openai_params( # noqa: PLR0915
elif custom_llm_provider == "ollama":
return litellm.OllamaConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "ollama_chat":
- return litellm.OllamaChatConfig().get_supported_openai_params()
+ return litellm.OllamaChatConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "anthropic":
return litellm.AnthropicConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "fireworks_ai":
@@ -50,7 +50,7 @@ def get_supported_openai_params( # noqa: PLR0915
return litellm.CerebrasConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "xai":
return litellm.XAIChatConfig().get_supported_openai_params(model=model)
- elif custom_llm_provider == "ai21_chat":
+ elif custom_llm_provider == "ai21_chat" or custom_llm_provider == "ai21":
return litellm.AI21ChatConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "volcengine":
return litellm.VolcEngineConfig().get_supported_openai_params(model=model)
@@ -97,79 +97,50 @@ def get_supported_openai_params( # noqa: PLR0915
model=model
)
else:
- return litellm.AzureOpenAIConfig().get_supported_openai_params()
+ return litellm.AzureOpenAIConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "openrouter":
- return [
- "temperature",
- "top_p",
- "frequency_penalty",
- "presence_penalty",
- "repetition_penalty",
- "seed",
- "max_tokens",
- "logit_bias",
- "logprobs",
- "top_logprobs",
- "response_format",
- "stop",
- "tools",
- "tool_choice",
- ]
+ return litellm.OpenrouterConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "mistral" or custom_llm_provider == "codestral":
# mistal and codestral api have the exact same params
if request_type == "chat_completion":
- return litellm.MistralConfig().get_supported_openai_params()
+ return litellm.MistralConfig().get_supported_openai_params(model=model)
elif request_type == "embeddings":
return litellm.MistralEmbeddingConfig().get_supported_openai_params()
elif custom_llm_provider == "text-completion-codestral":
- return litellm.MistralTextCompletionConfig().get_supported_openai_params()
+ return litellm.MistralTextCompletionConfig().get_supported_openai_params(
+ model=model
+ )
+ elif custom_llm_provider == "sambanova":
+ return litellm.SambanovaConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "replicate":
- return [
- "stream",
- "temperature",
- "max_tokens",
- "top_p",
- "stop",
- "seed",
- "tools",
- "tool_choice",
- "functions",
- "function_call",
- ]
+ return litellm.ReplicateConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "huggingface":
- return litellm.HuggingfaceConfig().get_supported_openai_params()
+ return litellm.HuggingfaceConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "jina_ai":
if request_type == "embeddings":
return litellm.JinaAIEmbeddingConfig().get_supported_openai_params()
elif custom_llm_provider == "together_ai":
return litellm.TogetherAIConfig().get_supported_openai_params(model=model)
- elif custom_llm_provider == "ai21":
- return [
- "stream",
- "n",
- "temperature",
- "max_tokens",
- "top_p",
- "stop",
- "frequency_penalty",
- "presence_penalty",
- ]
elif custom_llm_provider == "databricks":
if request_type == "chat_completion":
return litellm.DatabricksConfig().get_supported_openai_params(model=model)
elif request_type == "embeddings":
return litellm.DatabricksEmbeddingConfig().get_supported_openai_params()
elif custom_llm_provider == "palm" or custom_llm_provider == "gemini":
- return litellm.GoogleAIStudioGeminiConfig().get_supported_openai_params()
+ return litellm.GoogleAIStudioGeminiConfig().get_supported_openai_params(
+ model=model
+ )
elif custom_llm_provider == "vertex_ai":
if request_type == "chat_completion":
if model.startswith("meta/"):
return litellm.VertexAILlama3Config().get_supported_openai_params()
if model.startswith("mistral"):
- return litellm.MistralConfig().get_supported_openai_params()
+ return litellm.MistralConfig().get_supported_openai_params(model=model)
if model.startswith("codestral"):
return (
- litellm.MistralTextCompletionConfig().get_supported_openai_params()
+ litellm.MistralTextCompletionConfig().get_supported_openai_params(
+ model=model
+ )
)
if model.startswith("claude"):
return litellm.VertexAIAnthropicConfig().get_supported_openai_params(
@@ -180,7 +151,7 @@ def get_supported_openai_params( # noqa: PLR0915
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()
elif custom_llm_provider == "vertex_ai_beta":
if request_type == "chat_completion":
- return litellm.VertexGeminiConfig().get_supported_openai_params()
+ return litellm.VertexGeminiConfig().get_supported_openai_params(model=model)
elif request_type == "embeddings":
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()
elif custom_llm_provider == "sagemaker":
@@ -199,20 +170,11 @@ def get_supported_openai_params( # noqa: PLR0915
elif custom_llm_provider == "cloudflare":
return litellm.CloudflareChatConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "nlp_cloud":
- return [
- "max_tokens",
- "stream",
- "temperature",
- "top_p",
- "presence_penalty",
- "frequency_penalty",
- "n",
- "stop",
- ]
+ return litellm.NLPCloudConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "petals":
return ["max_tokens", "temperature", "top_p", "stream"]
elif custom_llm_provider == "deepinfra":
- return litellm.DeepInfraConfig().get_supported_openai_params()
+ return litellm.DeepInfraConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "perplexity":
return [
"temperature",
diff --git a/litellm/llms/ai21/chat.py b/litellm/llms/ai21/chat/transformation.py
similarity index 63%
rename from litellm/llms/ai21/chat.py
rename to litellm/llms/ai21/chat/transformation.py
index 7a60b1904f..06f87a6fe4 100644
--- a/litellm/llms/ai21/chat.py
+++ b/litellm/llms/ai21/chat/transformation.py
@@ -7,8 +7,10 @@ this is OpenAI compatible - no translation needed / occurs
import types
from typing import Optional, Union
+from ...openai_like.chat.transformation import OpenAILikeChatConfig
-class AI21ChatConfig:
+
+class AI21ChatConfig(OpenAILikeChatConfig):
"""
Reference: https://docs.ai21.com/reference/jamba-15-api-ref#request-parameters
@@ -19,8 +21,6 @@ class AI21ChatConfig:
response_format: Optional[dict] = None
documents: Optional[list] = None
max_tokens: Optional[int] = None
- temperature: Optional[float] = None
- top_p: Optional[float] = None
stop: Optional[Union[str, list]] = None
n: Optional[int] = None
stream: Optional[bool] = None
@@ -49,21 +49,7 @@ class AI21ChatConfig:
@classmethod
def get_config(cls):
- return {
- k: v
- for k, v in cls.__dict__.items()
- if not k.startswith("__")
- and not isinstance(
- v,
- (
- types.FunctionType,
- types.BuiltinFunctionType,
- classmethod,
- staticmethod,
- ),
- )
- and v is not None
- }
+ return super().get_config()
def get_supported_openai_params(self, model: str) -> list:
"""
@@ -77,22 +63,9 @@ class AI21ChatConfig:
"max_tokens",
"max_completion_tokens",
"temperature",
- "top_p",
"stop",
"n",
"stream",
"seed",
"tool_choice",
- "user",
]
-
- def map_openai_params(
- self, model: str, non_default_params: dict, optional_params: dict
- ) -> dict:
- supported_openai_params = self.get_supported_openai_params(model=model)
- for param, value in non_default_params.items():
- if param == "max_completion_tokens":
- optional_params["max_tokens"] = value
- elif param in supported_openai_params:
- optional_params[param] = value
- return optional_params
diff --git a/litellm/llms/ai21/completion.py b/litellm/llms/ai21/completion.py
deleted file mode 100644
index 0edd7e2aaf..0000000000
--- a/litellm/llms/ai21/completion.py
+++ /dev/null
@@ -1,221 +0,0 @@
-import json
-import os
-import time # type: ignore
-import traceback
-import types
-from enum import Enum
-from typing import Callable, Optional
-
-import httpx
-import requests # type: ignore
-
-import litellm
-from litellm.utils import Choices, Message, ModelResponse
-
-
-class AI21Error(Exception):
- def __init__(self, status_code, message):
- self.status_code = status_code
- self.message = message
- self.request = httpx.Request(
- method="POST", url="https://api.ai21.com/studio/v1/"
- )
- self.response = httpx.Response(status_code=status_code, request=self.request)
- super().__init__(
- self.message
- ) # Call the base class constructor with the parameters it needs
-
-
-class AI21Config:
- """
- Reference: https://docs.ai21.com/reference/j2-complete-ref
-
- The class `AI21Config` provides configuration for the AI21's API interface. Below are the parameters:
-
- - `numResults` (int32): Number of completions to sample and return. Optional, default is 1. If the temperature is greater than 0 (non-greedy decoding), a value greater than 1 can be meaningful.
-
- - `maxTokens` (int32): The maximum number of tokens to generate per result. Optional, default is 16. If no `stopSequences` are given, generation stops after producing `maxTokens`.
-
- - `minTokens` (int32): The minimum number of tokens to generate per result. Optional, default is 0. If `stopSequences` are given, they are ignored until `minTokens` are generated.
-
- - `temperature` (float): Modifies the distribution from which tokens are sampled. Optional, default is 0.7. A value of 0 essentially disables sampling and results in greedy decoding.
-
- - `topP` (float): Used for sampling tokens from the corresponding top percentile of probability mass. Optional, default is 1. For instance, a value of 0.9 considers only tokens comprising the top 90% probability mass.
-
- - `stopSequences` (array of strings): Stops decoding if any of the input strings is generated. Optional.
-
- - `topKReturn` (int32): Range between 0 to 10, including both. Optional, default is 0. Specifies the top-K alternative tokens to return. A non-zero value includes the string representations and log-probabilities for each of the top-K alternatives at each position.
-
- - `frequencyPenalty` (object): Placeholder for frequency penalty object.
-
- - `presencePenalty` (object): Placeholder for presence penalty object.
-
- - `countPenalty` (object): Placeholder for count penalty object.
- """
-
- numResults: Optional[int] = None
- maxTokens: Optional[int] = None
- minTokens: Optional[int] = None
- temperature: Optional[float] = None
- topP: Optional[float] = None
- stopSequences: Optional[list] = None
- topKReturn: Optional[int] = None
- frequencePenalty: Optional[dict] = None
- presencePenalty: Optional[dict] = None
- countPenalty: Optional[dict] = None
-
- def __init__(
- self,
- numResults: Optional[int] = None,
- maxTokens: Optional[int] = None,
- minTokens: Optional[int] = None,
- temperature: Optional[float] = None,
- topP: Optional[float] = None,
- stopSequences: Optional[list] = None,
- topKReturn: Optional[int] = None,
- frequencePenalty: Optional[dict] = None,
- presencePenalty: Optional[dict] = None,
- countPenalty: Optional[dict] = None,
- ) -> None:
- locals_ = locals()
- for key, value in locals_.items():
- if key != "self" and value is not None:
- setattr(self.__class__, key, value)
-
- @classmethod
- def get_config(cls):
- return {
- k: v
- for k, v in cls.__dict__.items()
- if not k.startswith("__")
- and not isinstance(
- v,
- (
- types.FunctionType,
- types.BuiltinFunctionType,
- classmethod,
- staticmethod,
- ),
- )
- and v is not None
- }
-
-
-def validate_environment(api_key):
- if api_key is None:
- raise ValueError(
- "Missing AI21 API Key - A call is being made to ai21 but no key is set either in the environment variables or via params"
- )
- headers = {
- "accept": "application/json",
- "content-type": "application/json",
- "Authorization": "Bearer " + api_key,
- }
- return headers
-
-
-def completion(
- model: str,
- messages: list,
- api_base: str,
- model_response: ModelResponse,
- print_verbose: Callable,
- encoding,
- api_key,
- logging_obj,
- optional_params: dict,
- litellm_params=None,
- logger_fn=None,
-):
- headers = validate_environment(api_key)
- model = model
- prompt = ""
- for message in messages:
- if "role" in message:
- if message["role"] == "user":
- prompt += f"{message['content']}"
- else:
- prompt += f"{message['content']}"
- else:
- prompt += f"{message['content']}"
-
- ## Load Config
- config = litellm.AI21Config.get_config()
- for k, v in config.items():
- if (
- k not in optional_params
- ): # completion(top_k=3) > ai21_config(top_k=3) <- allows for dynamic variables to be passed in
- optional_params[k] = v
-
- data = {
- "prompt": prompt,
- # "instruction": prompt, # some baseten models require the prompt to be passed in via the 'instruction' kwarg
- **optional_params,
- }
-
- ## LOGGING
- logging_obj.pre_call(
- input=prompt,
- api_key=api_key,
- additional_args={"complete_input_dict": data},
- )
- ## COMPLETION CALL
- response = requests.post(
- api_base + model + "/complete", headers=headers, data=json.dumps(data)
- )
- if response.status_code != 200:
- raise AI21Error(status_code=response.status_code, message=response.text)
- if "stream" in optional_params and optional_params["stream"] is True:
- return response.iter_lines()
- else:
- ## LOGGING
- logging_obj.post_call(
- input=prompt,
- api_key=api_key,
- original_response=response.text,
- additional_args={"complete_input_dict": data},
- )
- ## RESPONSE OBJECT
- completion_response = response.json()
- try:
- choices_list = []
- for idx, item in enumerate(completion_response["completions"]):
- if len(item["data"]["text"]) > 0:
- message_obj = Message(content=item["data"]["text"])
- else:
- message_obj = Message(content=None)
- choice_obj = Choices(
- finish_reason=item["finishReason"]["reason"],
- index=idx + 1,
- message=message_obj,
- )
- choices_list.append(choice_obj)
- model_response.choices = choices_list # type: ignore
- except Exception:
- raise AI21Error(
- message=traceback.format_exc(), status_code=response.status_code
- )
-
- ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
- prompt_tokens = len(encoding.encode(prompt))
- completion_tokens = len(
- encoding.encode(model_response["choices"][0]["message"].get("content"))
- )
-
- model_response.created = int(time.time())
- model_response.model = model
- setattr(
- model_response,
- "usage",
- litellm.Usage(
- prompt_tokens=prompt_tokens,
- completion_tokens=completion_tokens,
- total_tokens=prompt_tokens + completion_tokens,
- ),
- )
- return model_response
-
-
-def embedding():
- # logic for parsing in - calling - parsing out model embedding calls
- pass
diff --git a/litellm/llms/anthropic/chat/handler.py b/litellm/llms/anthropic/chat/handler.py
index 444082fac5..275e3b868d 100644
--- a/litellm/llms/anthropic/chat/handler.py
+++ b/litellm/llms/anthropic/chat/handler.py
@@ -52,20 +52,6 @@ from ..common_utils import AnthropicError, process_anthropic_headers
from .transformation import AnthropicConfig
-# makes headers for API call
-def validate_environment(
- api_key,
- user_headers,
- model,
- messages: List[AllMessageValues],
- is_vertex_request: bool,
- tools: Optional[List[AllAnthropicToolsValues]],
- anthropic_version: Optional[str] = None,
-):
-
- pass
-
-
async def make_call(
client: Optional[AsyncHTTPHandler],
api_base: str,
@@ -239,7 +225,7 @@ class AnthropicChatCompletion(BaseLLM):
data: dict,
optional_params: dict,
json_mode: bool,
- litellm_params=None,
+ litellm_params: dict,
logger_fn=None,
headers={},
client: Optional[AsyncHTTPHandler] = None,
@@ -283,6 +269,7 @@ class AnthropicChatCompletion(BaseLLM):
request_data=data,
messages=messages,
optional_params=optional_params,
+ litellm_params=litellm_params,
encoding=encoding,
json_mode=json_mode,
)
@@ -460,6 +447,7 @@ class AnthropicChatCompletion(BaseLLM):
request_data=data,
messages=messages,
optional_params=optional_params,
+ litellm_params=litellm_params,
encoding=encoding,
json_mode=json_mode,
)
diff --git a/litellm/llms/anthropic/chat/transformation.py b/litellm/llms/anthropic/chat/transformation.py
index 13016d1595..8454952886 100644
--- a/litellm/llms/anthropic/chat/transformation.py
+++ b/litellm/llms/anthropic/chat/transformation.py
@@ -567,6 +567,7 @@ class AnthropicConfig(BaseConfig):
request_data: Dict,
messages: List[AllMessageValues],
optional_params: Dict,
+ litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
@@ -715,11 +716,6 @@ class AnthropicConfig(BaseConfig):
return litellm.Message(content=json_mode_content_str)
return None
- def _transform_messages(
- self, messages: List[AllMessageValues]
- ) -> List[AllMessageValues]:
- return messages
-
def get_error_class(
self, error_message: str, status_code: int, headers: Union[Dict, httpx.Headers]
) -> BaseLLMException:
diff --git a/litellm/llms/anthropic/completion/transformation.py b/litellm/llms/anthropic/completion/transformation.py
index 7436327324..e556b833ba 100644
--- a/litellm/llms/anthropic/completion/transformation.py
+++ b/litellm/llms/anthropic/completion/transformation.py
@@ -180,6 +180,7 @@ class AnthropicTextConfig(BaseConfig):
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
+ litellm_params: dict,
encoding: str,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
diff --git a/litellm/llms/azure/azure.py b/litellm/llms/azure/azure.py
index fb2dfbc9f1..2735884f70 100644
--- a/litellm/llms/azure/azure.py
+++ b/litellm/llms/azure/azure.py
@@ -35,38 +35,11 @@ from ...types.llms.openai import (
RetrieveBatchRequest,
)
from ..base import BaseLLM
-from .common_utils import process_azure_headers
+from .common_utils import AzureOpenAIError, process_azure_headers
azure_ad_cache = DualCache()
-class AzureOpenAIError(Exception):
- def __init__(
- self,
- status_code,
- message,
- request: Optional[httpx.Request] = None,
- response: Optional[httpx.Response] = None,
- headers: Optional[httpx.Headers] = None,
- ):
- self.status_code = status_code
- self.message = message
- self.headers = headers
- if request:
- self.request = request
- else:
- self.request = httpx.Request(method="POST", url="https://api.openai.com/v1")
- if response:
- self.response = response
- else:
- self.response = httpx.Response(
- status_code=status_code, request=self.request
- )
- super().__init__(
- self.message
- ) # Call the base class constructor with the parameters it needs
-
-
class AzureOpenAIAssistantsAPIConfig:
"""
Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/assistants-reference-messages?tabs=python#create-message
@@ -412,8 +385,12 @@ class AzureChatCompletion(BaseLLM):
data = {"model": None, "messages": messages, **optional_params}
else:
- data = litellm.AzureOpenAIConfig.transform_request(
- model=model, messages=messages, optional_params=optional_params
+ data = litellm.AzureOpenAIConfig().transform_request(
+ model=model,
+ messages=messages,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ headers=headers or {},
)
if acompletion is True:
diff --git a/litellm/llms/azure/chat/gpt_transformation.py b/litellm/llms/azure/chat/gpt_transformation.py
index 8429edadd2..4e308d0ea2 100644
--- a/litellm/llms/azure/chat/gpt_transformation.py
+++ b/litellm/llms/azure/chat/gpt_transformation.py
@@ -1,7 +1,10 @@
import types
-from typing import List, Optional, Type, Union
+from typing import TYPE_CHECKING, Any, List, Optional, Type, Union
+
+from httpx._models import Headers, Response
import litellm
+from litellm.llms.base_llm.transformation import BaseLLMException
from ....exceptions import UnsupportedParamsError
from ....types.llms.openai import (
@@ -11,10 +14,19 @@ from ....types.llms.openai import (
ChatCompletionToolParam,
ChatCompletionToolParamFunctionChunk,
)
+from ...base_llm.transformation import BaseConfig
from ...prompt_templates.factory import convert_to_azure_openai_messages
+from ..common_utils import AzureOpenAIError
+
+if TYPE_CHECKING:
+ from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
+
+ LoggingClass = LiteLLMLoggingObj
+else:
+ LoggingClass = Any
-class AzureOpenAIConfig:
+class AzureOpenAIConfig(BaseConfig):
"""
Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions
@@ -61,23 +73,9 @@ class AzureOpenAIConfig:
@classmethod
def get_config(cls):
- return {
- k: v
- for k, v in cls.__dict__.items()
- if not k.startswith("__")
- and not isinstance(
- v,
- (
- types.FunctionType,
- types.BuiltinFunctionType,
- classmethod,
- staticmethod,
- ),
- )
- and v is not None
- }
+ return super().get_config()
- def get_supported_openai_params(self):
+ def get_supported_openai_params(self, model: str) -> List[str]:
return [
"temperature",
"n",
@@ -110,10 +108,10 @@ class AzureOpenAIConfig:
non_default_params: dict,
optional_params: dict,
model: str,
- api_version: str, # Y-M-D-{optional}
- drop_params,
+ drop_params: bool,
+ api_version: str = "",
) -> dict:
- supported_openai_params = self.get_supported_openai_params()
+ supported_openai_params = self.get_supported_openai_params(model)
api_version_times = api_version.split("-")
api_version_year = api_version_times[0]
@@ -204,9 +202,13 @@ class AzureOpenAIConfig:
return optional_params
- @classmethod
def transform_request(
- cls, model: str, messages: List[AllMessageValues], optional_params: dict
+ self,
+ model: str,
+ messages: List[AllMessageValues],
+ optional_params: dict,
+ litellm_params: dict,
+ headers: dict,
) -> dict:
messages = convert_to_azure_openai_messages(messages)
return {
@@ -215,6 +217,24 @@ class AzureOpenAIConfig:
**optional_params,
}
+ def transform_response(
+ self,
+ model: str,
+ raw_response: Response,
+ model_response: litellm.ModelResponse,
+ logging_obj: LoggingClass,
+ request_data: dict,
+ messages: List[AllMessageValues],
+ optional_params: dict,
+ litellm_params: dict,
+ encoding: Any,
+ api_key: Optional[str] = None,
+ json_mode: Optional[bool] = None,
+ ) -> litellm.ModelResponse:
+ raise NotImplementedError(
+ "Azure OpenAI handler.py has custom logic for transforming response, as it uses the OpenAI SDK."
+ )
+
def get_mapped_special_auth_params(self) -> dict:
return {"token": "azure_ad_token"}
@@ -246,3 +266,22 @@ class AzureOpenAIConfig:
"westus3",
"westus4",
]
+
+ def get_error_class(
+ self, error_message: str, status_code: int, headers: Union[dict, Headers]
+ ) -> BaseLLMException:
+ return AzureOpenAIError(
+ message=error_message, status_code=status_code, headers=headers
+ )
+
+ def validate_environment(
+ self,
+ headers: dict,
+ model: str,
+ messages: List[AllMessageValues],
+ optional_params: dict,
+ api_key: Optional[str] = None,
+ ) -> dict:
+ raise NotImplementedError(
+ "Azure OpenAI has custom logic for validating environment, as it uses the OpenAI SDK."
+ )
diff --git a/litellm/llms/azure/common_utils.py b/litellm/llms/azure/common_utils.py
index 01faa40264..b5033295c4 100644
--- a/litellm/llms/azure/common_utils.py
+++ b/litellm/llms/azure/common_utils.py
@@ -1,7 +1,27 @@
-from typing import Union
+from typing import Optional, Union
import httpx
+from litellm.llms.base_llm.transformation import BaseLLMException
+
+
+class AzureOpenAIError(BaseLLMException):
+ def __init__(
+ self,
+ status_code,
+ message,
+ request: Optional[httpx.Request] = None,
+ response: Optional[httpx.Response] = None,
+ headers: Optional[Union[httpx.Headers, dict]] = None,
+ ):
+ super().__init__(
+ status_code=status_code,
+ message=message,
+ request=request,
+ response=response,
+ headers=headers,
+ )
+
def process_azure_headers(headers: Union[httpx.Headers, dict]) -> dict:
openai_headers = {}
diff --git a/litellm/llms/azure_text.py b/litellm/llms/azure/completion/handler.py
similarity index 81%
rename from litellm/llms/azure_text.py
rename to litellm/llms/azure/completion/handler.py
index f72accfb60..193776fd1d 100644
--- a/litellm/llms/azure_text.py
+++ b/litellm/llms/azure/completion/handler.py
@@ -19,104 +19,16 @@ from litellm.utils import (
convert_to_model_response_object,
)
-from .base import BaseLLM
-from .openai.completion.handler import OpenAITextCompletion
-from .openai.completion.transformation import OpenAITextCompletionConfig
-from .prompt_templates.factory import custom_prompt, prompt_factory
+from ...base import BaseLLM
+from ...openai.completion.handler import OpenAITextCompletion
+from ...openai.completion.transformation import OpenAITextCompletionConfig
+from ...prompt_templates.factory import custom_prompt, prompt_factory
+from ..common_utils import AzureOpenAIError
openai_text_completion_config = OpenAITextCompletionConfig()
-class AzureOpenAIError(Exception):
- def __init__(
- self,
- status_code,
- message,
- request: Optional[httpx.Request] = None,
- response: Optional[httpx.Response] = None,
- headers: Optional[httpx.Headers] = None,
- ):
- self.status_code = status_code
- self.message = message
- self.headers = headers
- if request:
- self.request = request
- else:
- self.request = httpx.Request(method="POST", url="https://api.openai.com/v1")
- if response:
- self.response = response
- else:
- self.response = httpx.Response(
- status_code=status_code, request=self.request
- )
- super().__init__(
- self.message
- ) # Call the base class constructor with the parameters it needs
-
-
-class AzureOpenAIConfig(OpenAIConfig):
- """
- Reference: https://platform.openai.com/docs/api-reference/chat/create
-
- The class `AzureOpenAIConfig` provides configuration for the OpenAI's Chat API interface, for use with Azure. It inherits from `OpenAIConfig`. Below are the parameters::
-
- - `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition.
-
- - `function_call` (string or object): This optional parameter controls how the model calls functions.
-
- - `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs.
-
- - `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
-
- - `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion.
-
- - `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message.
-
- - `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics.
-
- - `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
-
- - `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2.
-
- - `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
- """
-
- def __init__(
- self,
- frequency_penalty: Optional[int] = None,
- function_call: Optional[Union[str, dict]] = None,
- functions: Optional[list] = None,
- logit_bias: Optional[dict] = None,
- max_tokens: Optional[int] = None,
- n: Optional[int] = None,
- presence_penalty: Optional[int] = None,
- stop: Optional[Union[str, list]] = None,
- temperature: Optional[int] = None,
- top_p: Optional[int] = None,
- ) -> None:
- super().__init__(
- frequency_penalty=frequency_penalty,
- function_call=function_call,
- functions=functions,
- logit_bias=logit_bias,
- max_tokens=max_tokens,
- n=n,
- presence_penalty=presence_penalty,
- stop=stop,
- temperature=temperature,
- top_p=top_p,
- )
-
-
def select_azure_base_url_or_endpoint(azure_client_params: dict):
- # azure_client_params = {
- # "api_version": api_version,
- # "azure_endpoint": api_base,
- # "azure_deployment": model,
- # "http_client": litellm.client_session,
- # "max_retries": max_retries,
- # "timeout": timeout,
- # }
azure_endpoint = azure_client_params.get("azure_endpoint", None)
if azure_endpoint is not None:
# see : https://github.com/openai/openai-python/blob/3d61ed42aba652b547029095a7eb269ad4e1e957/src/openai/lib/azure.py#L192
diff --git a/litellm/llms/azure/completion/transformation.py b/litellm/llms/azure/completion/transformation.py
new file mode 100644
index 0000000000..bc7b97c6ef
--- /dev/null
+++ b/litellm/llms/azure/completion/transformation.py
@@ -0,0 +1,53 @@
+from typing import Optional, Union
+
+from ...openai.completion.transformation import OpenAITextCompletionConfig
+
+
+class AzureOpenAITextConfig(OpenAITextCompletionConfig):
+ """
+ Reference: https://platform.openai.com/docs/api-reference/chat/create
+
+ The class `AzureOpenAIConfig` provides configuration for the OpenAI's Chat API interface, for use with Azure. It inherits from `OpenAIConfig`. Below are the parameters::
+
+ - `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition.
+
+ - `function_call` (string or object): This optional parameter controls how the model calls functions.
+
+ - `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs.
+
+ - `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
+
+ - `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion.
+
+ - `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message.
+
+ - `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics.
+
+ - `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
+
+ - `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2.
+
+ - `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
+ """
+
+ def __init__(
+ self,
+ frequency_penalty: Optional[int] = None,
+ logit_bias: Optional[dict] = None,
+ max_tokens: Optional[int] = None,
+ n: Optional[int] = None,
+ presence_penalty: Optional[int] = None,
+ stop: Optional[Union[str, list]] = None,
+ temperature: Optional[int] = None,
+ top_p: Optional[int] = None,
+ ) -> None:
+ super().__init__(
+ frequency_penalty=frequency_penalty,
+ logit_bias=logit_bias,
+ max_tokens=max_tokens,
+ n=n,
+ presence_penalty=presence_penalty,
+ stop=stop,
+ temperature=temperature,
+ top_p=top_p,
+ )
diff --git a/litellm/llms/azure_ai/chat/__init__.py b/litellm/llms/azure_ai/chat/__init__.py
deleted file mode 100644
index 62378de405..0000000000
--- a/litellm/llms/azure_ai/chat/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .handler import AzureAIChatCompletion
diff --git a/litellm/llms/azure_ai/chat/handler.py b/litellm/llms/azure_ai/chat/handler.py
index 711d31b2da..d141498cc4 100644
--- a/litellm/llms/azure_ai/chat/handler.py
+++ b/litellm/llms/azure_ai/chat/handler.py
@@ -1,59 +1,3 @@
-from typing import Any, Callable, List, Optional, Union
-
-from httpx._config import Timeout
-
-from litellm.llms.bedrock.chat.invoke_handler import MockResponseIterator
-from litellm.llms.openai.openai import OpenAIChatCompletion
-from litellm.types.utils import ModelResponse
-from litellm.utils import CustomStreamWrapper
-
-from .transformation import AzureAIStudioConfig
-
-
-class AzureAIChatCompletion(OpenAIChatCompletion):
- def completion(
- self,
- model_response: ModelResponse,
- timeout: Union[float, Timeout],
- optional_params: dict,
- logging_obj: Any,
- model: Optional[str] = None,
- messages: Optional[list] = None,
- print_verbose: Optional[Callable[..., Any]] = None,
- api_key: Optional[str] = None,
- api_base: Optional[str] = None,
- acompletion: bool = False,
- litellm_params=None,
- logger_fn=None,
- headers: Optional[dict] = None,
- custom_prompt_dict: dict = {},
- client=None,
- organization: Optional[str] = None,
- custom_llm_provider: Optional[str] = None,
- drop_params: Optional[bool] = None,
- ):
-
- transformed_messages = AzureAIStudioConfig()._transform_messages(
- messages=messages # type: ignore
- )
-
- return super().completion(
- model_response,
- timeout,
- optional_params,
- logging_obj,
- model,
- transformed_messages,
- print_verbose,
- api_key,
- api_base,
- acompletion,
- litellm_params,
- logger_fn,
- headers,
- custom_prompt_dict,
- client,
- organization,
- custom_llm_provider,
- drop_params,
- )
+"""
+LLM Calling done in `openai/openai.py`
+"""
diff --git a/litellm/llms/base_llm/transformation.py b/litellm/llms/base_llm/transformation.py
index c87c7a81d0..aa0f0838d3 100644
--- a/litellm/llms/base_llm/transformation.py
+++ b/litellm/llms/base_llm/transformation.py
@@ -13,6 +13,7 @@ from typing import (
Iterator,
List,
Optional,
+ TypedDict,
Union,
)
@@ -34,15 +35,25 @@ class BaseLLMException(Exception):
self,
status_code: int,
message: str,
- headers: Optional[Union[httpx.Headers, Dict]] = None,
+ headers: Optional[Union[dict, httpx.Headers]] = None,
request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None,
):
self.status_code = status_code
self.message: str = message
self.headers = headers
- self.request = httpx.Request(method="POST", url="https://docs.litellm.ai/docs")
- self.response = httpx.Response(status_code=status_code, request=self.request)
+ if request:
+ self.request = request
+ else:
+ self.request = httpx.Request(
+ method="POST", url="https://docs.litellm.ai/docs"
+ )
+ if response:
+ self.response = response
+ else:
+ self.response = httpx.Response(
+ status_code=status_code, request=self.request
+ )
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
@@ -117,12 +128,6 @@ class BaseConfig(ABC):
) -> dict:
pass
- @abstractmethod
- def _transform_messages(
- self, messages: List[AllMessageValues]
- ) -> List[AllMessageValues]:
- pass
-
@abstractmethod
def transform_response(
self,
@@ -133,7 +138,8 @@ class BaseConfig(ABC):
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
- encoding: str,
+ litellm_params: dict,
+ encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
diff --git a/litellm/llms/cerebras/chat.py b/litellm/llms/cerebras/chat.py
index 0b885a5996..09e8ffb834 100644
--- a/litellm/llms/cerebras/chat.py
+++ b/litellm/llms/cerebras/chat.py
@@ -7,8 +7,10 @@ this is OpenAI compatible - no translation needed / occurs
import types
from typing import Optional, Union
+from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
-class CerebrasConfig:
+
+class CerebrasConfig(OpenAIGPTConfig):
"""
Reference: https://inference-docs.cerebras.ai/api-reference/chat-completions
@@ -18,9 +20,7 @@ class CerebrasConfig:
max_tokens: Optional[int] = None
response_format: Optional[dict] = None
seed: Optional[int] = None
- stop: Optional[str] = None
stream: Optional[bool] = None
- temperature: Optional[float] = None
top_p: Optional[int] = None
tool_choice: Optional[str] = None
tools: Optional[list] = None
@@ -46,21 +46,7 @@ class CerebrasConfig:
@classmethod
def get_config(cls):
- return {
- k: v
- for k, v in cls.__dict__.items()
- if not k.startswith("__")
- and not isinstance(
- v,
- (
- types.FunctionType,
- types.BuiltinFunctionType,
- classmethod,
- staticmethod,
- ),
- )
- and v is not None
- }
+ return super().get_config()
def get_supported_openai_params(self, model: str) -> list:
"""
@@ -83,7 +69,11 @@ class CerebrasConfig:
]
def map_openai_params(
- self, model: str, non_default_params: dict, optional_params: dict
+ self,
+ non_default_params: dict,
+ optional_params: dict,
+ model: str,
+ drop_params: bool,
) -> dict:
supported_openai_params = self.get_supported_openai_params(model=model)
for param, value in non_default_params.items():
diff --git a/litellm/llms/clarifai/chat/transformation.py b/litellm/llms/clarifai/chat/transformation.py
index ae2705d025..9eb4803180 100644
--- a/litellm/llms/clarifai/chat/transformation.py
+++ b/litellm/llms/clarifai/chat/transformation.py
@@ -148,6 +148,7 @@ class ClarifaiConfig(BaseConfig):
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
+ litellm_params: dict,
encoding: str,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
diff --git a/litellm/llms/cloudflare/chat/transformation.py b/litellm/llms/cloudflare/chat/transformation.py
index 17d97503b4..4906f7b44e 100644
--- a/litellm/llms/cloudflare/chat/transformation.py
+++ b/litellm/llms/cloudflare/chat/transformation.py
@@ -49,6 +49,10 @@ class CloudflareChatConfig(BaseConfig):
if key != "self" and value is not None:
setattr(self.__class__, key, value)
+ @classmethod
+ def get_config(cls):
+ return super().get_config()
+
def validate_environment(
self,
headers: dict,
@@ -120,6 +124,7 @@ class CloudflareChatConfig(BaseConfig):
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
+ litellm_params: dict,
encoding: str,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
diff --git a/litellm/llms/cohere/chat/transformation.py b/litellm/llms/cohere/chat/transformation.py
index b28f37e6f0..204137f793 100644
--- a/litellm/llms/cohere/chat/transformation.py
+++ b/litellm/llms/cohere/chat/transformation.py
@@ -216,7 +216,8 @@ class CohereChatConfig(BaseConfig):
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
- encoding: str,
+ litellm_params: dict,
+ encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
diff --git a/litellm/llms/cohere/completion/transformation.py b/litellm/llms/cohere/completion/transformation.py
index 9414a88e58..b94d6d24e6 100644
--- a/litellm/llms/cohere/completion/transformation.py
+++ b/litellm/llms/cohere/completion/transformation.py
@@ -217,7 +217,8 @@ class CohereTextConfig(BaseConfig):
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
- encoding: str,
+ litellm_params: dict,
+ encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py
index f4d20f8fb9..d08bc794fc 100644
--- a/litellm/llms/custom_httpx/http_handler.py
+++ b/litellm/llms/custom_httpx/http_handler.py
@@ -211,7 +211,6 @@ class AsyncHTTPHandler:
headers=headers,
)
except httpx.HTTPStatusError as e:
-
if stream is True:
setattr(e, "message", await e.response.aread())
setattr(e, "text", await e.response.aread())
diff --git a/litellm/llms/custom_httpx/llm_http_handler.py b/litellm/llms/custom_httpx/llm_http_handler.py
index de42def31a..e3114a5221 100644
--- a/litellm/llms/custom_httpx/llm_http_handler.py
+++ b/litellm/llms/custom_httpx/llm_http_handler.py
@@ -51,7 +51,8 @@ class BaseLLMHTTPHandler:
logging_obj: LiteLLMLoggingObj,
messages: list,
optional_params: dict,
- encoding: str,
+ litellm_params: dict,
+ encoding: Any,
api_key: Optional[str] = None,
):
async_httpx_client = get_async_httpx_client(
@@ -75,6 +76,7 @@ class BaseLLMHTTPHandler:
request_data=data,
messages=messages,
optional_params=optional_params,
+ litellm_params=litellm_params,
encoding=encoding,
)
@@ -163,6 +165,7 @@ class BaseLLMHTTPHandler:
api_key=api_key,
messages=messages,
optional_params=optional_params,
+ litellm_params=litellm_params,
encoding=encoding,
)
@@ -211,6 +214,7 @@ class BaseLLMHTTPHandler:
request_data=data,
messages=messages,
optional_params=optional_params,
+ litellm_params=litellm_params,
encoding=encoding,
)
diff --git a/litellm/llms/databricks/chat/transformation.py b/litellm/llms/databricks/chat/transformation.py
index 05470f14c8..584b413373 100644
--- a/litellm/llms/databricks/chat/transformation.py
+++ b/litellm/llms/databricks/chat/transformation.py
@@ -10,14 +10,14 @@ from pydantic import BaseModel
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ProviderField
-from ...openai.chat.gpt_transformation import OpenAIGPTConfig
+from ...openai_like.chat.transformation import OpenAILikeChatConfig
from ...prompt_templates.common_utils import (
handle_messages_with_content_list_to_str_conversion,
strip_name_from_messages,
)
-class DatabricksConfig(OpenAIGPTConfig):
+class DatabricksConfig(OpenAILikeChatConfig):
"""
Reference: https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request
"""
@@ -85,30 +85,6 @@ class DatabricksConfig(OpenAIGPTConfig):
return False
- def map_openai_params(
- self,
- non_default_params: dict,
- optional_params: dict,
- model: str,
- drop_params: bool,
- ):
- for param, value in non_default_params.items():
- if param == "max_tokens" or param == "max_completion_tokens":
- optional_params["max_tokens"] = value
- if param == "n":
- optional_params["n"] = value
- if param == "stream" and value is True:
- optional_params["stream"] = value
- if param == "temperature":
- optional_params["temperature"] = value
- if param == "top_p":
- optional_params["top_p"] = value
- if param == "stop":
- optional_params["stop"] = value
- if param == "response_format":
- optional_params["response_format"] = value
- return optional_params
-
def _transform_messages(
self, messages: List[AllMessageValues]
) -> List[AllMessageValues]:
diff --git a/litellm/llms/deepinfra/chat/transformation.py b/litellm/llms/deepinfra/chat/transformation.py
new file mode 100644
index 0000000000..0137f409b3
--- /dev/null
+++ b/litellm/llms/deepinfra/chat/transformation.py
@@ -0,0 +1,120 @@
+import types
+from typing import Optional, Tuple, Union
+
+import litellm
+from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
+from litellm.secret_managers.main import get_secret_str
+
+
+class DeepInfraConfig(OpenAIGPTConfig):
+ """
+ Reference: https://deepinfra.com/docs/advanced/openai_api
+
+ The class `DeepInfra` provides configuration for the DeepInfra's Chat Completions API interface. Below are the parameters:
+ """
+
+ frequency_penalty: Optional[int] = None
+ function_call: Optional[Union[str, dict]] = None
+ functions: Optional[list] = None
+ logit_bias: Optional[dict] = None
+ max_tokens: Optional[int] = None
+ n: Optional[int] = None
+ presence_penalty: Optional[int] = None
+ stop: Optional[Union[str, list]] = None
+ temperature: Optional[int] = None
+ top_p: Optional[int] = None
+ response_format: Optional[dict] = None
+ tools: Optional[list] = None
+ tool_choice: Optional[Union[str, dict]] = None
+
+ def __init__(
+ self,
+ frequency_penalty: Optional[int] = None,
+ function_call: Optional[Union[str, dict]] = None,
+ functions: Optional[list] = None,
+ logit_bias: Optional[dict] = None,
+ max_tokens: Optional[int] = None,
+ n: Optional[int] = None,
+ presence_penalty: Optional[int] = None,
+ stop: Optional[Union[str, list]] = None,
+ temperature: Optional[int] = None,
+ top_p: Optional[int] = None,
+ response_format: Optional[dict] = None,
+ tools: Optional[list] = None,
+ tool_choice: Optional[Union[str, dict]] = None,
+ ) -> None:
+ locals_ = locals().copy()
+ for key, value in locals_.items():
+ if key != "self" and value is not None:
+ setattr(self.__class__, key, value)
+
+ @classmethod
+ def get_config(cls):
+ return super().get_config()
+
+ def get_supported_openai_params(self, model: str):
+ return [
+ "stream",
+ "frequency_penalty",
+ "function_call",
+ "functions",
+ "logit_bias",
+ "max_tokens",
+ "max_completion_tokens",
+ "n",
+ "presence_penalty",
+ "stop",
+ "temperature",
+ "top_p",
+ "response_format",
+ "tools",
+ "tool_choice",
+ ]
+
+ def map_openai_params(
+ self,
+ non_default_params: dict,
+ optional_params: dict,
+ model: str,
+ drop_params: bool,
+ ) -> dict:
+ supported_openai_params = self.get_supported_openai_params(model=model)
+ for param, value in non_default_params.items():
+ if (
+ param == "temperature"
+ and value == 0
+ and model == "mistralai/Mistral-7B-Instruct-v0.1"
+ ): # this model does no support temperature == 0
+ value = 0.0001 # close to 0
+ if param == "tool_choice":
+ if (
+ value != "auto" and value != "none"
+ ): # https://deepinfra.com/docs/advanced/function_calling
+ ## UNSUPPORTED TOOL CHOICE VALUE
+ if litellm.drop_params is True or drop_params is True:
+ value = None
+ else:
+ raise litellm.utils.UnsupportedParamsError(
+ message="Deepinfra doesn't support tool_choice={}. To drop unsupported openai params from the call, set `litellm.drop_params = True`".format(
+ value
+ ),
+ status_code=400,
+ )
+ elif param == "max_completion_tokens":
+ optional_params["max_tokens"] = value
+ elif param in supported_openai_params:
+ if value is not None:
+ optional_params[param] = value
+ return optional_params
+
+ def _get_openai_compatible_provider_info(
+ self, api_base: Optional[str], api_key: Optional[str]
+ ) -> Tuple[Optional[str], Optional[str]]:
+ # deepinfra is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1
+ api_base = (
+ api_base
+ or get_secret_str("DEEPINFRA_API_BASE")
+ or "https://api.deepinfra.com/v1/openai"
+ )
+ dynamic_api_key = api_key or get_secret_str("DEEPINFRA_API_KEY")
+ return api_base, dynamic_api_key
diff --git a/litellm/llms/palm.py b/litellm/llms/deprecated_providers/palm.py
similarity index 100%
rename from litellm/llms/palm.py
rename to litellm/llms/deprecated_providers/palm.py
diff --git a/litellm/llms/gemini.py b/litellm/llms/gemini.py
deleted file mode 100644
index 3b05b70dcc..0000000000
--- a/litellm/llms/gemini.py
+++ /dev/null
@@ -1,421 +0,0 @@
-# ####################################
-# ######### DEPRECATED FILE ##########
-# ####################################
-# # logic moved to `vertex_httpx.py` #
-
-import copy
-import time
-import traceback
-import types
-from typing import Callable, Optional
-
-import httpx
-from packaging.version import Version
-
-import litellm
-from litellm import verbose_logger
-from litellm.utils import Choices, Message, ModelResponse, Usage
-
-from .prompt_templates.factory import custom_prompt, get_system_prompt, prompt_factory
-
-
-class GeminiError(Exception):
- def __init__(self, status_code, message):
- self.status_code = status_code
- self.message = message
- self.request = httpx.Request(
- method="POST",
- url="https://developers.generativeai.google/api/python/google/generativeai/chat",
- )
- self.response = httpx.Response(status_code=status_code, request=self.request)
- super().__init__(
- self.message
- ) # Call the base class constructor with the parameters it needs
-
-
-class GeminiConfig:
- """
- Reference: https://ai.google.dev/api/python/google/generativeai/GenerationConfig
-
- The class `GeminiConfig` provides configuration for the Gemini's API interface. Here are the parameters:
-
- - `candidate_count` (int): Number of generated responses to return.
-
- - `stop_sequences` (List[str]): The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop at the first appearance of a stop sequence. The stop sequence will not be included as part of the response.
-
- - `max_output_tokens` (int): The maximum number of tokens to include in a candidate. If unset, this will default to output_token_limit specified in the model's specification.
-
- - `temperature` (float): Controls the randomness of the output. Note: The default value varies by model, see the Model.temperature attribute of the Model returned the genai.get_model function. Values can range from [0.0,1.0], inclusive. A value closer to 1.0 will produce responses that are more varied and creative, while a value closer to 0.0 will typically result in more straightforward responses from the model.
-
- - `top_p` (float): Optional. The maximum cumulative probability of tokens to consider when sampling.
-
- - `top_k` (int): Optional. The maximum number of tokens to consider when sampling.
- """
-
- candidate_count: Optional[int] = None
- stop_sequences: Optional[list] = None
- max_output_tokens: Optional[int] = None
- temperature: Optional[float] = None
- top_p: Optional[float] = None
- top_k: Optional[int] = None
-
- def __init__(
- self,
- candidate_count: Optional[int] = None,
- stop_sequences: Optional[list] = None,
- max_output_tokens: Optional[int] = None,
- temperature: Optional[float] = None,
- top_p: Optional[float] = None,
- top_k: Optional[int] = None,
- ) -> None:
- locals_ = locals()
- for key, value in locals_.items():
- if key != "self" and value is not None:
- setattr(self.__class__, key, value)
-
- @classmethod
- def get_config(cls):
- return {
- k: v
- for k, v in cls.__dict__.items()
- if not k.startswith("__")
- and not isinstance(
- v,
- (
- types.FunctionType,
- types.BuiltinFunctionType,
- classmethod,
- staticmethod,
- ),
- )
- and v is not None
- }
-
-
-# class TextStreamer:
-# """
-# A class designed to return an async stream from AsyncGenerateContentResponse object.
-# """
-
-# def __init__(self, response):
-# self.response = response
-# self._aiter = self.response.__aiter__()
-
-# async def __aiter__(self):
-# while True:
-# try:
-# # This will manually advance the async iterator.
-# # In the case the next object doesn't exists, __anext__() will simply raise a StopAsyncIteration exception
-# next_object = await self._aiter.__anext__()
-# yield next_object
-# except StopAsyncIteration:
-# # After getting all items from the async iterator, stop iterating
-# break
-
-
-# def supports_system_instruction():
-# import google.generativeai as genai
-
-# gemini_pkg_version = Version(genai.__version__)
-# return gemini_pkg_version >= Version("0.5.0")
-
-
-# def completion(
-# model: str,
-# messages: list,
-# model_response: ModelResponse,
-# print_verbose: Callable,
-# api_key,
-# encoding,
-# logging_obj,
-# custom_prompt_dict: dict,
-# acompletion: bool = False,
-# optional_params=None,
-# litellm_params=None,
-# logger_fn=None,
-# ):
-# try:
-# import google.generativeai as genai # type: ignore
-# except Exception:
-# raise Exception(
-# "Importing google.generativeai failed, please run 'pip install -q google-generativeai"
-# )
-# genai.configure(api_key=api_key)
-# system_prompt = ""
-# if model in custom_prompt_dict:
-# # check if the model has a registered custom prompt
-# model_prompt_details = custom_prompt_dict[model]
-# prompt = custom_prompt(
-# role_dict=model_prompt_details["roles"],
-# initial_prompt_value=model_prompt_details["initial_prompt_value"],
-# final_prompt_value=model_prompt_details["final_prompt_value"],
-# messages=messages,
-# )
-# else:
-# system_prompt, messages = get_system_prompt(messages=messages)
-# prompt = prompt_factory(
-# model=model, messages=messages, custom_llm_provider="gemini"
-# )
-
-# ## Load Config
-# inference_params = copy.deepcopy(optional_params)
-# stream = inference_params.pop("stream", None)
-
-# # Handle safety settings
-# safety_settings_param = inference_params.pop("safety_settings", None)
-# safety_settings = None
-# if safety_settings_param:
-# safety_settings = [
-# genai.types.SafetySettingDict(x) for x in safety_settings_param
-# ]
-
-# config = litellm.GeminiConfig.get_config()
-# for k, v in config.items():
-# if (
-# k not in inference_params
-# ): # completion(top_k=3) > gemini_config(top_k=3) <- allows for dynamic variables to be passed in
-# inference_params[k] = v
-
-# ## LOGGING
-# logging_obj.pre_call(
-# input=prompt,
-# api_key="",
-# additional_args={
-# "complete_input_dict": {
-# "inference_params": inference_params,
-# "system_prompt": system_prompt,
-# }
-# },
-# )
-# ## COMPLETION CALL
-# try:
-# _params = {"model_name": "models/{}".format(model)}
-# _system_instruction = supports_system_instruction()
-# if _system_instruction and len(system_prompt) > 0:
-# _params["system_instruction"] = system_prompt
-# _model = genai.GenerativeModel(**_params)
-# if stream is True:
-# if acompletion is True:
-
-# async def async_streaming():
-# try:
-# response = await _model.generate_content_async(
-# contents=prompt,
-# generation_config=genai.types.GenerationConfig(
-# **inference_params
-# ),
-# safety_settings=safety_settings,
-# stream=True,
-# )
-
-# response = litellm.CustomStreamWrapper(
-# TextStreamer(response),
-# model,
-# custom_llm_provider="gemini",
-# logging_obj=logging_obj,
-# )
-# return response
-# except Exception as e:
-# raise GeminiError(status_code=500, message=str(e))
-
-# return async_streaming()
-# response = _model.generate_content(
-# contents=prompt,
-# generation_config=genai.types.GenerationConfig(**inference_params),
-# safety_settings=safety_settings,
-# stream=True,
-# )
-# return response
-# elif acompletion == True:
-# return async_completion(
-# _model=_model,
-# model=model,
-# prompt=prompt,
-# inference_params=inference_params,
-# safety_settings=safety_settings,
-# logging_obj=logging_obj,
-# print_verbose=print_verbose,
-# model_response=model_response,
-# messages=messages,
-# encoding=encoding,
-# )
-# else:
-# params = {
-# "contents": prompt,
-# "generation_config": genai.types.GenerationConfig(**inference_params),
-# "safety_settings": safety_settings,
-# }
-# response = _model.generate_content(**params)
-# except Exception as e:
-# raise GeminiError(
-# message=str(e),
-# status_code=500,
-# )
-
-# ## LOGGING
-# logging_obj.post_call(
-# input=prompt,
-# api_key="",
-# original_response=response,
-# additional_args={"complete_input_dict": {}},
-# )
-# print_verbose(f"raw model_response: {response}")
-# ## RESPONSE OBJECT
-# completion_response = response
-# try:
-# choices_list = []
-# for idx, item in enumerate(completion_response.candidates):
-# if len(item.content.parts) > 0:
-# message_obj = Message(content=item.content.parts[0].text)
-# else:
-# message_obj = Message(content=None)
-# choice_obj = Choices(index=idx, message=message_obj)
-# choices_list.append(choice_obj)
-# model_response.choices = choices_list
-# except Exception as e:
-# verbose_logger.error("LiteLLM.gemini.py: Exception occured - {}".format(str(e)))
-# raise GeminiError(
-# message=traceback.format_exc(), status_code=response.status_code
-# )
-
-# try:
-# completion_response = model_response["choices"][0]["message"].get("content")
-# if completion_response is None:
-# raise Exception
-# except Exception:
-# original_response = f"response: {response}"
-# if hasattr(response, "candidates"):
-# original_response = f"response: {response.candidates}"
-# if "SAFETY" in original_response:
-# original_response += (
-# "\nThe candidate content was flagged for safety reasons."
-# )
-# elif "RECITATION" in original_response:
-# original_response += (
-# "\nThe candidate content was flagged for recitation reasons."
-# )
-# raise GeminiError(
-# status_code=400,
-# message=f"No response received. Original response - {original_response}",
-# )
-
-# ## CALCULATING USAGE
-# prompt_str = ""
-# for m in messages:
-# if isinstance(m["content"], str):
-# prompt_str += m["content"]
-# elif isinstance(m["content"], list):
-# for content in m["content"]:
-# if content["type"] == "text":
-# prompt_str += content["text"]
-# prompt_tokens = len(encoding.encode(prompt_str))
-# completion_tokens = len(
-# encoding.encode(model_response["choices"][0]["message"].get("content", ""))
-# )
-
-# model_response.created = int(time.time())
-# model_response.model = "gemini/" + model
-# usage = Usage(
-# prompt_tokens=prompt_tokens,
-# completion_tokens=completion_tokens,
-# total_tokens=prompt_tokens + completion_tokens,
-# )
-# setattr(model_response, "usage", usage)
-# return model_response
-
-
-# async def async_completion(
-# _model,
-# model,
-# prompt,
-# inference_params,
-# safety_settings,
-# logging_obj,
-# print_verbose,
-# model_response,
-# messages,
-# encoding,
-# ):
-# import google.generativeai as genai # type: ignore
-
-# response = await _model.generate_content_async(
-# contents=prompt,
-# generation_config=genai.types.GenerationConfig(**inference_params),
-# safety_settings=safety_settings,
-# )
-
-# ## LOGGING
-# logging_obj.post_call(
-# input=prompt,
-# api_key="",
-# original_response=response,
-# additional_args={"complete_input_dict": {}},
-# )
-# print_verbose(f"raw model_response: {response}")
-# ## RESPONSE OBJECT
-# completion_response = response
-# try:
-# choices_list = []
-# for idx, item in enumerate(completion_response.candidates):
-# if len(item.content.parts) > 0:
-# message_obj = Message(content=item.content.parts[0].text)
-# else:
-# message_obj = Message(content=None)
-# choice_obj = Choices(index=idx, message=message_obj)
-# choices_list.append(choice_obj)
-# model_response["choices"] = choices_list
-# except Exception as e:
-# verbose_logger.error("LiteLLM.gemini.py: Exception occured - {}".format(str(e)))
-# raise GeminiError(
-# message=traceback.format_exc(), status_code=response.status_code
-# )
-
-# try:
-# completion_response = model_response["choices"][0]["message"].get("content")
-# if completion_response is None:
-# raise Exception
-# except Exception:
-# original_response = f"response: {response}"
-# if hasattr(response, "candidates"):
-# original_response = f"response: {response.candidates}"
-# if "SAFETY" in original_response:
-# original_response += (
-# "\nThe candidate content was flagged for safety reasons."
-# )
-# elif "RECITATION" in original_response:
-# original_response += (
-# "\nThe candidate content was flagged for recitation reasons."
-# )
-# raise GeminiError(
-# status_code=400,
-# message=f"No response received. Original response - {original_response}",
-# )
-
-# ## CALCULATING USAGE
-# prompt_str = ""
-# for m in messages:
-# if isinstance(m["content"], str):
-# prompt_str += m["content"]
-# elif isinstance(m["content"], list):
-# for content in m["content"]:
-# if content["type"] == "text":
-# prompt_str += content["text"]
-# prompt_tokens = len(encoding.encode(prompt_str))
-# completion_tokens = len(
-# encoding.encode(model_response["choices"][0]["message"].get("content", ""))
-# )
-
-# model_response["created"] = int(time.time())
-# model_response["model"] = "gemini/" + model
-# usage = Usage(
-# prompt_tokens=prompt_tokens,
-# completion_tokens=completion_tokens,
-# total_tokens=prompt_tokens + completion_tokens,
-# )
-# model_response.usage = usage
-# return model_response
-
-
-# def embedding():
-# # logic for parsing in - calling - parsing out model embedding calls
-# pass
diff --git a/litellm/llms/huggingface/chat/handler.py b/litellm/llms/huggingface/chat/handler.py
new file mode 100644
index 0000000000..608bfc6ed8
--- /dev/null
+++ b/litellm/llms/huggingface/chat/handler.py
@@ -0,0 +1,750 @@
+## Uses the huggingface text generation inference API
+import copy
+import enum
+import json
+import os
+import time
+import types
+from enum import Enum
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ List,
+ Literal,
+ Optional,
+ Tuple,
+ Union,
+ cast,
+ get_args,
+)
+
+import httpx
+import requests
+
+import litellm
+from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
+from litellm.llms.custom_httpx.http_handler import (
+ AsyncHTTPHandler,
+ HTTPHandler,
+ _get_httpx_client,
+ get_async_httpx_client,
+)
+from litellm.llms.huggingface.chat.transformation import (
+ HuggingfaceChatConfig as HuggingfaceConfig,
+)
+from litellm.secret_managers.main import get_secret_str
+from litellm.types.completion import ChatCompletionMessageToolCallParam
+from litellm.types.llms.openai import AllMessageValues
+from litellm.types.utils import Logprobs as TextCompletionLogprobs
+from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage
+
+from ...base import BaseLLM
+from ...prompt_templates.factory import custom_prompt, prompt_factory
+from ..common_utils import HuggingfaceError, hf_task_list, hf_tasks
+
+hf_chat_config = HuggingfaceConfig()
+
+
+hf_tasks_embeddings = Literal[ # pipeline tags + hf tei endpoints - https://huggingface.github.io/text-embeddings-inference/#/
+ "sentence-similarity", "feature-extraction", "rerank", "embed", "similarity"
+]
+
+
+def get_hf_task_embedding_for_model(
+ model: str, task_type: Optional[str], api_base: str
+) -> Optional[str]:
+ if task_type is not None:
+ if task_type in get_args(hf_tasks_embeddings):
+ return task_type
+ else:
+ raise Exception(
+ "Invalid task_type={}. Expected one of={}".format(
+ task_type, hf_tasks_embeddings
+ )
+ )
+ http_client = HTTPHandler(concurrent_limit=1)
+
+ model_info = http_client.get(url=api_base)
+
+ model_info_dict = model_info.json()
+
+ pipeline_tag: Optional[str] = model_info_dict.get("pipeline_tag", None)
+
+ return pipeline_tag
+
+
+async def async_get_hf_task_embedding_for_model(
+ model: str, task_type: Optional[str], api_base: str
+) -> Optional[str]:
+ if task_type is not None:
+ if task_type in get_args(hf_tasks_embeddings):
+ return task_type
+ else:
+ raise Exception(
+ "Invalid task_type={}. Expected one of={}".format(
+ task_type, hf_tasks_embeddings
+ )
+ )
+ http_client = get_async_httpx_client(
+ llm_provider=litellm.LlmProviders.HUGGINGFACE,
+ )
+
+ model_info = await http_client.get(url=api_base)
+
+ model_info_dict = model_info.json()
+
+ pipeline_tag: Optional[str] = model_info_dict.get("pipeline_tag", None)
+
+ return pipeline_tag
+
+
+async def make_call(
+ client: Optional[AsyncHTTPHandler],
+ api_base: str,
+ headers: dict,
+ data: str,
+ model: str,
+ messages: list,
+ logging_obj,
+ timeout: Optional[Union[float, httpx.Timeout]],
+ json_mode: bool,
+) -> Tuple[Any, httpx.Headers]:
+ if client is None:
+ client = litellm.module_level_aclient
+
+ try:
+ response = await client.post(
+ api_base, headers=headers, data=data, stream=True, timeout=timeout
+ )
+ except httpx.HTTPStatusError as e:
+ error_headers = getattr(e, "headers", None)
+ error_response = getattr(e, "response", None)
+ if error_headers is None and error_response:
+ error_headers = getattr(error_response, "headers", None)
+ raise HuggingfaceError(
+ status_code=e.response.status_code,
+ message=str(await e.response.aread()),
+ headers=cast(dict, error_headers) if error_headers else None,
+ )
+ except Exception as e:
+ for exception in litellm.LITELLM_EXCEPTION_TYPES:
+ if isinstance(e, exception):
+ raise e
+ raise HuggingfaceError(status_code=500, message=str(e))
+
+ # LOGGING
+ logging_obj.post_call(
+ input=messages,
+ api_key="",
+ original_response=response, # Pass the completion stream for logging
+ additional_args={"complete_input_dict": data},
+ )
+
+ return response.aiter_lines(), response.headers
+
+
+class Huggingface(BaseLLM):
+ _client_session: Optional[httpx.Client] = None
+ _aclient_session: Optional[httpx.AsyncClient] = None
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ def completion( # noqa: PLR0915
+ self,
+ model: str,
+ messages: list,
+ api_base: Optional[str],
+ model_response: ModelResponse,
+ print_verbose: Callable,
+ timeout: float,
+ encoding,
+ api_key,
+ logging_obj,
+ optional_params: dict,
+ litellm_params: dict,
+ custom_prompt_dict={},
+ acompletion: bool = False,
+ logger_fn=None,
+ client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
+ headers: dict = {},
+ ):
+ super().completion()
+ exception_mapping_worked = False
+ try:
+ task, model = hf_chat_config.get_hf_task_for_model(model)
+ litellm_params["task"] = task
+ headers = hf_chat_config.validate_environment(
+ api_key=api_key,
+ headers=headers,
+ model=model,
+ messages=messages,
+ optional_params=optional_params,
+ )
+ completion_url = hf_chat_config.get_api_base(api_base=api_base, model=model)
+ data = hf_chat_config.transform_request(
+ model=model,
+ messages=messages,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ headers=headers,
+ )
+
+ ## LOGGING
+ logging_obj.pre_call(
+ input=data,
+ api_key=api_key,
+ additional_args={
+ "complete_input_dict": data,
+ "headers": headers,
+ "api_base": completion_url,
+ "acompletion": acompletion,
+ },
+ )
+ ## COMPLETION CALL
+
+ if acompletion is True:
+ ### ASYNC STREAMING
+ if optional_params.get("stream", False):
+ return self.async_streaming(logging_obj=logging_obj, api_base=completion_url, data=data, headers=headers, model_response=model_response, model=model, timeout=timeout, messages=messages) # type: ignore
+ else:
+ ### ASYNC COMPLETION
+ return self.acompletion(api_base=completion_url, data=data, headers=headers, model_response=model_response, task=task, encoding=encoding, model=model, optional_params=optional_params, timeout=timeout, litellm_params=litellm_params) # type: ignore
+ if client is None or not isinstance(client, HTTPHandler):
+ client = HTTPHandler()
+ ### SYNC STREAMING
+ if "stream" in optional_params and optional_params["stream"] is True:
+ response = client.post(
+ url=completion_url,
+ headers=headers,
+ data=json.dumps(data),
+ stream=optional_params["stream"],
+ )
+ return response.iter_lines()
+ ### SYNC COMPLETION
+ else:
+ response = client.post(
+ url=completion_url,
+ headers=headers,
+ data=json.dumps(data),
+ )
+
+ return hf_chat_config.transform_response(
+ model=model,
+ raw_response=response,
+ model_response=model_response,
+ logging_obj=logging_obj,
+ api_key=api_key,
+ request_data=data,
+ messages=messages,
+ optional_params=optional_params,
+ encoding=encoding,
+ json_mode=None,
+ litellm_params=litellm_params,
+ )
+ except httpx.HTTPStatusError as e:
+ raise HuggingfaceError(
+ status_code=e.response.status_code,
+ message=e.response.text,
+ headers=e.response.headers,
+ )
+ except HuggingfaceError as e:
+ exception_mapping_worked = True
+ raise e
+ except Exception as e:
+ if exception_mapping_worked:
+ raise e
+ else:
+ import traceback
+
+ raise HuggingfaceError(status_code=500, message=traceback.format_exc())
+
+ async def acompletion(
+ self,
+ api_base: str,
+ data: dict,
+ headers: dict,
+ model_response: ModelResponse,
+ encoding: Any,
+ model: str,
+ optional_params: dict,
+ litellm_params: dict,
+ timeout: float,
+ logging_obj: LiteLLMLoggingObj,
+ api_key: str,
+ messages: List[AllMessageValues],
+ ):
+ response: Optional[httpx.Response] = None
+ try:
+ http_client = get_async_httpx_client(
+ llm_provider=litellm.LlmProviders.HUGGINGFACE
+ )
+ ### ASYNC COMPLETION
+ http_response = await http_client.post(
+ url=api_base, headers=headers, data=json.dumps(data), timeout=timeout
+ )
+
+ response = http_response
+
+ return hf_chat_config.transform_response(
+ model=model,
+ raw_response=http_response,
+ model_response=model_response,
+ logging_obj=logging_obj,
+ api_key=api_key,
+ request_data=data,
+ messages=messages,
+ optional_params=optional_params,
+ encoding=encoding,
+ json_mode=None,
+ litellm_params=litellm_params,
+ )
+ except Exception as e:
+ if isinstance(e, httpx.TimeoutException):
+ raise HuggingfaceError(status_code=500, message="Request Timeout Error")
+ elif isinstance(e, HuggingfaceError):
+ raise e
+ elif response is not None and hasattr(response, "text"):
+ raise HuggingfaceError(
+ status_code=500,
+ message=f"{str(e)}\n\nOriginal Response: {response.text}",
+ headers=response.headers,
+ )
+ else:
+ raise HuggingfaceError(status_code=500, message=f"{str(e)}")
+
+ async def async_streaming(
+ self,
+ logging_obj,
+ api_base: str,
+ data: dict,
+ headers: dict,
+ model_response: ModelResponse,
+ messages: List[AllMessageValues],
+ model: str,
+ timeout: float,
+ client: Optional[AsyncHTTPHandler] = None,
+ ):
+ completion_stream, _ = await make_call(
+ client=client,
+ api_base=api_base,
+ headers=headers,
+ data=json.dumps(data),
+ model=model,
+ messages=messages,
+ logging_obj=logging_obj,
+ timeout=timeout,
+ json_mode=False,
+ )
+ streamwrapper = CustomStreamWrapper(
+ completion_stream=completion_stream,
+ model=model,
+ custom_llm_provider="huggingface",
+ logging_obj=logging_obj,
+ )
+ return streamwrapper
+
+ def _transform_input_on_pipeline_tag(
+ self, input: List, pipeline_tag: Optional[str]
+ ) -> dict:
+ if pipeline_tag is None:
+ return {"inputs": input}
+ if pipeline_tag == "sentence-similarity" or pipeline_tag == "similarity":
+ if len(input) < 2:
+ raise HuggingfaceError(
+ status_code=400,
+ message="sentence-similarity requires 2+ sentences",
+ )
+ return {"inputs": {"source_sentence": input[0], "sentences": input[1:]}}
+ elif pipeline_tag == "rerank":
+ if len(input) < 2:
+ raise HuggingfaceError(
+ status_code=400,
+ message="reranker requires 2+ sentences",
+ )
+ return {"inputs": {"query": input[0], "texts": input[1:]}}
+ return {"inputs": input} # default to feature-extraction pipeline tag
+
+ async def _async_transform_input(
+ self,
+ model: str,
+ task_type: Optional[str],
+ embed_url: str,
+ input: List,
+ optional_params: dict,
+ ) -> dict:
+ hf_task = await async_get_hf_task_embedding_for_model(
+ model=model, task_type=task_type, api_base=embed_url
+ )
+
+ data = self._transform_input_on_pipeline_tag(input=input, pipeline_tag=hf_task)
+
+ if len(optional_params.keys()) > 0:
+ data["options"] = optional_params
+
+ return data
+
+ def _process_optional_params(self, data: dict, optional_params: dict) -> dict:
+ special_options_keys = HuggingfaceConfig().get_special_options_params()
+ special_parameters_keys = [
+ "min_length",
+ "max_length",
+ "top_k",
+ "top_p",
+ "temperature",
+ "repetition_penalty",
+ "max_time",
+ ]
+
+ for k, v in optional_params.items():
+ if k in special_options_keys:
+ data.setdefault("options", {})
+ data["options"][k] = v
+ elif k in special_parameters_keys:
+ data.setdefault("parameters", {})
+ data["parameters"][k] = v
+ else:
+ data[k] = v
+
+ return data
+
+ def _transform_input(
+ self,
+ input: List,
+ model: str,
+ call_type: Literal["sync", "async"],
+ optional_params: dict,
+ embed_url: str,
+ ) -> dict:
+ data: Dict = {}
+ ## TRANSFORMATION ##
+ if "sentence-transformers" in model:
+ if len(input) == 0:
+ raise HuggingfaceError(
+ status_code=400,
+ message="sentence transformers requires 2+ sentences",
+ )
+ data = {"inputs": {"source_sentence": input[0], "sentences": input[1:]}}
+ else:
+ data = {"inputs": input}
+
+ task_type = optional_params.pop("input_type", None)
+
+ if call_type == "sync":
+ hf_task = get_hf_task_embedding_for_model(
+ model=model, task_type=task_type, api_base=embed_url
+ )
+ elif call_type == "async":
+ return self._async_transform_input(
+ model=model, task_type=task_type, embed_url=embed_url, input=input
+ ) # type: ignore
+
+ data = self._transform_input_on_pipeline_tag(
+ input=input, pipeline_tag=hf_task
+ )
+
+ if len(optional_params.keys()) > 0:
+ data = self._process_optional_params(
+ data=data, optional_params=optional_params
+ )
+
+ return data
+
+ def _process_embedding_response(
+ self,
+ embeddings: dict,
+ model_response: litellm.EmbeddingResponse,
+ model: str,
+ input: List,
+ encoding: Any,
+ ) -> litellm.EmbeddingResponse:
+ output_data = []
+ if "similarities" in embeddings:
+ for idx, embedding in embeddings["similarities"]:
+ output_data.append(
+ {
+ "object": "embedding",
+ "index": idx,
+ "embedding": embedding, # flatten list returned from hf
+ }
+ )
+ else:
+ for idx, embedding in enumerate(embeddings):
+ if isinstance(embedding, float):
+ output_data.append(
+ {
+ "object": "embedding",
+ "index": idx,
+ "embedding": embedding, # flatten list returned from hf
+ }
+ )
+ elif isinstance(embedding, list) and isinstance(embedding[0], float):
+ output_data.append(
+ {
+ "object": "embedding",
+ "index": idx,
+ "embedding": embedding, # flatten list returned from hf
+ }
+ )
+ else:
+ output_data.append(
+ {
+ "object": "embedding",
+ "index": idx,
+ "embedding": embedding[0][
+ 0
+ ], # flatten list returned from hf
+ }
+ )
+ model_response.object = "list"
+ model_response.data = output_data
+ model_response.model = model
+ input_tokens = 0
+ for text in input:
+ input_tokens += len(encoding.encode(text))
+
+ setattr(
+ model_response,
+ "usage",
+ litellm.Usage(
+ prompt_tokens=input_tokens,
+ completion_tokens=input_tokens,
+ total_tokens=input_tokens,
+ prompt_tokens_details=None,
+ completion_tokens_details=None,
+ ),
+ )
+ return model_response
+
+ async def aembedding(
+ self,
+ model: str,
+ input: list,
+ model_response: litellm.utils.EmbeddingResponse,
+ timeout: Union[float, httpx.Timeout],
+ logging_obj: LiteLLMLoggingObj,
+ optional_params: dict,
+ api_base: str,
+ api_key: Optional[str],
+ headers: dict,
+ encoding: Callable,
+ client: Optional[AsyncHTTPHandler] = None,
+ ):
+ ## TRANSFORMATION ##
+ data = self._transform_input(
+ input=input,
+ model=model,
+ call_type="sync",
+ optional_params=optional_params,
+ embed_url=api_base,
+ )
+
+ ## LOGGING
+ logging_obj.pre_call(
+ input=input,
+ api_key=api_key,
+ additional_args={
+ "complete_input_dict": data,
+ "headers": headers,
+ "api_base": api_base,
+ },
+ )
+ ## COMPLETION CALL
+ if client is None:
+ client = get_async_httpx_client(
+ llm_provider=litellm.LlmProviders.HUGGINGFACE,
+ )
+
+ response = await client.post(api_base, headers=headers, data=json.dumps(data))
+
+ ## LOGGING
+ logging_obj.post_call(
+ input=input,
+ api_key=api_key,
+ additional_args={"complete_input_dict": data},
+ original_response=response,
+ )
+
+ embeddings = response.json()
+
+ if "error" in embeddings:
+ raise HuggingfaceError(status_code=500, message=embeddings["error"])
+
+ ## PROCESS RESPONSE ##
+ return self._process_embedding_response(
+ embeddings=embeddings,
+ model_response=model_response,
+ model=model,
+ input=input,
+ encoding=encoding,
+ )
+
+ def embedding(
+ self,
+ model: str,
+ input: list,
+ model_response: litellm.EmbeddingResponse,
+ optional_params: dict,
+ logging_obj: LiteLLMLoggingObj,
+ encoding: Callable,
+ api_key: Optional[str] = None,
+ api_base: Optional[str] = None,
+ timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
+ aembedding: Optional[bool] = None,
+ client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
+ headers={},
+ ) -> litellm.EmbeddingResponse:
+ super().embedding()
+ headers = hf_chat_config.validate_environment(
+ api_key=api_key,
+ headers=headers,
+ model=model,
+ optional_params=optional_params,
+ messages=[],
+ )
+ # print_verbose(f"{model}, {task}")
+ embed_url = ""
+ if "https" in model:
+ embed_url = model
+ elif api_base:
+ embed_url = api_base
+ elif "HF_API_BASE" in os.environ:
+ embed_url = os.getenv("HF_API_BASE", "")
+ elif "HUGGINGFACE_API_BASE" in os.environ:
+ embed_url = os.getenv("HUGGINGFACE_API_BASE", "")
+ else:
+ embed_url = f"https://api-inference.huggingface.co/models/{model}"
+
+ ## ROUTING ##
+ if aembedding is True:
+ return self.aembedding(
+ input=input,
+ model_response=model_response,
+ timeout=timeout,
+ logging_obj=logging_obj,
+ headers=headers,
+ api_base=embed_url, # type: ignore
+ api_key=api_key,
+ client=client if isinstance(client, AsyncHTTPHandler) else None,
+ model=model,
+ optional_params=optional_params,
+ encoding=encoding,
+ )
+
+ ## TRANSFORMATION ##
+
+ data = self._transform_input(
+ input=input,
+ model=model,
+ call_type="sync",
+ optional_params=optional_params,
+ embed_url=embed_url,
+ )
+
+ ## LOGGING
+ logging_obj.pre_call(
+ input=input,
+ api_key=api_key,
+ additional_args={
+ "complete_input_dict": data,
+ "headers": headers,
+ "api_base": embed_url,
+ },
+ )
+ ## COMPLETION CALL
+ if client is None or not isinstance(client, HTTPHandler):
+ client = HTTPHandler(concurrent_limit=1)
+ response = client.post(embed_url, headers=headers, data=json.dumps(data))
+
+ ## LOGGING
+ logging_obj.post_call(
+ input=input,
+ api_key=api_key,
+ additional_args={"complete_input_dict": data},
+ original_response=response,
+ )
+
+ embeddings = response.json()
+
+ if "error" in embeddings:
+ raise HuggingfaceError(status_code=500, message=embeddings["error"])
+
+ ## PROCESS RESPONSE ##
+ return self._process_embedding_response(
+ embeddings=embeddings,
+ model_response=model_response,
+ model=model,
+ input=input,
+ encoding=encoding,
+ )
+
+ def _transform_logprobs(
+ self, hf_response: Optional[List]
+ ) -> Optional[TextCompletionLogprobs]:
+ """
+ Transform Hugging Face logprobs to OpenAI.Completion() format
+ """
+ if hf_response is None:
+ return None
+
+ # Initialize an empty list for the transformed logprobs
+ _logprob: TextCompletionLogprobs = TextCompletionLogprobs(
+ text_offset=[],
+ token_logprobs=[],
+ tokens=[],
+ top_logprobs=[],
+ )
+
+ # For each Hugging Face response, transform the logprobs
+ for response in hf_response:
+ # Extract the relevant information from the response
+ response_details = response["details"]
+ top_tokens = response_details.get("top_tokens", {})
+
+ for i, token in enumerate(response_details["prefill"]):
+ # Extract the text of the token
+ token_text = token["text"]
+
+ # Extract the logprob of the token
+ token_logprob = token["logprob"]
+
+ # Add the token information to the 'token_info' list
+ _logprob.tokens.append(token_text)
+ _logprob.token_logprobs.append(token_logprob)
+
+ # stub this to work with llm eval harness
+ top_alt_tokens = {"": -1.0, "": -2.0, "": -3.0} # noqa: F601
+ _logprob.top_logprobs.append(top_alt_tokens)
+
+ # For each element in the 'tokens' list, extract the relevant information
+ for i, token in enumerate(response_details["tokens"]):
+ # Extract the text of the token
+ token_text = token["text"]
+
+ # Extract the logprob of the token
+ token_logprob = token["logprob"]
+
+ top_alt_tokens = {}
+ temp_top_logprobs = []
+ if top_tokens != {}:
+ temp_top_logprobs = top_tokens[i]
+
+ # top_alt_tokens should look like this: { "alternative_1": -1, "alternative_2": -2, "alternative_3": -3 }
+ for elem in temp_top_logprobs:
+ text = elem["text"]
+ logprob = elem["logprob"]
+ top_alt_tokens[text] = logprob
+
+ # Add the token information to the 'token_info' list
+ _logprob.tokens.append(token_text)
+ _logprob.token_logprobs.append(token_logprob)
+ _logprob.top_logprobs.append(top_alt_tokens)
+
+ # Add the text offset of the token
+ # This is computed as the sum of the lengths of all previous tokens
+ _logprob.text_offset.append(
+ sum(len(t["text"]) for t in response_details["tokens"][:i])
+ )
+
+ return _logprob
diff --git a/litellm/llms/huggingface/chat/transformation.py b/litellm/llms/huggingface/chat/transformation.py
new file mode 100644
index 0000000000..8880ec41c3
--- /dev/null
+++ b/litellm/llms/huggingface/chat/transformation.py
@@ -0,0 +1,590 @@
+import json
+import os
+import time
+import types
+from copy import deepcopy
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
+
+import httpx
+
+import litellm
+from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
+from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
+from litellm.llms.prompt_templates.common_utils import convert_content_list_to_str
+from litellm.llms.prompt_templates.factory import custom_prompt, prompt_factory
+from litellm.secret_managers.main import get_secret_str
+from litellm.types.llms.openai import AllMessageValues
+from litellm.types.utils import Choices, Message, ModelResponse, Usage
+from litellm.utils import token_counter
+
+from ..common_utils import HuggingfaceError, hf_task_list, hf_tasks, output_parser
+
+if TYPE_CHECKING:
+ from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
+
+ LoggingClass = LiteLLMLoggingObj
+else:
+ LoggingClass = Any
+
+
+tgi_models_cache = None
+conv_models_cache = None
+
+
+class HuggingfaceChatConfig(BaseConfig):
+ """
+ Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate
+ """
+
+ hf_task: Optional[hf_tasks] = (
+ None # litellm-specific param, used to know the api spec to use when calling huggingface api
+ )
+ best_of: Optional[int] = None
+ decoder_input_details: Optional[bool] = None
+ details: Optional[bool] = True # enables returning logprobs + best of
+ max_new_tokens: Optional[int] = None
+ repetition_penalty: Optional[float] = None
+ return_full_text: Optional[bool] = (
+ False # by default don't return the input as part of the output
+ )
+ seed: Optional[int] = None
+ temperature: Optional[float] = None
+ top_k: Optional[int] = None
+ top_n_tokens: Optional[int] = None
+ top_p: Optional[int] = None
+ truncate: Optional[int] = None
+ typical_p: Optional[float] = None
+ watermark: Optional[bool] = None
+
+ def __init__(
+ self,
+ best_of: Optional[int] = None,
+ decoder_input_details: Optional[bool] = None,
+ details: Optional[bool] = None,
+ max_new_tokens: Optional[int] = None,
+ repetition_penalty: Optional[float] = None,
+ return_full_text: Optional[bool] = None,
+ seed: Optional[int] = None,
+ temperature: Optional[float] = None,
+ top_k: Optional[int] = None,
+ top_n_tokens: Optional[int] = None,
+ top_p: Optional[int] = None,
+ truncate: Optional[int] = None,
+ typical_p: Optional[float] = None,
+ watermark: Optional[bool] = None,
+ ) -> None:
+ locals_ = locals()
+ for key, value in locals_.items():
+ if key != "self" and value is not None:
+ setattr(self.__class__, key, value)
+
+ @classmethod
+ def get_config(cls):
+ return super().get_config()
+
+ def get_special_options_params(self):
+ return ["use_cache", "wait_for_model"]
+
+ def get_supported_openai_params(self, model: str):
+ return [
+ "stream",
+ "temperature",
+ "max_tokens",
+ "max_completion_tokens",
+ "top_p",
+ "stop",
+ "n",
+ "echo",
+ ]
+
+ def map_openai_params(
+ self,
+ non_default_params: Dict,
+ optional_params: Dict,
+ model: str,
+ drop_params: bool,
+ ) -> Dict:
+ for param, value in non_default_params.items():
+ # temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
+ if param == "temperature":
+ if value == 0.0 or value == 0:
+ # hugging face exception raised when temp==0
+ # Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
+ value = 0.01
+ optional_params["temperature"] = value
+ if param == "top_p":
+ optional_params["top_p"] = value
+ if param == "n":
+ optional_params["best_of"] = value
+ optional_params["do_sample"] = (
+ True # Need to sample if you want best of for hf inference endpoints
+ )
+ if param == "stream":
+ optional_params["stream"] = value
+ if param == "stop":
+ optional_params["stop"] = value
+ if param == "max_tokens" or param == "max_completion_tokens":
+ # HF TGI raises the following exception when max_new_tokens==0
+ # Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
+ if value == 0:
+ value = 1
+ optional_params["max_new_tokens"] = value
+ if param == "echo":
+ # https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
+ # Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
+ optional_params["decoder_input_details"] = True
+
+ return optional_params
+
+ def get_hf_api_key(self) -> Optional[str]:
+ return get_secret_str("HUGGINGFACE_API_KEY")
+
+ def read_tgi_conv_models(self):
+ try:
+ global tgi_models_cache, conv_models_cache
+ # Check if the cache is already populated
+ # so we don't keep on reading txt file if there are 1k requests
+ if (tgi_models_cache is not None) and (conv_models_cache is not None):
+ return tgi_models_cache, conv_models_cache
+ # If not, read the file and populate the cache
+ tgi_models = set()
+ script_directory = os.path.dirname(os.path.abspath(__file__))
+ script_directory = os.path.dirname(script_directory)
+ # Construct the file path relative to the script's directory
+ file_path = os.path.join(
+ script_directory,
+ "huggingface_llms_metadata",
+ "hf_text_generation_models.txt",
+ )
+
+ with open(file_path, "r") as file:
+ for line in file:
+ tgi_models.add(line.strip())
+
+ # Cache the set for future use
+ tgi_models_cache = tgi_models
+
+ # If not, read the file and populate the cache
+ file_path = os.path.join(
+ script_directory,
+ "huggingface_llms_metadata",
+ "hf_conversational_models.txt",
+ )
+ conv_models = set()
+ with open(file_path, "r") as file:
+ for line in file:
+ conv_models.add(line.strip())
+ # Cache the set for future use
+ conv_models_cache = conv_models
+ return tgi_models, conv_models
+ except Exception:
+ return set(), set()
+
+ def get_hf_task_for_model(self, model: str) -> Tuple[hf_tasks, str]:
+ # read text file, cast it to set
+ # read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
+ if model.split("/")[0] in hf_task_list:
+ split_model = model.split("/", 1)
+ return split_model[0], split_model[1] # type: ignore
+ tgi_models, conversational_models = self.read_tgi_conv_models()
+
+ if model in tgi_models:
+ return "text-generation-inference", model
+ elif model in conversational_models:
+ return "conversational", model
+ elif "roneneldan/TinyStories" in model:
+ return "text-generation", model
+ else:
+ return "text-generation-inference", model # default to tgi
+
+ def transform_request(
+ self,
+ model: str,
+ messages: List[AllMessageValues],
+ optional_params: dict,
+ litellm_params: dict,
+ headers: dict,
+ ) -> dict:
+ task = litellm_params.get("task", None)
+ ## VALIDATE API FORMAT
+ if task is None or not isinstance(task, str) or task not in hf_task_list:
+ raise Exception(
+ "Invalid hf task - {}. Valid formats - {}.".format(task, hf_tasks)
+ )
+
+ ## Load Config
+ config = litellm.HuggingfaceConfig.get_config()
+ for k, v in config.items():
+ if (
+ k not in optional_params
+ ): # completion(top_k=3) > huggingfaceConfig(top_k=3) <- allows for dynamic variables to be passed in
+ optional_params[k] = v
+
+ ### MAP INPUT PARAMS
+ #### HANDLE SPECIAL PARAMS
+ special_params = self.get_special_options_params()
+ special_params_dict = {}
+ # Create a list of keys to pop after iteration
+ keys_to_pop = []
+
+ for k, v in optional_params.items():
+ if k in special_params:
+ special_params_dict[k] = v
+ keys_to_pop.append(k)
+
+ # Pop the keys from the dictionary after iteration
+ for k in keys_to_pop:
+ optional_params.pop(k)
+ if task == "conversational":
+ inference_params = deepcopy(optional_params)
+ inference_params.pop("details")
+ inference_params.pop("return_full_text")
+ past_user_inputs = []
+ generated_responses = []
+ text = ""
+ for message in messages:
+ if message["role"] == "user":
+ if text != "":
+ past_user_inputs.append(text)
+ text = convert_content_list_to_str(message)
+ elif message["role"] == "assistant" or message["role"] == "system":
+ generated_responses.append(convert_content_list_to_str(message))
+ data = {
+ "inputs": {
+ "text": text,
+ "past_user_inputs": past_user_inputs,
+ "generated_responses": generated_responses,
+ },
+ "parameters": inference_params,
+ }
+
+ elif task == "text-generation-inference":
+ # always send "details" and "return_full_text" as params
+ if model in litellm.custom_prompt_dict:
+ # check if the model has a registered custom prompt
+ model_prompt_details = litellm.custom_prompt_dict[model]
+ prompt = custom_prompt(
+ role_dict=model_prompt_details.get("roles", None),
+ initial_prompt_value=model_prompt_details.get(
+ "initial_prompt_value", ""
+ ),
+ final_prompt_value=model_prompt_details.get(
+ "final_prompt_value", ""
+ ),
+ messages=messages,
+ )
+ else:
+ prompt = prompt_factory(model=model, messages=messages)
+ data = {
+ "inputs": prompt, # type: ignore
+ "parameters": optional_params,
+ "stream": ( # type: ignore
+ True
+ if "stream" in optional_params
+ and isinstance(optional_params["stream"], bool)
+ and optional_params["stream"] is True # type: ignore
+ else False
+ ),
+ }
+ else:
+ # Non TGI and Conversational llms
+ # We need this branch, it removes 'details' and 'return_full_text' from params
+ if model in litellm.custom_prompt_dict:
+ # check if the model has a registered custom prompt
+ model_prompt_details = litellm.custom_prompt_dict[model]
+ prompt = custom_prompt(
+ role_dict=model_prompt_details.get("roles", {}),
+ initial_prompt_value=model_prompt_details.get(
+ "initial_prompt_value", ""
+ ),
+ final_prompt_value=model_prompt_details.get(
+ "final_prompt_value", ""
+ ),
+ bos_token=model_prompt_details.get("bos_token", ""),
+ eos_token=model_prompt_details.get("eos_token", ""),
+ messages=messages,
+ )
+ else:
+ prompt = prompt_factory(model=model, messages=messages)
+ inference_params = deepcopy(optional_params)
+ inference_params.pop("details")
+ inference_params.pop("return_full_text")
+ data = {
+ "inputs": prompt, # type: ignore
+ }
+ if task == "text-generation-inference":
+ data["parameters"] = inference_params
+ data["stream"] = ( # type: ignore
+ True # type: ignore
+ if "stream" in optional_params and optional_params["stream"] is True
+ else False
+ )
+
+ ### RE-ADD SPECIAL PARAMS
+ if len(special_params_dict.keys()) > 0:
+ data.update({"options": special_params_dict})
+
+ return data
+
+ def get_api_base(self, api_base: Optional[str], model: str) -> str:
+ """
+ Get the API base for the Huggingface API.
+
+ Do not add the chat/embedding/rerank extension here. Let the handler do this.
+ """
+ if "https" in model:
+ completion_url = model
+ elif api_base is not None:
+ completion_url = api_base
+ elif "HF_API_BASE" in os.environ:
+ completion_url = os.getenv("HF_API_BASE", "")
+ elif "HUGGINGFACE_API_BASE" in os.environ:
+ completion_url = os.getenv("HUGGINGFACE_API_BASE", "")
+ else:
+ completion_url = f"https://api-inference.huggingface.co/models/{model}"
+
+ return completion_url
+
+ def validate_environment(
+ self,
+ headers: Dict,
+ model: str,
+ messages: List[AllMessageValues],
+ optional_params: Dict,
+ api_key: Optional[str] = None,
+ ) -> Dict:
+ default_headers = {
+ "content-type": "application/json",
+ }
+ if api_key is not None:
+ default_headers["Authorization"] = (
+ f"Bearer {api_key}" # Huggingface Inference Endpoint default is to accept bearer tokens
+ )
+
+ headers = {**headers, **default_headers}
+ return headers
+
+ def _transform_messages(
+ self,
+ messages: List[AllMessageValues],
+ ) -> List[AllMessageValues]:
+ return messages
+
+ def get_error_class(
+ self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
+ ) -> BaseLLMException:
+ return HuggingfaceError(
+ status_code=status_code, message=error_message, headers=headers
+ )
+
+ def _convert_streamed_response_to_complete_response(
+ self,
+ response: httpx.Response,
+ logging_obj: LoggingClass,
+ model: str,
+ data: dict,
+ api_key: Optional[str] = None,
+ ) -> List[Dict[str, Any]]:
+ streamed_response = CustomStreamWrapper(
+ completion_stream=response.iter_lines(),
+ model=model,
+ custom_llm_provider="huggingface",
+ logging_obj=logging_obj,
+ )
+ content = ""
+ for chunk in streamed_response:
+ content += chunk["choices"][0]["delta"]["content"]
+ completion_response: List[Dict[str, Any]] = [{"generated_text": content}]
+ ## LOGGING
+ logging_obj.post_call(
+ input=data,
+ api_key=api_key,
+ original_response=completion_response,
+ additional_args={"complete_input_dict": data},
+ )
+ return completion_response
+
+ def convert_to_model_response_object( # noqa: PLR0915
+ self,
+ completion_response: Union[List[Dict[str, Any]], Dict[str, Any]],
+ model_response: litellm.ModelResponse,
+ task: Optional[hf_tasks],
+ optional_params: dict,
+ encoding: Any,
+ messages: List[AllMessageValues],
+ model: str,
+ ):
+ if task is None:
+ task = "text-generation-inference" # default to tgi
+
+ if task == "conversational":
+ if len(completion_response["generated_text"]) > 0: # type: ignore
+ model_response.choices[0].message.content = completion_response[ # type: ignore
+ "generated_text"
+ ]
+ elif task == "text-generation-inference":
+ if (
+ not isinstance(completion_response, list)
+ or not isinstance(completion_response[0], dict)
+ or "generated_text" not in completion_response[0]
+ ):
+ raise HuggingfaceError(
+ status_code=422,
+ message=f"response is not in expected format - {completion_response}",
+ headers=None,
+ )
+
+ if len(completion_response[0]["generated_text"]) > 0:
+ model_response.choices[0].message.content = output_parser( # type: ignore
+ completion_response[0]["generated_text"]
+ )
+ ## GETTING LOGPROBS + FINISH REASON
+ if (
+ "details" in completion_response[0]
+ and "tokens" in completion_response[0]["details"]
+ ):
+ model_response.choices[0].finish_reason = completion_response[0][
+ "details"
+ ]["finish_reason"]
+ sum_logprob = 0
+ for token in completion_response[0]["details"]["tokens"]:
+ if token["logprob"] is not None:
+ sum_logprob += token["logprob"]
+ setattr(model_response.choices[0].message, "_logprob", sum_logprob) # type: ignore
+ if "best_of" in optional_params and optional_params["best_of"] > 1:
+ if (
+ "details" in completion_response[0]
+ and "best_of_sequences" in completion_response[0]["details"]
+ ):
+ choices_list = []
+ for idx, item in enumerate(
+ completion_response[0]["details"]["best_of_sequences"]
+ ):
+ sum_logprob = 0
+ for token in item["tokens"]:
+ if token["logprob"] is not None:
+ sum_logprob += token["logprob"]
+ if len(item["generated_text"]) > 0:
+ message_obj = Message(
+ content=output_parser(item["generated_text"]),
+ logprobs=sum_logprob,
+ )
+ else:
+ message_obj = Message(content=None)
+ choice_obj = Choices(
+ finish_reason=item["finish_reason"],
+ index=idx + 1,
+ message=message_obj,
+ )
+ choices_list.append(choice_obj)
+ model_response.choices.extend(choices_list)
+ elif task == "text-classification":
+ model_response.choices[0].message.content = json.dumps( # type: ignore
+ completion_response
+ )
+ else:
+ if (
+ isinstance(completion_response, list)
+ and len(completion_response[0]["generated_text"]) > 0
+ ):
+ model_response.choices[0].message.content = output_parser( # type: ignore
+ completion_response[0]["generated_text"]
+ )
+ ## CALCULATING USAGE
+ prompt_tokens = 0
+ try:
+ prompt_tokens = token_counter(model=model, messages=messages)
+ except Exception:
+ # this should remain non blocking we should not block a response returning if calculating usage fails
+ pass
+ output_text = model_response["choices"][0]["message"].get("content", "")
+ if output_text is not None and len(output_text) > 0:
+ completion_tokens = 0
+ try:
+ completion_tokens = len(
+ encoding.encode(
+ model_response["choices"][0]["message"].get("content", "")
+ )
+ ) ##[TODO] use the llama2 tokenizer here
+ except Exception:
+ # this should remain non blocking we should not block a response returning if calculating usage fails
+ pass
+ else:
+ completion_tokens = 0
+
+ model_response.created = int(time.time())
+ model_response.model = model
+ usage = Usage(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=prompt_tokens + completion_tokens,
+ )
+ setattr(model_response, "usage", usage)
+ model_response._hidden_params["original_response"] = completion_response
+ return model_response
+
+ def transform_response(
+ self,
+ model: str,
+ raw_response: httpx.Response,
+ model_response: ModelResponse,
+ logging_obj: LoggingClass,
+ request_data: Dict,
+ messages: List[AllMessageValues],
+ optional_params: Dict,
+ litellm_params: Dict,
+ encoding: Any,
+ api_key: Optional[str] = None,
+ json_mode: Optional[bool] = None,
+ ) -> ModelResponse:
+ ## Some servers might return streaming responses even though stream was not set to true. (e.g. Baseten)
+ task = litellm_params.get("task", None)
+ is_streamed = False
+ if (
+ raw_response.__dict__["headers"].get("Content-Type", "")
+ == "text/event-stream"
+ ):
+ is_streamed = True
+
+ # iterate over the complete streamed response, and return the final answer
+ if is_streamed:
+ completion_response = self._convert_streamed_response_to_complete_response(
+ response=raw_response,
+ logging_obj=logging_obj,
+ model=model,
+ data=request_data,
+ api_key=api_key,
+ )
+ else:
+ ## LOGGING
+ logging_obj.post_call(
+ input=request_data,
+ api_key=api_key,
+ original_response=raw_response.text,
+ additional_args={"complete_input_dict": request_data},
+ )
+ ## RESPONSE OBJECT
+ try:
+ completion_response = raw_response.json()
+ if isinstance(completion_response, dict):
+ completion_response = [completion_response]
+ except Exception:
+ raise HuggingfaceError(
+ message=f"Original Response received: {raw_response.text}",
+ status_code=raw_response.status_code,
+ )
+
+ if isinstance(completion_response, dict) and "error" in completion_response:
+ raise HuggingfaceError(
+ message=completion_response["error"], # type: ignore
+ status_code=raw_response.status_code,
+ )
+ return self.convert_to_model_response_object(
+ completion_response=completion_response,
+ model_response=model_response,
+ task=task if task is not None and task in hf_task_list else None,
+ optional_params=optional_params,
+ encoding=encoding,
+ messages=messages,
+ model=model,
+ )
diff --git a/litellm/llms/huggingface/common_utils.py b/litellm/llms/huggingface/common_utils.py
new file mode 100644
index 0000000000..c63a4a0d1d
--- /dev/null
+++ b/litellm/llms/huggingface/common_utils.py
@@ -0,0 +1,45 @@
+from typing import Literal, Optional, Union
+
+import httpx
+
+from litellm.llms.base_llm.transformation import BaseLLMException
+
+
+class HuggingfaceError(BaseLLMException):
+ def __init__(
+ self,
+ status_code: int,
+ message: str,
+ headers: Optional[Union[dict, httpx.Headers]] = None,
+ ):
+ super().__init__(status_code=status_code, message=message, headers=headers)
+
+
+hf_tasks = Literal[
+ "text-generation-inference",
+ "conversational",
+ "text-classification",
+ "text-generation",
+]
+
+hf_task_list = [
+ "text-generation-inference",
+ "conversational",
+ "text-classification",
+ "text-generation",
+]
+
+
+def output_parser(generated_text: str):
+ """
+ Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens.
+
+ Initial issue that prompted this - https://github.com/BerriAI/litellm/issues/763
+ """
+ chat_template_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "", ""]
+ for token in chat_template_tokens:
+ if generated_text.strip().startswith(token):
+ generated_text = generated_text.replace(token, "", 1)
+ if generated_text.endswith(token):
+ generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1]
+ return generated_text
diff --git a/litellm/llms/huggingface_llms_metadata/hf_conversational_models.txt b/litellm/llms/huggingface/huggingface_llms_metadata/hf_conversational_models.txt
similarity index 100%
rename from litellm/llms/huggingface_llms_metadata/hf_conversational_models.txt
rename to litellm/llms/huggingface/huggingface_llms_metadata/hf_conversational_models.txt
diff --git a/litellm/llms/huggingface_llms_metadata/hf_text_generation_models.txt b/litellm/llms/huggingface/huggingface_llms_metadata/hf_text_generation_models.txt
similarity index 100%
rename from litellm/llms/huggingface_llms_metadata/hf_text_generation_models.txt
rename to litellm/llms/huggingface/huggingface_llms_metadata/hf_text_generation_models.txt
diff --git a/litellm/llms/huggingface_restapi.py b/litellm/llms/huggingface_restapi.py
deleted file mode 100644
index 8b45f1ae7d..0000000000
--- a/litellm/llms/huggingface_restapi.py
+++ /dev/null
@@ -1,1264 +0,0 @@
-## Uses the huggingface text generation inference API
-import copy
-import enum
-import json
-import os
-import time
-import types
-from enum import Enum
-from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union, get_args
-
-import httpx
-import requests
-
-import litellm
-from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
-from litellm.secret_managers.main import get_secret_str
-from litellm.types.completion import ChatCompletionMessageToolCallParam
-from litellm.types.utils import Logprobs as TextCompletionLogprobs
-from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage
-
-from .base import BaseLLM
-from .prompt_templates.factory import custom_prompt, prompt_factory
-
-
-class HuggingfaceError(Exception):
- def __init__(
- self,
- status_code,
- message,
- request: Optional[httpx.Request] = None,
- response: Optional[httpx.Response] = None,
- ):
- self.status_code = status_code
- self.message = message
- if request is not None:
- self.request = request
- else:
- self.request = httpx.Request(
- method="POST", url="https://api-inference.huggingface.co/models"
- )
- if response is not None:
- self.response = response
- else:
- self.response = httpx.Response(
- status_code=status_code, request=self.request
- )
- super().__init__(
- self.message
- ) # Call the base class constructor with the parameters it needs
-
-
-hf_task_list = [
- "text-generation-inference",
- "conversational",
- "text-classification",
- "text-generation",
-]
-
-hf_tasks = Literal[
- "text-generation-inference",
- "conversational",
- "text-classification",
- "text-generation",
-]
-
-hf_tasks_embeddings = Literal[ # pipeline tags + hf tei endpoints - https://huggingface.github.io/text-embeddings-inference/#/
- "sentence-similarity", "feature-extraction", "rerank", "embed", "similarity"
-]
-
-
-class HuggingfaceConfig:
- """
- Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate
- """
-
- hf_task: Optional[hf_tasks] = (
- None # litellm-specific param, used to know the api spec to use when calling huggingface api
- )
- best_of: Optional[int] = None
- decoder_input_details: Optional[bool] = None
- details: Optional[bool] = True # enables returning logprobs + best of
- max_new_tokens: Optional[int] = None
- repetition_penalty: Optional[float] = None
- return_full_text: Optional[bool] = (
- False # by default don't return the input as part of the output
- )
- seed: Optional[int] = None
- temperature: Optional[float] = None
- top_k: Optional[int] = None
- top_n_tokens: Optional[int] = None
- top_p: Optional[int] = None
- truncate: Optional[int] = None
- typical_p: Optional[float] = None
- watermark: Optional[bool] = None
-
- def __init__(
- self,
- best_of: Optional[int] = None,
- decoder_input_details: Optional[bool] = None,
- details: Optional[bool] = None,
- max_new_tokens: Optional[int] = None,
- repetition_penalty: Optional[float] = None,
- return_full_text: Optional[bool] = None,
- seed: Optional[int] = None,
- temperature: Optional[float] = None,
- top_k: Optional[int] = None,
- top_n_tokens: Optional[int] = None,
- top_p: Optional[int] = None,
- truncate: Optional[int] = None,
- typical_p: Optional[float] = None,
- watermark: Optional[bool] = None,
- ) -> None:
- locals_ = locals()
- for key, value in locals_.items():
- if key != "self" and value is not None:
- setattr(self.__class__, key, value)
-
- @classmethod
- def get_config(cls):
- return {
- k: v
- for k, v in cls.__dict__.items()
- if not k.startswith("__")
- and not isinstance(
- v,
- (
- types.FunctionType,
- types.BuiltinFunctionType,
- classmethod,
- staticmethod,
- ),
- )
- and v is not None
- }
-
- def get_special_options_params(self):
- return ["use_cache", "wait_for_model"]
-
- def get_supported_openai_params(self):
- return [
- "stream",
- "temperature",
- "max_tokens",
- "max_completion_tokens",
- "top_p",
- "stop",
- "n",
- "echo",
- ]
-
- def map_openai_params(
- self, non_default_params: dict, optional_params: dict
- ) -> dict:
- for param, value in non_default_params.items():
- # temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
- if param == "temperature":
- if value == 0.0 or value == 0:
- # hugging face exception raised when temp==0
- # Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
- value = 0.01
- optional_params["temperature"] = value
- if param == "top_p":
- optional_params["top_p"] = value
- if param == "n":
- optional_params["best_of"] = value
- optional_params["do_sample"] = (
- True # Need to sample if you want best of for hf inference endpoints
- )
- if param == "stream":
- optional_params["stream"] = value
- if param == "stop":
- optional_params["stop"] = value
- if param == "max_tokens" or param == "max_completion_tokens":
- # HF TGI raises the following exception when max_new_tokens==0
- # Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
- if value == 0:
- value = 1
- optional_params["max_new_tokens"] = value
- if param == "echo":
- # https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
- # Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
- optional_params["decoder_input_details"] = True
- return optional_params
-
- def get_hf_api_key(self) -> Optional[str]:
- return get_secret_str("HUGGINGFACE_API_KEY")
-
-
-def output_parser(generated_text: str):
- """
- Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens.
-
- Initial issue that prompted this - https://github.com/BerriAI/litellm/issues/763
- """
- chat_template_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "", ""]
- for token in chat_template_tokens:
- if generated_text.strip().startswith(token):
- generated_text = generated_text.replace(token, "", 1)
- if generated_text.endswith(token):
- generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1]
- return generated_text
-
-
-tgi_models_cache = None
-conv_models_cache = None
-
-
-def read_tgi_conv_models():
- try:
- global tgi_models_cache, conv_models_cache
- # Check if the cache is already populated
- # so we don't keep on reading txt file if there are 1k requests
- if (tgi_models_cache is not None) and (conv_models_cache is not None):
- return tgi_models_cache, conv_models_cache
- # If not, read the file and populate the cache
- tgi_models = set()
- script_directory = os.path.dirname(os.path.abspath(__file__))
- # Construct the file path relative to the script's directory
- file_path = os.path.join(
- script_directory,
- "huggingface_llms_metadata",
- "hf_text_generation_models.txt",
- )
-
- with open(file_path, "r") as file:
- for line in file:
- tgi_models.add(line.strip())
-
- # Cache the set for future use
- tgi_models_cache = tgi_models
-
- # If not, read the file and populate the cache
- file_path = os.path.join(
- script_directory,
- "huggingface_llms_metadata",
- "hf_conversational_models.txt",
- )
- conv_models = set()
- with open(file_path, "r") as file:
- for line in file:
- conv_models.add(line.strip())
- # Cache the set for future use
- conv_models_cache = conv_models
- return tgi_models, conv_models
- except Exception:
- return set(), set()
-
-
-def get_hf_task_for_model(model: str) -> Tuple[hf_tasks, str]:
- # read text file, cast it to set
- # read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
- if model.split("/")[0] in hf_task_list:
- split_model = model.split("/", 1)
- return split_model[0], split_model[1] # type: ignore
- tgi_models, conversational_models = read_tgi_conv_models()
- if model in tgi_models:
- return "text-generation-inference", model
- elif model in conversational_models:
- return "conversational", model
- elif "roneneldan/TinyStories" in model:
- return "text-generation", model
- else:
- return "text-generation-inference", model # default to tgi
-
-
-from litellm.llms.custom_httpx.http_handler import (
- AsyncHTTPHandler,
- HTTPHandler,
- get_async_httpx_client,
-)
-
-
-def get_hf_task_embedding_for_model(
- model: str, task_type: Optional[str], api_base: str
-) -> Optional[str]:
- if task_type is not None:
- if task_type in get_args(hf_tasks_embeddings):
- return task_type
- else:
- raise Exception(
- "Invalid task_type={}. Expected one of={}".format(
- task_type, hf_tasks_embeddings
- )
- )
- http_client = HTTPHandler(concurrent_limit=1)
-
- model_info = http_client.get(url=api_base)
-
- model_info_dict = model_info.json()
-
- pipeline_tag: Optional[str] = model_info_dict.get("pipeline_tag", None)
-
- return pipeline_tag
-
-
-async def async_get_hf_task_embedding_for_model(
- model: str, task_type: Optional[str], api_base: str
-) -> Optional[str]:
- if task_type is not None:
- if task_type in get_args(hf_tasks_embeddings):
- return task_type
- else:
- raise Exception(
- "Invalid task_type={}. Expected one of={}".format(
- task_type, hf_tasks_embeddings
- )
- )
- http_client = get_async_httpx_client(
- llm_provider=litellm.LlmProviders.HUGGINGFACE,
- )
-
- model_info = await http_client.get(url=api_base)
-
- model_info_dict = model_info.json()
-
- pipeline_tag: Optional[str] = model_info_dict.get("pipeline_tag", None)
-
- return pipeline_tag
-
-
-class Huggingface(BaseLLM):
- _client_session: Optional[httpx.Client] = None
- _aclient_session: Optional[httpx.AsyncClient] = None
-
- def __init__(self) -> None:
- super().__init__()
-
- def _validate_environment(self, api_key, headers) -> dict:
- default_headers = {
- "content-type": "application/json",
- }
- if api_key and headers is None:
- default_headers["Authorization"] = (
- f"Bearer {api_key}" # Huggingface Inference Endpoint default is to accept bearer tokens
- )
- headers = default_headers
- elif headers:
- headers = headers
- else:
- headers = default_headers
- return headers
-
- def convert_to_model_response_object( # noqa: PLR0915
- self,
- completion_response,
- model_response: litellm.ModelResponse,
- task: hf_tasks,
- optional_params,
- encoding,
- input_text,
- model,
- ):
- if task == "conversational":
- if len(completion_response["generated_text"]) > 0: # type: ignore
- model_response.choices[0].message.content = completion_response[ # type: ignore
- "generated_text"
- ]
- elif task == "text-generation-inference":
- if (
- not isinstance(completion_response, list)
- or not isinstance(completion_response[0], dict)
- or "generated_text" not in completion_response[0]
- ):
- raise HuggingfaceError(
- status_code=422,
- message=f"response is not in expected format - {completion_response}",
- )
-
- if len(completion_response[0]["generated_text"]) > 0:
- model_response.choices[0].message.content = output_parser( # type: ignore
- completion_response[0]["generated_text"]
- )
- ## GETTING LOGPROBS + FINISH REASON
- if (
- "details" in completion_response[0]
- and "tokens" in completion_response[0]["details"]
- ):
- model_response.choices[0].finish_reason = completion_response[0][
- "details"
- ]["finish_reason"]
- sum_logprob = 0
- for token in completion_response[0]["details"]["tokens"]:
- if token["logprob"] is not None:
- sum_logprob += token["logprob"]
- setattr(model_response.choices[0].message, "_logprob", sum_logprob) # type: ignore
- if "best_of" in optional_params and optional_params["best_of"] > 1:
- if (
- "details" in completion_response[0]
- and "best_of_sequences" in completion_response[0]["details"]
- ):
- choices_list = []
- for idx, item in enumerate(
- completion_response[0]["details"]["best_of_sequences"]
- ):
- sum_logprob = 0
- for token in item["tokens"]:
- if token["logprob"] is not None:
- sum_logprob += token["logprob"]
- if len(item["generated_text"]) > 0:
- message_obj = Message(
- content=output_parser(item["generated_text"]),
- logprobs=sum_logprob,
- )
- else:
- message_obj = Message(content=None)
- choice_obj = Choices(
- finish_reason=item["finish_reason"],
- index=idx + 1,
- message=message_obj,
- )
- choices_list.append(choice_obj)
- model_response.choices.extend(choices_list)
- elif task == "text-classification":
- model_response.choices[0].message.content = json.dumps( # type: ignore
- completion_response
- )
- else:
- if len(completion_response[0]["generated_text"]) > 0:
- model_response.choices[0].message.content = output_parser( # type: ignore
- completion_response[0]["generated_text"]
- )
- ## CALCULATING USAGE
- prompt_tokens = 0
- try:
- prompt_tokens = len(
- encoding.encode(input_text)
- ) ##[TODO] use the llama2 tokenizer here
- except Exception:
- # this should remain non blocking we should not block a response returning if calculating usage fails
- pass
- output_text = model_response["choices"][0]["message"].get("content", "")
- if output_text is not None and len(output_text) > 0:
- completion_tokens = 0
- try:
- completion_tokens = len(
- encoding.encode(
- model_response["choices"][0]["message"].get("content", "")
- )
- ) ##[TODO] use the llama2 tokenizer here
- except Exception:
- # this should remain non blocking we should not block a response returning if calculating usage fails
- pass
- else:
- completion_tokens = 0
-
- model_response.created = int(time.time())
- model_response.model = model
- usage = Usage(
- prompt_tokens=prompt_tokens,
- completion_tokens=completion_tokens,
- total_tokens=prompt_tokens + completion_tokens,
- )
- setattr(model_response, "usage", usage)
- model_response._hidden_params["original_response"] = completion_response
- return model_response
-
- def completion( # noqa: PLR0915
- self,
- model: str,
- messages: list,
- api_base: Optional[str],
- headers: Optional[dict],
- model_response: ModelResponse,
- print_verbose: Callable,
- timeout: float,
- encoding,
- api_key,
- logging_obj,
- optional_params: dict,
- custom_prompt_dict={},
- acompletion: bool = False,
- litellm_params=None,
- logger_fn=None,
- ):
- super().completion()
- exception_mapping_worked = False
- try:
- headers = self._validate_environment(api_key, headers)
- task, model = get_hf_task_for_model(model)
- ## VALIDATE API FORMAT
- if task is None or not isinstance(task, str) or task not in hf_task_list:
- raise Exception(
- "Invalid hf task - {}. Valid formats - {}.".format(task, hf_tasks)
- )
-
- print_verbose(f"{model}, {task}")
- completion_url = ""
- input_text = ""
- if "https" in model:
- completion_url = model
- elif api_base:
- completion_url = api_base
- elif "HF_API_BASE" in os.environ:
- completion_url = os.getenv("HF_API_BASE", "")
- elif "HUGGINGFACE_API_BASE" in os.environ:
- completion_url = os.getenv("HUGGINGFACE_API_BASE", "")
- else:
- completion_url = f"https://api-inference.huggingface.co/models/{model}"
-
- ## Load Config
- config = litellm.HuggingfaceConfig.get_config()
- for k, v in config.items():
- if (
- k not in optional_params
- ): # completion(top_k=3) > huggingfaceConfig(top_k=3) <- allows for dynamic variables to be passed in
- optional_params[k] = v
-
- ### MAP INPUT PARAMS
- #### HANDLE SPECIAL PARAMS
- special_params = HuggingfaceConfig().get_special_options_params()
- special_params_dict = {}
- # Create a list of keys to pop after iteration
- keys_to_pop = []
-
- for k, v in optional_params.items():
- if k in special_params:
- special_params_dict[k] = v
- keys_to_pop.append(k)
-
- # Pop the keys from the dictionary after iteration
- for k in keys_to_pop:
- optional_params.pop(k)
- if task == "conversational":
- inference_params = copy.deepcopy(optional_params)
- inference_params.pop("details")
- inference_params.pop("return_full_text")
- past_user_inputs = []
- generated_responses = []
- text = ""
- for message in messages:
- if message["role"] == "user":
- if text != "":
- past_user_inputs.append(text)
- text = message["content"]
- elif message["role"] == "assistant" or message["role"] == "system":
- generated_responses.append(message["content"])
- data = {
- "inputs": {
- "text": text,
- "past_user_inputs": past_user_inputs,
- "generated_responses": generated_responses,
- },
- "parameters": inference_params,
- }
- input_text = "".join(message["content"] for message in messages)
- elif task == "text-generation-inference":
- # always send "details" and "return_full_text" as params
- if model in custom_prompt_dict:
- # check if the model has a registered custom prompt
- model_prompt_details = custom_prompt_dict[model]
- prompt = custom_prompt(
- role_dict=model_prompt_details.get("roles", None),
- initial_prompt_value=model_prompt_details.get(
- "initial_prompt_value", ""
- ),
- final_prompt_value=model_prompt_details.get(
- "final_prompt_value", ""
- ),
- messages=messages,
- )
- else:
- prompt = prompt_factory(model=model, messages=messages)
- data = {
- "inputs": prompt, # type: ignore
- "parameters": optional_params,
- "stream": ( # type: ignore
- True
- if "stream" in optional_params
- and isinstance(optional_params["stream"], bool)
- and optional_params["stream"] is True # type: ignore
- else False
- ),
- }
- input_text = prompt
- else:
- # Non TGI and Conversational llms
- # We need this branch, it removes 'details' and 'return_full_text' from params
- if model in custom_prompt_dict:
- # check if the model has a registered custom prompt
- model_prompt_details = custom_prompt_dict[model]
- prompt = custom_prompt(
- role_dict=model_prompt_details.get("roles", {}),
- initial_prompt_value=model_prompt_details.get(
- "initial_prompt_value", ""
- ),
- final_prompt_value=model_prompt_details.get(
- "final_prompt_value", ""
- ),
- bos_token=model_prompt_details.get("bos_token", ""),
- eos_token=model_prompt_details.get("eos_token", ""),
- messages=messages,
- )
- else:
- prompt = prompt_factory(model=model, messages=messages)
- inference_params = copy.deepcopy(optional_params)
- inference_params.pop("details")
- inference_params.pop("return_full_text")
- data = {
- "inputs": prompt, # type: ignore
- }
- if task == "text-generation-inference":
- data["parameters"] = inference_params
- data["stream"] = ( # type: ignore
- True # type: ignore
- if "stream" in optional_params
- and optional_params["stream"] is True
- else False
- )
- input_text = prompt
-
- ### RE-ADD SPECIAL PARAMS
- if len(special_params_dict.keys()) > 0:
- data.update({"options": special_params_dict})
-
- ## LOGGING
- logging_obj.pre_call(
- input=input_text,
- api_key=api_key,
- additional_args={
- "complete_input_dict": data,
- "task": task,
- "headers": headers,
- "api_base": completion_url,
- "acompletion": acompletion,
- },
- )
- ## COMPLETION CALL
-
- # SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
- ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
- if ssl_verify in ["True", "False"]:
- ssl_verify = bool(ssl_verify)
-
- if acompletion is True:
- ### ASYNC STREAMING
- if optional_params.get("stream", False):
- return self.async_streaming(logging_obj=logging_obj, api_base=completion_url, data=data, headers=headers, model_response=model_response, model=model, timeout=timeout) # type: ignore
- else:
- ### ASYNC COMPLETION
- return self.acompletion(api_base=completion_url, data=data, headers=headers, model_response=model_response, task=task, encoding=encoding, input_text=input_text, model=model, optional_params=optional_params, timeout=timeout) # type: ignore
- ### SYNC STREAMING
- if "stream" in optional_params and optional_params["stream"] is True:
- response = requests.post(
- completion_url,
- headers=headers,
- data=json.dumps(data),
- stream=optional_params["stream"],
- verify=ssl_verify,
- )
- return response.iter_lines()
- ### SYNC COMPLETION
- else:
- response = requests.post(
- completion_url,
- headers=headers,
- data=json.dumps(data),
- verify=ssl_verify,
- )
-
- ## Some servers might return streaming responses even though stream was not set to true. (e.g. Baseten)
- is_streamed = False
- if (
- response.__dict__["headers"].get("Content-Type", "")
- == "text/event-stream"
- ):
- is_streamed = True
-
- # iterate over the complete streamed response, and return the final answer
- if is_streamed:
- streamed_response = CustomStreamWrapper(
- completion_stream=response.iter_lines(),
- model=model,
- custom_llm_provider="huggingface",
- logging_obj=logging_obj,
- )
- content = ""
- for chunk in streamed_response:
- content += chunk["choices"][0]["delta"]["content"]
- completion_response: List[Dict[str, Any]] = [
- {"generated_text": content}
- ]
- ## LOGGING
- logging_obj.post_call(
- input=input_text,
- api_key=api_key,
- original_response=completion_response,
- additional_args={"complete_input_dict": data, "task": task},
- )
- else:
- ## LOGGING
- logging_obj.post_call(
- input=input_text,
- api_key=api_key,
- original_response=response.text,
- additional_args={"complete_input_dict": data, "task": task},
- )
- ## RESPONSE OBJECT
- try:
- completion_response = response.json()
- if isinstance(completion_response, dict):
- completion_response = [completion_response]
- except Exception:
- import traceback
-
- raise HuggingfaceError(
- message=f"Original Response received: {response.text}; Stacktrace: {traceback.format_exc()}",
- status_code=response.status_code,
- )
- print_verbose(f"response: {completion_response}")
- if (
- isinstance(completion_response, dict)
- and "error" in completion_response
- ):
- print_verbose(f"completion error: {completion_response['error']}") # type: ignore
- print_verbose(f"response.status_code: {response.status_code}")
- raise HuggingfaceError(
- message=completion_response["error"], # type: ignore
- status_code=response.status_code,
- )
- return self.convert_to_model_response_object(
- completion_response=completion_response,
- model_response=model_response,
- task=task,
- optional_params=optional_params,
- encoding=encoding,
- input_text=input_text,
- model=model,
- )
- except HuggingfaceError as e:
- exception_mapping_worked = True
- raise e
- except Exception as e:
- if exception_mapping_worked:
- raise e
- else:
- import traceback
-
- raise HuggingfaceError(status_code=500, message=traceback.format_exc())
-
- async def acompletion(
- self,
- api_base: str,
- data: dict,
- headers: dict,
- model_response: ModelResponse,
- task: hf_tasks,
- encoding: Any,
- input_text: str,
- model: str,
- optional_params: dict,
- timeout: float,
- ):
- # SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
- ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
-
- response = None
- try:
- async with httpx.AsyncClient(timeout=timeout, verify=ssl_verify) as client:
- response = await client.post(url=api_base, json=data, headers=headers)
- response_json = response.json()
- if response.status_code != 200:
- if "error" in response_json:
- raise HuggingfaceError(
- status_code=response.status_code,
- message=response_json["error"],
- request=response.request,
- response=response,
- )
- else:
- raise HuggingfaceError(
- status_code=response.status_code,
- message=response.text,
- request=response.request,
- response=response,
- )
-
- ## RESPONSE OBJECT
- return self.convert_to_model_response_object(
- completion_response=response_json,
- model_response=model_response,
- task=task,
- encoding=encoding,
- input_text=input_text,
- model=model,
- optional_params=optional_params,
- )
- except Exception as e:
- if isinstance(e, httpx.TimeoutException):
- raise HuggingfaceError(status_code=500, message="Request Timeout Error")
- elif isinstance(e, HuggingfaceError):
- raise e
- elif response is not None and hasattr(response, "text"):
- raise HuggingfaceError(
- status_code=500,
- message=f"{str(e)}\n\nOriginal Response: {response.text}",
- )
- else:
- raise HuggingfaceError(status_code=500, message=f"{str(e)}")
-
- async def async_streaming(
- self,
- logging_obj,
- api_base: str,
- data: dict,
- headers: dict,
- model_response: ModelResponse,
- model: str,
- timeout: float,
- ):
- # SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
- ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
-
- async with httpx.AsyncClient(timeout=timeout, verify=ssl_verify) as client:
- response = client.stream(
- "POST", url=f"{api_base}", json=data, headers=headers
- )
- async with response as r:
- if r.status_code != 200:
- text = await r.aread()
- raise HuggingfaceError(
- status_code=r.status_code,
- message=str(text),
- )
- """
- Check first chunk for error message.
- If error message, raise error.
- If not - add back to stream
- """
- # Async iterator over the lines in the response body
- response_iterator = r.aiter_lines()
-
- # Attempt to get the first line/chunk from the response
- try:
- first_chunk = await response_iterator.__anext__()
- except StopAsyncIteration:
- # Handle the case where there are no lines to read (empty response)
- first_chunk = ""
-
- # Check the first chunk for an error message
- if (
- "error" in first_chunk.lower()
- ): # Adjust this condition based on how error messages are structured
- raise HuggingfaceError(
- status_code=400,
- message=first_chunk,
- )
-
- # Create a new async generator that begins with the first_chunk and includes the remaining items
- async def custom_stream_with_first_chunk():
- yield first_chunk # Yield back the first chunk
- async for (
- chunk
- ) in response_iterator: # Continue yielding the rest of the chunks
- yield chunk
-
- # Creating a new completion stream that starts with the first chunk
- completion_stream = custom_stream_with_first_chunk()
-
- streamwrapper = CustomStreamWrapper(
- completion_stream=completion_stream,
- model=model,
- custom_llm_provider="huggingface",
- logging_obj=logging_obj,
- )
-
- async for transformed_chunk in streamwrapper:
- yield transformed_chunk
-
- def _transform_input_on_pipeline_tag(
- self, input: List, pipeline_tag: Optional[str]
- ) -> dict:
- if pipeline_tag is None:
- return {"inputs": input}
- if pipeline_tag == "sentence-similarity" or pipeline_tag == "similarity":
- if len(input) < 2:
- raise HuggingfaceError(
- status_code=400,
- message="sentence-similarity requires 2+ sentences",
- )
- return {"inputs": {"source_sentence": input[0], "sentences": input[1:]}}
- elif pipeline_tag == "rerank":
- if len(input) < 2:
- raise HuggingfaceError(
- status_code=400,
- message="reranker requires 2+ sentences",
- )
- return {"inputs": {"query": input[0], "texts": input[1:]}}
- return {"inputs": input} # default to feature-extraction pipeline tag
-
- async def _async_transform_input(
- self,
- model: str,
- task_type: Optional[str],
- embed_url: str,
- input: List,
- optional_params: dict,
- ) -> dict:
- hf_task = await async_get_hf_task_embedding_for_model(
- model=model, task_type=task_type, api_base=embed_url
- )
-
- data = self._transform_input_on_pipeline_tag(input=input, pipeline_tag=hf_task)
-
- if len(optional_params.keys()) > 0:
- data["options"] = optional_params
-
- return data
-
- def _process_optional_params(self, data: dict, optional_params: dict) -> dict:
- special_options_keys = HuggingfaceConfig().get_special_options_params()
- special_parameters_keys = [
- "min_length",
- "max_length",
- "top_k",
- "top_p",
- "temperature",
- "repetition_penalty",
- "max_time",
- ]
-
- for k, v in optional_params.items():
- if k in special_options_keys:
- data.setdefault("options", {})
- data["options"][k] = v
- elif k in special_parameters_keys:
- data.setdefault("parameters", {})
- data["parameters"][k] = v
- else:
- data[k] = v
-
- return data
-
- def _transform_input(
- self,
- input: List,
- model: str,
- call_type: Literal["sync", "async"],
- optional_params: dict,
- embed_url: str,
- ) -> dict:
- data: Dict = {}
- ## TRANSFORMATION ##
- if "sentence-transformers" in model:
- if len(input) == 0:
- raise HuggingfaceError(
- status_code=400,
- message="sentence transformers requires 2+ sentences",
- )
- data = {"inputs": {"source_sentence": input[0], "sentences": input[1:]}}
- else:
- data = {"inputs": input}
-
- task_type = optional_params.pop("input_type", None)
-
- if call_type == "sync":
- hf_task = get_hf_task_embedding_for_model(
- model=model, task_type=task_type, api_base=embed_url
- )
- elif call_type == "async":
- return self._async_transform_input(
- model=model, task_type=task_type, embed_url=embed_url, input=input
- ) # type: ignore
-
- data = self._transform_input_on_pipeline_tag(
- input=input, pipeline_tag=hf_task
- )
-
- if len(optional_params.keys()) > 0:
- data = self._process_optional_params(
- data=data, optional_params=optional_params
- )
-
- return data
-
- def _process_embedding_response(
- self,
- embeddings: dict,
- model_response: litellm.EmbeddingResponse,
- model: str,
- input: List,
- encoding: Any,
- ) -> litellm.EmbeddingResponse:
- output_data = []
- if "similarities" in embeddings:
- for idx, embedding in embeddings["similarities"]:
- output_data.append(
- {
- "object": "embedding",
- "index": idx,
- "embedding": embedding, # flatten list returned from hf
- }
- )
- else:
- for idx, embedding in enumerate(embeddings):
- if isinstance(embedding, float):
- output_data.append(
- {
- "object": "embedding",
- "index": idx,
- "embedding": embedding, # flatten list returned from hf
- }
- )
- elif isinstance(embedding, list) and isinstance(embedding[0], float):
- output_data.append(
- {
- "object": "embedding",
- "index": idx,
- "embedding": embedding, # flatten list returned from hf
- }
- )
- else:
- output_data.append(
- {
- "object": "embedding",
- "index": idx,
- "embedding": embedding[0][
- 0
- ], # flatten list returned from hf
- }
- )
- model_response.object = "list"
- model_response.data = output_data
- model_response.model = model
- input_tokens = 0
- for text in input:
- input_tokens += len(encoding.encode(text))
-
- setattr(
- model_response,
- "usage",
- litellm.Usage(
- prompt_tokens=input_tokens,
- completion_tokens=input_tokens,
- total_tokens=input_tokens,
- prompt_tokens_details=None,
- completion_tokens_details=None,
- ),
- )
- return model_response
-
- async def aembedding(
- self,
- model: str,
- input: list,
- model_response: litellm.utils.EmbeddingResponse,
- timeout: Union[float, httpx.Timeout],
- logging_obj: LiteLLMLoggingObj,
- optional_params: dict,
- api_base: str,
- api_key: Optional[str],
- headers: dict,
- encoding: Callable,
- client: Optional[AsyncHTTPHandler] = None,
- ):
- ## TRANSFORMATION ##
- data = self._transform_input(
- input=input,
- model=model,
- call_type="sync",
- optional_params=optional_params,
- embed_url=api_base,
- )
-
- ## LOGGING
- logging_obj.pre_call(
- input=input,
- api_key=api_key,
- additional_args={
- "complete_input_dict": data,
- "headers": headers,
- "api_base": api_base,
- },
- )
- ## COMPLETION CALL
- if client is None:
- client = get_async_httpx_client(
- llm_provider=litellm.LlmProviders.HUGGINGFACE,
- )
-
- response = await client.post(api_base, headers=headers, data=json.dumps(data))
-
- ## LOGGING
- logging_obj.post_call(
- input=input,
- api_key=api_key,
- additional_args={"complete_input_dict": data},
- original_response=response,
- )
-
- embeddings = response.json()
-
- if "error" in embeddings:
- raise HuggingfaceError(status_code=500, message=embeddings["error"])
-
- ## PROCESS RESPONSE ##
- return self._process_embedding_response(
- embeddings=embeddings,
- model_response=model_response,
- model=model,
- input=input,
- encoding=encoding,
- )
-
- def embedding(
- self,
- model: str,
- input: list,
- model_response: litellm.EmbeddingResponse,
- optional_params: dict,
- logging_obj: LiteLLMLoggingObj,
- encoding: Callable,
- api_key: Optional[str] = None,
- api_base: Optional[str] = None,
- timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
- aembedding: Optional[bool] = None,
- client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
- ) -> litellm.EmbeddingResponse:
- super().embedding()
- headers = self._validate_environment(api_key, headers=None)
- # print_verbose(f"{model}, {task}")
- embed_url = ""
- if "https" in model:
- embed_url = model
- elif api_base:
- embed_url = api_base
- elif "HF_API_BASE" in os.environ:
- embed_url = os.getenv("HF_API_BASE", "")
- elif "HUGGINGFACE_API_BASE" in os.environ:
- embed_url = os.getenv("HUGGINGFACE_API_BASE", "")
- else:
- embed_url = f"https://api-inference.huggingface.co/models/{model}"
-
- ## ROUTING ##
- if aembedding is True:
- return self.aembedding(
- input=input,
- model_response=model_response,
- timeout=timeout,
- logging_obj=logging_obj,
- headers=headers,
- api_base=embed_url, # type: ignore
- api_key=api_key,
- client=client if isinstance(client, AsyncHTTPHandler) else None,
- model=model,
- optional_params=optional_params,
- encoding=encoding,
- )
-
- ## TRANSFORMATION ##
-
- data = self._transform_input(
- input=input,
- model=model,
- call_type="sync",
- optional_params=optional_params,
- embed_url=embed_url,
- )
-
- ## LOGGING
- logging_obj.pre_call(
- input=input,
- api_key=api_key,
- additional_args={
- "complete_input_dict": data,
- "headers": headers,
- "api_base": embed_url,
- },
- )
- ## COMPLETION CALL
- if client is None or not isinstance(client, HTTPHandler):
- client = HTTPHandler(concurrent_limit=1)
- response = client.post(embed_url, headers=headers, data=json.dumps(data))
-
- ## LOGGING
- logging_obj.post_call(
- input=input,
- api_key=api_key,
- additional_args={"complete_input_dict": data},
- original_response=response,
- )
-
- embeddings = response.json()
-
- if "error" in embeddings:
- raise HuggingfaceError(status_code=500, message=embeddings["error"])
-
- ## PROCESS RESPONSE ##
- return self._process_embedding_response(
- embeddings=embeddings,
- model_response=model_response,
- model=model,
- input=input,
- encoding=encoding,
- )
-
- def _transform_logprobs(
- self, hf_response: Optional[List]
- ) -> Optional[TextCompletionLogprobs]:
- """
- Transform Hugging Face logprobs to OpenAI.Completion() format
- """
- if hf_response is None:
- return None
-
- # Initialize an empty list for the transformed logprobs
- _logprob: TextCompletionLogprobs = TextCompletionLogprobs(
- text_offset=[],
- token_logprobs=[],
- tokens=[],
- top_logprobs=[],
- )
-
- # For each Hugging Face response, transform the logprobs
- for response in hf_response:
- # Extract the relevant information from the response
- response_details = response["details"]
- top_tokens = response_details.get("top_tokens", {})
-
- for i, token in enumerate(response_details["prefill"]):
- # Extract the text of the token
- token_text = token["text"]
-
- # Extract the logprob of the token
- token_logprob = token["logprob"]
-
- # Add the token information to the 'token_info' list
- _logprob.tokens.append(token_text)
- _logprob.token_logprobs.append(token_logprob)
-
- # stub this to work with llm eval harness
- top_alt_tokens = {"": -1.0, "": -2.0, "": -3.0} # noqa: F601
- _logprob.top_logprobs.append(top_alt_tokens)
-
- # For each element in the 'tokens' list, extract the relevant information
- for i, token in enumerate(response_details["tokens"]):
- # Extract the text of the token
- token_text = token["text"]
-
- # Extract the logprob of the token
- token_logprob = token["logprob"]
-
- top_alt_tokens = {}
- temp_top_logprobs = []
- if top_tokens != {}:
- temp_top_logprobs = top_tokens[i]
-
- # top_alt_tokens should look like this: { "alternative_1": -1, "alternative_2": -2, "alternative_3": -3 }
- for elem in temp_top_logprobs:
- text = elem["text"]
- logprob = elem["logprob"]
- top_alt_tokens[text] = logprob
-
- # Add the token information to the 'token_info' list
- _logprob.tokens.append(token_text)
- _logprob.token_logprobs.append(token_logprob)
- _logprob.top_logprobs.append(top_alt_tokens)
-
- # Add the text offset of the token
- # This is computed as the sum of the lengths of all previous tokens
- _logprob.text_offset.append(
- sum(len(t["text"]) for t in response_details["tokens"][:i])
- )
-
- return _logprob
diff --git a/litellm/llms/maritalk.py b/litellm/llms/maritalk.py
index 813dfa8eae..10df36394b 100644
--- a/litellm/llms/maritalk.py
+++ b/litellm/llms/maritalk.py
@@ -4,59 +4,42 @@ import time
import traceback
import types
from enum import Enum
-from typing import Any, Callable, List, Optional
+from typing import Any, Callable, List, Optional, Union
-import requests # type: ignore
+from httpx._models import Headers
import litellm
+from litellm.llms.base_llm.transformation import BaseLLMException
+from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
from litellm.utils import Choices, Message, ModelResponse, Usage
-class MaritalkError(Exception):
- def __init__(self, status_code, message):
- self.status_code = status_code
- self.message = message
- super().__init__(
- self.message
- ) # Call the base class constructor with the parameters it needs
+class MaritalkError(BaseLLMException):
+ def __init__(
+ self,
+ status_code: int,
+ message: str,
+ headers: Optional[Union[dict, Headers]] = None,
+ ):
+ super().__init__(status_code=status_code, message=message, headers=headers)
-class MaritTalkConfig:
- """
- The class `MaritTalkConfig` provides configuration for the MaritTalk's API interface. Here are the parameters:
-
- - `max_tokens` (integer): Maximum number of tokens the model will generate as part of the response. Default is 1.
-
- - `model` (string): The model used for conversation. Default is 'maritalk'.
-
- - `do_sample` (boolean): If set to True, the API will generate a response using sampling. Default is True.
-
- - `temperature` (number): A non-negative float controlling the randomness in generation. Lower temperatures result in less random generations. Default is 0.7.
-
- - `top_p` (number): Selection threshold for token inclusion based on cumulative probability. Default is 0.95.
-
- - `repetition_penalty` (number): Penalty for repetition in the generated conversation. Default is 1.
-
- - `stopping_tokens` (list of string): List of tokens where the conversation can be stopped/stopped.
- """
-
- max_tokens: Optional[int] = None
- model: Optional[str] = None
- do_sample: Optional[bool] = None
- temperature: Optional[float] = None
- top_p: Optional[float] = None
- repetition_penalty: Optional[float] = None
- stopping_tokens: Optional[List[str]] = None
+class MaritalkConfig(OpenAIGPTConfig):
def __init__(
self,
- max_tokens: Optional[int] = None,
- model: Optional[str] = None,
- do_sample: Optional[bool] = None,
- temperature: Optional[float] = None,
+ frequency_penalty: Optional[float] = None,
+ presence_penalty: Optional[float] = None,
top_p: Optional[float] = None,
- repetition_penalty: Optional[float] = None,
- stopping_tokens: Optional[List[str]] = None,
+ top_k: Optional[int] = None,
+ temperature: Optional[float] = None,
+ max_tokens: Optional[int] = None,
+ n: Optional[int] = None,
+ stop: Optional[List[str]] = None,
+ stream: Optional[bool] = None,
+ stream_options: Optional[dict] = None,
+ tools: Optional[List[dict]] = None,
+ tool_choice: Optional[Union[str, dict]] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
@@ -65,129 +48,27 @@ class MaritTalkConfig:
@classmethod
def get_config(cls):
- return {
- k: v
- for k, v in cls.__dict__.items()
- if not k.startswith("__")
- and not isinstance(
- v,
- (
- types.FunctionType,
- types.BuiltinFunctionType,
- classmethod,
- staticmethod,
- ),
- )
- and v is not None
- }
+ return super().get_config()
+ def get_supported_openai_params(self, model: str) -> List:
+ return [
+ "frequency_penalty",
+ "presence_penalty",
+ "top_p",
+ "top_k",
+ "temperature",
+ "max_tokens",
+ "n",
+ "stop",
+ "stream",
+ "stream_options",
+ "tools",
+ "tool_choice",
+ ]
-def validate_environment(api_key):
- headers = {
- "accept": "application/json",
- "content-type": "application/json",
- }
- if api_key:
- headers["Authorization"] = f"Key {api_key}"
- return headers
-
-
-def completion(
- model: str,
- messages: list,
- api_base: str,
- model_response: ModelResponse,
- print_verbose: Callable,
- encoding,
- api_key,
- logging_obj,
- optional_params: dict,
- litellm_params=None,
- logger_fn=None,
-):
- headers = validate_environment(api_key)
- completion_url = api_base
- model = model
-
- ## Load Config
- config = litellm.MaritTalkConfig.get_config()
- for k, v in config.items():
- if (
- k not in optional_params
- ): # completion(top_k=3) > maritalk_config(top_k=3) <- allows for dynamic variables to be passed in
- optional_params[k] = v
-
- data = {
- "messages": messages,
- **optional_params,
- }
-
- ## LOGGING
- logging_obj.pre_call(
- input=messages,
- api_key=api_key,
- additional_args={"complete_input_dict": data},
- )
- ## COMPLETION CALL
- response = requests.post(
- completion_url,
- headers=headers,
- data=json.dumps(data),
- stream=optional_params["stream"] if "stream" in optional_params else False,
- )
- if "stream" in optional_params and optional_params["stream"] is True:
- return response.iter_lines()
- else:
- ## LOGGING
- logging_obj.post_call(
- input=messages,
- api_key=api_key,
- original_response=response.text,
- additional_args={"complete_input_dict": data},
+ def get_error_class(
+ self, error_message: str, status_code: int, headers: Union[dict, Headers]
+ ) -> BaseLLMException:
+ return MaritalkError(
+ status_code=status_code, message=error_message, headers=headers
)
- print_verbose(f"raw model_response: {response.text}")
- ## RESPONSE OBJECT
- completion_response = response.json()
- if "error" in completion_response:
- raise MaritalkError(
- message=completion_response["error"],
- status_code=response.status_code,
- )
- else:
- try:
- if len(completion_response["answer"]) > 0:
- model_response.choices[0].message.content = completion_response[ # type: ignore
- "answer"
- ]
- except Exception:
- raise MaritalkError(
- message=response.text, status_code=response.status_code
- )
-
- ## CALCULATING USAGE
- prompt = "".join(m["content"] for m in messages)
- prompt_tokens = len(encoding.encode(prompt))
- completion_tokens = len(
- encoding.encode(model_response["choices"][0]["message"].get("content", ""))
- )
-
- model_response.created = int(time.time())
- model_response.model = model
- usage = Usage(
- prompt_tokens=prompt_tokens,
- completion_tokens=completion_tokens,
- total_tokens=prompt_tokens + completion_tokens,
- )
- setattr(model_response, "usage", usage)
- return model_response
-
-
-def embedding(
- model: str,
- input: list,
- api_key: Optional[str],
- logging_obj: Any,
- model_response=None,
- encoding=None,
-):
- pass
diff --git a/litellm/llms/mistral/mistral_chat_transformation.py b/litellm/llms/mistral/mistral_chat_transformation.py
index aeb1a90fdb..50f08771c1 100644
--- a/litellm/llms/mistral/mistral_chat_transformation.py
+++ b/litellm/llms/mistral/mistral_chat_transformation.py
@@ -9,11 +9,16 @@ Docs - https://docs.mistral.ai/api/
import types
from typing import List, Literal, Optional, Tuple, Union
+from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
+from litellm.llms.prompt_templates.common_utils import (
+ handle_messages_with_content_list_to_str_conversion,
+ strip_none_values_from_message,
+)
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
-class MistralConfig:
+class MistralConfig(OpenAIGPTConfig):
"""
Reference: https://docs.mistral.ai/api/
@@ -67,23 +72,9 @@ class MistralConfig:
@classmethod
def get_config(cls):
- return {
- k: v
- for k, v in cls.__dict__.items()
- if not k.startswith("__")
- and not isinstance(
- v,
- (
- types.FunctionType,
- types.BuiltinFunctionType,
- classmethod,
- staticmethod,
- ),
- )
- and v is not None
- }
+ return super().get_config()
- def get_supported_openai_params(self):
+ def get_supported_openai_params(self, model: str) -> List[str]:
return [
"stream",
"temperature",
@@ -104,7 +95,13 @@ class MistralConfig:
else: # openai 'tool_choice' object param not supported by Mistral API
return "any"
- def map_openai_params(self, non_default_params: dict, optional_params: dict):
+ def map_openai_params(
+ self,
+ non_default_params: dict,
+ optional_params: dict,
+ model: str,
+ drop_params: bool,
+ ) -> dict:
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["max_tokens"] = value
@@ -150,8 +147,9 @@ class MistralConfig:
)
return api_base, dynamic_api_key
- @classmethod
- def _transform_messages(cls, messages: List[AllMessageValues]):
+ def _transform_messages(
+ self, messages: List[AllMessageValues]
+ ) -> List[AllMessageValues]:
"""
- handles scenario where content is list and not string
- content list is just text, and no images
@@ -160,48 +158,36 @@ class MistralConfig:
Motivation: mistral api doesn't support content as a list
"""
- new_messages = []
+ ## 1. If 'image_url' in content, then return as is
for m in messages:
- special_keys = ["role", "content", "tool_calls", "function_call"]
- extra_args = {}
- if isinstance(m, dict):
- for k, v in m.items():
- if k not in special_keys:
- extra_args[k] = v
- texts = ""
- _content = m.get("content")
- if _content is not None and isinstance(_content, list):
- for c in _content:
- _text: Optional[str] = c.get("text")
- if c["type"] == "image_url":
+ _content_block = m.get("content")
+ if _content_block and isinstance(_content_block, list):
+ for c in _content_block:
+ if c.get("type") == "image_url":
return messages
- elif c["type"] == "text" and isinstance(_text, str):
- texts += _text
- elif _content is not None and isinstance(_content, str):
- texts = _content
- new_m = {"role": m["role"], "content": texts, **extra_args}
+ ## 2. If content is list, then convert to string
+ messages = handle_messages_with_content_list_to_str_conversion(messages)
- if m.get("tool_calls"):
- new_m["tool_calls"] = m.get("tool_calls")
+ ## 3. Handle name in message
+ new_messages: List[AllMessageValues] = []
+ for m in messages:
+ m = MistralConfig._handle_name_in_message(m)
+ m = strip_none_values_from_message(m) # prevents 'extra_forbidden' error
+ new_messages.append(m)
- new_m = cls._handle_name_in_message(new_m)
-
- new_messages.append(new_m)
return new_messages
@classmethod
- def _handle_name_in_message(cls, message: dict) -> dict:
+ def _handle_name_in_message(cls, message: AllMessageValues) -> AllMessageValues:
"""
Mistral API only supports `name` in tool messages
If role == tool, then we keep `name`
Otherwise, we drop `name`
"""
- if message.get("name") is not None:
- if message["role"] == "tool":
- message["name"] = message.get("name")
- else:
- message.pop("name", None)
+ _name = message.get("name") # type: ignore
+ if _name is not None and message["role"] != "tool":
+ message.pop("name", None) # type: ignore
return message
diff --git a/litellm/llms/nlp_cloud/chat/handler.py b/litellm/llms/nlp_cloud/chat/handler.py
new file mode 100644
index 0000000000..e82086ebf3
--- /dev/null
+++ b/litellm/llms/nlp_cloud/chat/handler.py
@@ -0,0 +1,140 @@
+import json
+import os
+import time
+import types
+from enum import Enum
+from typing import Any, Callable, List, Optional, Union
+
+import httpx
+
+import litellm
+from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
+from litellm.llms.custom_httpx.http_handler import (
+ AsyncHTTPHandler,
+ HTTPHandler,
+ _get_httpx_client,
+ get_async_httpx_client,
+)
+from litellm.types.llms.openai import AllMessageValues
+from litellm.utils import ModelResponse, Usage
+
+from ..common_utils import NLPCloudError
+from .transformation import NLPCloudConfig
+
+nlp_config = NLPCloudConfig()
+
+
+def completion(
+ model: str,
+ messages: list,
+ api_base: str,
+ model_response: ModelResponse,
+ print_verbose: Callable,
+ encoding,
+ api_key,
+ logging_obj,
+ optional_params: dict,
+ litellm_params: dict,
+ logger_fn=None,
+ default_max_tokens_to_sample=None,
+ client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
+ headers={},
+):
+ headers = nlp_config.validate_environment(
+ api_key=api_key,
+ headers=headers,
+ model=model,
+ messages=messages,
+ optional_params=optional_params,
+ )
+
+ ## Load Config
+ config = litellm.NLPCloudConfig.get_config()
+ for k, v in config.items():
+ if (
+ k not in optional_params
+ ): # completion(top_k=3) > togetherai_config(top_k=3) <- allows for dynamic variables to be passed in
+ optional_params[k] = v
+
+ completion_url_fragment_1 = api_base
+ completion_url_fragment_2 = "/generation"
+ model = model
+
+ completion_url = completion_url_fragment_1 + model + completion_url_fragment_2
+ data = nlp_config.transform_request(
+ model=model,
+ messages=messages,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ headers=headers,
+ )
+
+ ## LOGGING
+ logging_obj.pre_call(
+ input=None,
+ api_key=api_key,
+ additional_args={
+ "complete_input_dict": data,
+ "headers": headers,
+ "api_base": completion_url,
+ },
+ )
+ ## COMPLETION CALL
+ if client is None or not isinstance(client, HTTPHandler):
+ client = _get_httpx_client()
+
+ response = client.post(
+ completion_url,
+ headers=headers,
+ data=json.dumps(data),
+ stream=optional_params["stream"] if "stream" in optional_params else False,
+ )
+ if "stream" in optional_params and optional_params["stream"] is True:
+ return clean_and_iterate_chunks(response)
+ else:
+ return nlp_config.transform_response(
+ model=model,
+ raw_response=response,
+ model_response=model_response,
+ logging_obj=logging_obj,
+ api_key=api_key,
+ request_data=data,
+ messages=messages,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ encoding=encoding,
+ )
+
+
+# def clean_and_iterate_chunks(response):
+# def process_chunk(chunk):
+# print(f"received chunk: {chunk}")
+# cleaned_chunk = chunk.decode("utf-8")
+# # Perform further processing based on your needs
+# return cleaned_chunk
+
+
+# for line in response.iter_lines():
+# if line:
+# yield process_chunk(line)
+def clean_and_iterate_chunks(response):
+ buffer = b""
+
+ for chunk in response.iter_content(chunk_size=1024):
+ if not chunk:
+ break
+
+ buffer += chunk
+ while b"\x00" in buffer:
+ buffer = buffer.replace(b"\x00", b"")
+ yield buffer.decode("utf-8")
+ buffer = b""
+
+ # No more data expected, yield any remaining data in the buffer
+ if buffer:
+ yield buffer.decode("utf-8")
+
+
+def embedding():
+ # logic for parsing in - calling - parsing out model embedding calls
+ pass
diff --git a/litellm/llms/nlp_cloud.py b/litellm/llms/nlp_cloud/chat/transformation.py
similarity index 50%
rename from litellm/llms/nlp_cloud.py
rename to litellm/llms/nlp_cloud/chat/transformation.py
index a959ea49a3..ec5540ca62 100644
--- a/litellm/llms/nlp_cloud.py
+++ b/litellm/llms/nlp_cloud/chat/transformation.py
@@ -1,26 +1,25 @@
import json
-import os
import time
-import types
-from enum import Enum
-from typing import Callable, Optional
+from typing import TYPE_CHECKING, Any, List, Optional, Union
-import requests # type: ignore
+import httpx
-import litellm
+from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
+from litellm.llms.prompt_templates.common_utils import convert_content_list_to_str
+from litellm.types.llms.openai import AllMessageValues
from litellm.utils import ModelResponse, Usage
+from ..common_utils import NLPCloudError
-class NLPCloudError(Exception):
- def __init__(self, status_code, message):
- self.status_code = status_code
- self.message = message
- super().__init__(
- self.message
- ) # Call the base class constructor with the parameters it needs
+if TYPE_CHECKING:
+ from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
+
+ LoggingClass = LiteLLMLoggingObj
+else:
+ LoggingClass = Any
-class NLPCloudConfig:
+class NLPCloudConfig(BaseConfig):
"""
Reference: https://docs.nlpcloud.com/#generation
@@ -84,106 +83,119 @@ class NLPCloudConfig:
@classmethod
def get_config(cls):
- return {
- k: v
- for k, v in cls.__dict__.items()
- if not k.startswith("__")
- and not isinstance(
- v,
- (
- types.FunctionType,
- types.BuiltinFunctionType,
- classmethod,
- staticmethod,
- ),
- )
- and v is not None
+ return super().get_config()
+
+ def validate_environment(
+ self,
+ headers: dict,
+ model: str,
+ messages: List[AllMessageValues],
+ optional_params: dict,
+ api_key: Optional[str] = None,
+ ) -> dict:
+ headers = {
+ "accept": "application/json",
+ "content-type": "application/json",
+ }
+ if api_key:
+ headers["Authorization"] = f"Token {api_key}"
+ return headers
+
+ def get_supported_openai_params(self, model: str) -> List:
+ return [
+ "max_tokens",
+ "stream",
+ "temperature",
+ "top_p",
+ "presence_penalty",
+ "frequency_penalty",
+ "n",
+ "stop",
+ ]
+
+ def map_openai_params(
+ self,
+ non_default_params: dict,
+ optional_params: dict,
+ model: str,
+ drop_params: bool,
+ ) -> dict:
+ for param, value in non_default_params.items():
+ if param == "max_tokens":
+ optional_params["max_length"] = value
+ if param == "stream":
+ optional_params["stream"] = value
+ if param == "temperature":
+ optional_params["temperature"] = value
+ if param == "top_p":
+ optional_params["top_p"] = value
+ if param == "presence_penalty":
+ optional_params["presence_penalty"] = value
+ if param == "frequency_penalty":
+ optional_params["frequency_penalty"] = value
+ if param == "n":
+ optional_params["num_return_sequences"] = value
+ if param == "stop":
+ optional_params["stop_sequences"] = value
+ return optional_params
+
+ def get_error_class(
+ self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
+ ) -> BaseLLMException:
+ return NLPCloudError(
+ status_code=status_code, message=error_message, headers=headers
+ )
+
+ def transform_request(
+ self,
+ model: str,
+ messages: List[AllMessageValues],
+ optional_params: dict,
+ litellm_params: dict,
+ headers: dict,
+ ) -> dict:
+ text = " ".join(convert_content_list_to_str(message) for message in messages)
+
+ data = {
+ "text": text,
+ **optional_params,
}
+ return data
-def validate_environment(api_key):
- headers = {
- "accept": "application/json",
- "content-type": "application/json",
- }
- if api_key:
- headers["Authorization"] = f"Token {api_key}"
- return headers
-
-
-def completion(
- model: str,
- messages: list,
- api_base: str,
- model_response: ModelResponse,
- print_verbose: Callable,
- encoding,
- api_key,
- logging_obj,
- optional_params: dict,
- litellm_params=None,
- logger_fn=None,
- default_max_tokens_to_sample=None,
-):
- headers = validate_environment(api_key)
-
- ## Load Config
- config = litellm.NLPCloudConfig.get_config()
- for k, v in config.items():
- if (
- k not in optional_params
- ): # completion(top_k=3) > togetherai_config(top_k=3) <- allows for dynamic variables to be passed in
- optional_params[k] = v
-
- completion_url_fragment_1 = api_base
- completion_url_fragment_2 = "/generation"
- model = model
- text = " ".join(message["content"] for message in messages)
-
- data = {
- "text": text,
- **optional_params,
- }
-
- completion_url = completion_url_fragment_1 + model + completion_url_fragment_2
-
- ## LOGGING
- logging_obj.pre_call(
- input=text,
- api_key=api_key,
- additional_args={
- "complete_input_dict": data,
- "headers": headers,
- "api_base": completion_url,
- },
- )
- ## COMPLETION CALL
- response = requests.post(
- completion_url,
- headers=headers,
- data=json.dumps(data),
- stream=optional_params["stream"] if "stream" in optional_params else False,
- )
- if "stream" in optional_params and optional_params["stream"] is True:
- return clean_and_iterate_chunks(response)
- else:
+ def transform_response(
+ self,
+ model: str,
+ raw_response: httpx.Response,
+ model_response: ModelResponse,
+ logging_obj: LoggingClass,
+ request_data: dict,
+ messages: List[AllMessageValues],
+ optional_params: dict,
+ litellm_params: dict,
+ encoding: Any,
+ api_key: Optional[str] = None,
+ json_mode: Optional[bool] = None,
+ ) -> ModelResponse:
## LOGGING
logging_obj.post_call(
- input=text,
+ input=None,
api_key=api_key,
- original_response=response.text,
- additional_args={"complete_input_dict": data},
+ original_response=raw_response.text,
+ additional_args={"complete_input_dict": request_data},
)
- print_verbose(f"raw model_response: {response.text}")
+
## RESPONSE OBJECT
try:
- completion_response = response.json()
+ completion_response = raw_response.json()
except Exception:
- raise NLPCloudError(message=response.text, status_code=response.status_code)
+ raise NLPCloudError(
+ message=raw_response.text, status_code=raw_response.status_code
+ )
if "error" in completion_response:
raise NLPCloudError(
message=completion_response["error"],
- status_code=response.status_code,
+ status_code=raw_response.status_code,
)
else:
try:
@@ -194,7 +206,7 @@ def completion(
except Exception:
raise NLPCloudError(
message=json.dumps(completion_response),
- status_code=response.status_code,
+ status_code=raw_response.status_code,
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
@@ -210,37 +222,3 @@ def completion(
)
setattr(model_response, "usage", usage)
return model_response
-
-
-# def clean_and_iterate_chunks(response):
-# def process_chunk(chunk):
-# print(f"received chunk: {chunk}")
-# cleaned_chunk = chunk.decode("utf-8")
-# # Perform further processing based on your needs
-# return cleaned_chunk
-
-
-# for line in response.iter_lines():
-# if line:
-# yield process_chunk(line)
-def clean_and_iterate_chunks(response):
- buffer = b""
-
- for chunk in response.iter_content(chunk_size=1024):
- if not chunk:
- break
-
- buffer += chunk
- while b"\x00" in buffer:
- buffer = buffer.replace(b"\x00", b"")
- yield buffer.decode("utf-8")
- buffer = b""
-
- # No more data expected, yield any remaining data in the buffer
- if buffer:
- yield buffer.decode("utf-8")
-
-
-def embedding():
- # logic for parsing in - calling - parsing out model embedding calls
- pass
diff --git a/litellm/llms/nlp_cloud/common_utils.py b/litellm/llms/nlp_cloud/common_utils.py
new file mode 100644
index 0000000000..5488a2fd7a
--- /dev/null
+++ b/litellm/llms/nlp_cloud/common_utils.py
@@ -0,0 +1,15 @@
+from typing import Optional, Union
+
+import httpx
+
+from litellm.llms.base_llm.transformation import BaseLLMException
+
+
+class NLPCloudError(BaseLLMException):
+ def __init__(
+ self,
+ status_code: int,
+ message: str,
+ headers: Optional[Union[dict, httpx.Headers]] = None,
+ ):
+ super().__init__(status_code=status_code, message=message, headers=headers)
diff --git a/litellm/llms/nvidia_nim/chat.py b/litellm/llms/nvidia_nim/chat.py
index 99c88345e1..3f50c02dd9 100644
--- a/litellm/llms/nvidia_nim/chat.py
+++ b/litellm/llms/nvidia_nim/chat.py
@@ -11,8 +11,10 @@ API calling is done using the OpenAI SDK with an api_base
import types
from typing import Optional, Union
+from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
-class NvidiaNimConfig:
+
+class NvidiaNimConfig(OpenAIGPTConfig):
"""
Reference: https://docs.api.nvidia.com/nim/reference/databricks-dbrx-instruct-infer
@@ -42,21 +44,7 @@ class NvidiaNimConfig:
@classmethod
def get_config(cls):
- return {
- k: v
- for k, v in cls.__dict__.items()
- if not k.startswith("__")
- and not isinstance(
- v,
- (
- types.FunctionType,
- types.BuiltinFunctionType,
- classmethod,
- staticmethod,
- ),
- )
- and v is not None
- }
+ return super().get_config()
def get_supported_openai_params(self, model: str) -> list:
"""
@@ -132,7 +120,11 @@ class NvidiaNimConfig:
]
def map_openai_params(
- self, model: str, non_default_params: dict, optional_params: dict
+ self,
+ non_default_params: dict,
+ optional_params: dict,
+ model: str,
+ drop_params: bool,
) -> dict:
supported_openai_params = self.get_supported_openai_params(model=model)
for param, value in non_default_params.items():
diff --git a/litellm/llms/ollama/completion/transformation.py b/litellm/llms/ollama/completion/transformation.py
index 4e08419a7a..cc5fddf9f7 100644
--- a/litellm/llms/ollama/completion/transformation.py
+++ b/litellm/llms/ollama/completion/transformation.py
@@ -242,6 +242,7 @@ class OllamaConfig(BaseConfig):
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
+ litellm_params: dict,
encoding: str,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
diff --git a/litellm/llms/ollama_chat.py b/litellm/llms/ollama_chat.py
index ce0df139d0..47555a3a48 100644
--- a/litellm/llms/ollama_chat.py
+++ b/litellm/llms/ollama_chat.py
@@ -14,6 +14,7 @@ from pydantic import BaseModel
import litellm
from litellm import verbose_logger
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
+from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
from litellm.types.llms.ollama import OllamaToolCall, OllamaToolCallFunction
from litellm.types.llms.openai import ChatCompletionAssistantToolCall
from litellm.types.utils import StreamingChoices
@@ -30,7 +31,7 @@ class OllamaError(Exception):
) # Call the base class constructor with the parameters it needs
-class OllamaChatConfig:
+class OllamaChatConfig(OpenAIGPTConfig):
"""
Reference: https://github.com/ollama/ollama/blob/main/docs/api.md#parameters
@@ -81,15 +82,10 @@ class OllamaChatConfig:
num_thread: Optional[int] = None
repeat_last_n: Optional[int] = None
repeat_penalty: Optional[float] = None
- temperature: Optional[float] = None
seed: Optional[int] = None
- stop: Optional[list] = (
- None # stop is a list based on this - https://github.com/ollama/ollama/pull/442
- )
tfs_z: Optional[float] = None
num_predict: Optional[int] = None
top_k: Optional[int] = None
- top_p: Optional[float] = None
system: Optional[str] = None
template: Optional[str] = None
@@ -120,26 +116,9 @@ class OllamaChatConfig:
@classmethod
def get_config(cls):
- return {
- k: v
- for k, v in cls.__dict__.items()
- if not k.startswith("__")
- and k != "function_name" # special param for function calling
- and not isinstance(
- v,
- (
- types.FunctionType,
- types.BuiltinFunctionType,
- classmethod,
- staticmethod,
- ),
- )
- and v is not None
- }
+ return super().get_config()
- def get_supported_openai_params(
- self,
- ):
+ def get_supported_openai_params(self, model: str):
return [
"max_tokens",
"max_completion_tokens",
@@ -156,8 +135,12 @@ class OllamaChatConfig:
]
def map_openai_params(
- self, model: str, non_default_params: dict, optional_params: dict
- ):
+ self,
+ non_default_params: dict,
+ optional_params: dict,
+ model: str,
+ drop_params: bool,
+ ) -> dict:
for param, value in non_default_params.items():
if param == "max_tokens" or param == "max_completion_tokens":
optional_params["num_predict"] = value
diff --git a/litellm/llms/oobabooga.py b/litellm/llms/oobabooga/chat/oobabooga.py
similarity index 58%
rename from litellm/llms/oobabooga.py
rename to litellm/llms/oobabooga/chat/oobabooga.py
index d47e563113..b7852eed49 100644
--- a/litellm/llms/oobabooga.py
+++ b/litellm/llms/oobabooga/chat/oobabooga.py
@@ -6,28 +6,14 @@ from typing import Any, Callable, Optional
import requests # type: ignore
+from litellm.llms.custom_httpx.http_handler import HTTPHandler, _get_httpx_client
from litellm.utils import EmbeddingResponse, ModelResponse, Usage
-from .prompt_templates.factory import custom_prompt, prompt_factory
+from ...prompt_templates.factory import custom_prompt, prompt_factory
+from ..common_utils import OobaboogaError
+from .transformation import OobaboogaConfig
-
-class OobaboogaError(Exception):
- def __init__(self, status_code, message):
- self.status_code = status_code
- self.message = message
- super().__init__(
- self.message
- ) # Call the base class constructor with the parameters it needs
-
-
-def validate_environment(api_key):
- headers = {
- "accept": "application/json",
- "content-type": "application/json",
- }
- if api_key:
- headers["Authorization"] = f"Token {api_key}"
- return headers
+oobabooga_config = OobaboogaConfig()
def completion(
@@ -40,12 +26,18 @@ def completion(
api_key,
logging_obj,
optional_params: dict,
+ litellm_params: dict,
custom_prompt_dict={},
- litellm_params=None,
logger_fn=None,
default_max_tokens_to_sample=None,
):
- headers = validate_environment(api_key)
+ headers = oobabooga_config.validate_environment(
+ api_key=api_key,
+ headers={},
+ model=model,
+ messages=messages,
+ optional_params=optional_params,
+ )
if "https" in model:
completion_url = model
elif api_base:
@@ -58,10 +50,13 @@ def completion(
model = model
completion_url = completion_url + "/v1/chat/completions"
- data = {
- "messages": messages,
- **optional_params,
- }
+ data = oobabooga_config.transform_request(
+ model=model,
+ messages=messages,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ headers=headers,
+ )
## LOGGING
logging_obj.pre_call(
@@ -70,8 +65,8 @@ def completion(
additional_args={"complete_input_dict": data},
)
## COMPLETION CALL
-
- response = requests.post(
+ client = _get_httpx_client()
+ response = client.post(
completion_url,
headers=headers,
data=json.dumps(data),
@@ -80,44 +75,18 @@ def completion(
if "stream" in optional_params and optional_params["stream"] is True:
return response.iter_lines()
else:
- ## LOGGING
- logging_obj.post_call(
- input=messages,
+ return oobabooga_config.transform_response(
+ model=model,
+ raw_response=response,
+ model_response=model_response,
+ logging_obj=logging_obj,
api_key=api_key,
- original_response=response.text,
- additional_args={"complete_input_dict": data},
+ request_data=data,
+ messages=messages,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ encoding=encoding,
)
- print_verbose(f"raw model_response: {response.text}")
- ## RESPONSE OBJECT
- try:
- completion_response = response.json()
- except Exception:
- raise OobaboogaError(
- message=response.text, status_code=response.status_code
- )
- if "error" in completion_response:
- raise OobaboogaError(
- message=completion_response["error"],
- status_code=response.status_code,
- )
- else:
- try:
- model_response.choices[0].message.content = completion_response["choices"][0]["message"]["content"] # type: ignore
- except Exception:
- raise OobaboogaError(
- message=json.dumps(completion_response),
- status_code=response.status_code,
- )
-
- model_response.created = int(time.time())
- model_response.model = model
- usage = Usage(
- prompt_tokens=completion_response["usage"]["prompt_tokens"],
- completion_tokens=completion_response["usage"]["completion_tokens"],
- total_tokens=completion_response["usage"]["total_tokens"],
- )
- setattr(model_response, "usage", usage)
- return model_response
def embedding(
@@ -127,7 +96,7 @@ def embedding(
api_key: Optional[str],
api_base: Optional[str],
logging_obj: Any,
- optional_params=None,
+ optional_params: dict,
encoding=None,
):
# Create completion URL
@@ -153,7 +122,13 @@ def embedding(
)
# Send POST request
- headers = validate_environment(api_key)
+ headers = oobabooga_config.validate_environment(
+ api_key=api_key,
+ headers={},
+ model=model,
+ messages=[],
+ optional_params=optional_params,
+ )
response = requests.post(embeddings_url, headers=headers, json=data)
if not response.ok:
raise OobaboogaError(message=response.text, status_code=response.status_code)
diff --git a/litellm/llms/oobabooga/chat/transformation.py b/litellm/llms/oobabooga/chat/transformation.py
new file mode 100644
index 0000000000..18944a7b80
--- /dev/null
+++ b/litellm/llms/oobabooga/chat/transformation.py
@@ -0,0 +1,110 @@
+import json
+import time
+import types
+from typing import TYPE_CHECKING, Any, List, Optional, Union
+
+import httpx
+
+import litellm
+from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
+from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
+from litellm.llms.prompt_templates.common_utils import convert_content_list_to_str
+from litellm.llms.prompt_templates.factory import custom_prompt, prompt_factory
+from litellm.types.llms.openai import AllMessageValues
+from litellm.types.utils import Choices, Message, ModelResponse, Usage
+from litellm.utils import token_counter
+
+from ..common_utils import OobaboogaError
+
+if TYPE_CHECKING:
+ from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
+
+ LoggingClass = LiteLLMLoggingObj
+else:
+ LoggingClass = Any
+
+
+class OobaboogaConfig(OpenAIGPTConfig):
+ def _transform_messages(
+ self, messages: List[AllMessageValues]
+ ) -> List[AllMessageValues]:
+ return messages
+
+ def get_error_class(
+ self,
+ error_message: str,
+ status_code: int,
+ headers: Optional[Union[dict, httpx.Headers]] = None,
+ ) -> BaseLLMException:
+ return OobaboogaError(
+ status_code=status_code, message=error_message, headers=headers
+ )
+
+ def transform_response(
+ self,
+ model: str,
+ raw_response: httpx.Response,
+ model_response: ModelResponse,
+ logging_obj: LoggingClass,
+ request_data: dict,
+ messages: List[AllMessageValues],
+ optional_params: dict,
+ litellm_params: dict,
+ encoding: Any,
+ api_key: Optional[str] = None,
+ json_mode: Optional[bool] = None,
+ ) -> ModelResponse:
+ ## LOGGING
+ logging_obj.post_call(
+ input=messages,
+ api_key=api_key,
+ original_response=raw_response.text,
+ additional_args={"complete_input_dict": request_data},
+ )
+
+ ## RESPONSE OBJECT
+ try:
+ completion_response = raw_response.json()
+ except Exception:
+ raise OobaboogaError(
+ message=raw_response.text, status_code=raw_response.status_code
+ )
+ if "error" in completion_response:
+ raise OobaboogaError(
+ message=completion_response["error"],
+ status_code=raw_response.status_code,
+ )
+ else:
+ try:
+ model_response.choices[0].message.content = completion_response["choices"][0]["message"]["content"] # type: ignore
+ except Exception as e:
+ raise OobaboogaError(
+ message=str(e),
+ status_code=raw_response.status_code,
+ )
+
+ model_response.created = int(time.time())
+ model_response.model = model
+ usage = Usage(
+ prompt_tokens=completion_response["usage"]["prompt_tokens"],
+ completion_tokens=completion_response["usage"]["completion_tokens"],
+ total_tokens=completion_response["usage"]["total_tokens"],
+ )
+ setattr(model_response, "usage", usage)
+ return model_response
+
+ def validate_environment(
+ self,
+ headers: dict,
+ model: str,
+ messages: List[AllMessageValues],
+ optional_params: dict,
+ api_key: Optional[str] = None,
+ ) -> dict:
+ headers = {
+ "accept": "application/json",
+ "content-type": "application/json",
+ }
+ if api_key is not None:
+ headers["Authorization"] = f"Token {api_key}"
+ return headers
diff --git a/litellm/llms/oobabooga/common_utils.py b/litellm/llms/oobabooga/common_utils.py
new file mode 100644
index 0000000000..3612fed407
--- /dev/null
+++ b/litellm/llms/oobabooga/common_utils.py
@@ -0,0 +1,15 @@
+from typing import Optional, Union
+
+import httpx
+
+from litellm.llms.base_llm.transformation import BaseLLMException
+
+
+class OobaboogaError(BaseLLMException):
+ def __init__(
+ self,
+ status_code: int,
+ message: str,
+ headers: Optional[Union[dict, httpx.Headers]] = None,
+ ):
+ super().__init__(status_code=status_code, message=message, headers=headers)
diff --git a/litellm/llms/openai/chat/gpt_transformation.py b/litellm/llms/openai/chat/gpt_transformation.py
index d1496d8133..87b66ddc69 100644
--- a/litellm/llms/openai/chat/gpt_transformation.py
+++ b/litellm/llms/openai/chat/gpt_transformation.py
@@ -197,7 +197,8 @@ class OpenAIGPTConfig(BaseConfig):
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
- encoding: str,
+ litellm_params: dict,
+ encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
diff --git a/litellm/llms/openai/chat/o1_handler.py b/litellm/llms/openai/chat/o1_handler.py
index e8515ac226..d141498cc4 100644
--- a/litellm/llms/openai/chat/o1_handler.py
+++ b/litellm/llms/openai/chat/o1_handler.py
@@ -1,63 +1,3 @@
"""
-Handler file for calls to OpenAI's o1 family of models
-
-Written separately to handle faking streaming for o1 models.
+LLM Calling done in `openai/openai.py`
"""
-
-import asyncio
-from typing import Any, Callable, List, Optional, Union
-
-from httpx._config import Timeout
-
-from litellm.llms.bedrock.chat.invoke_handler import MockResponseIterator
-from litellm.llms.openai.openai import OpenAIChatCompletion
-from litellm.types.utils import ModelResponse
-from litellm.utils import CustomStreamWrapper
-
-
-class OpenAIO1ChatCompletion(OpenAIChatCompletion):
-
- def completion(
- self,
- model_response: ModelResponse,
- timeout: Union[float, Timeout],
- optional_params: dict,
- logging_obj: Any,
- model: Optional[str] = None,
- messages: Optional[list] = None,
- print_verbose: Optional[Callable[..., Any]] = None,
- api_key: Optional[str] = None,
- api_base: Optional[str] = None,
- acompletion: bool = False,
- litellm_params=None,
- logger_fn=None,
- headers: Optional[dict] = None,
- custom_prompt_dict: dict = {},
- client=None,
- organization: Optional[str] = None,
- custom_llm_provider: Optional[str] = None,
- drop_params: Optional[bool] = None,
- ):
- # stream: Optional[bool] = optional_params.pop("stream", False)
- response = super().completion(
- model_response,
- timeout,
- optional_params,
- logging_obj,
- model,
- messages,
- print_verbose,
- api_key,
- api_base,
- acompletion,
- litellm_params,
- logger_fn,
- headers,
- custom_prompt_dict,
- client,
- organization,
- custom_llm_provider,
- drop_params,
- )
-
- return response
diff --git a/litellm/llms/openai/common_utils.py b/litellm/llms/openai/common_utils.py
index 5da8c4925f..e5b926f6aa 100644
--- a/litellm/llms/openai/common_utils.py
+++ b/litellm/llms/openai/common_utils.py
@@ -3,7 +3,7 @@ Common helpers / utils across al OpenAI endpoints
"""
import json
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Union
import httpx
import openai
@@ -18,7 +18,7 @@ class OpenAIError(BaseLLMException):
message: str,
request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None,
- headers: Optional[httpx.Headers] = None,
+ headers: Optional[Union[dict, httpx.Headers]] = None,
):
self.status_code = status_code
self.message = message
diff --git a/litellm/llms/openai/openai.py b/litellm/llms/openai/openai.py
index 108a31d19a..54f52c50fe 100644
--- a/litellm/llms/openai/openai.py
+++ b/litellm/llms/openai/openai.py
@@ -4,7 +4,17 @@ import os
import time
import traceback
import types
-from typing import Any, Callable, Coroutine, Iterable, Literal, Optional, Union, cast
+from typing import (
+ Any,
+ Callable,
+ Coroutine,
+ Iterable,
+ List,
+ Literal,
+ Optional,
+ Union,
+ cast,
+)
import httpx
import openai
@@ -18,6 +28,7 @@ import litellm
from litellm import LlmProviders
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
+from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS
from litellm.secret_managers.main import get_secret_str
from litellm.types.utils import ProviderField
@@ -35,6 +46,7 @@ from litellm.utils import (
from ...types.llms.openai import *
from ..base import BaseLLM
from ..prompt_templates.factory import custom_prompt, prompt_factory
+from .chat.gpt_transformation import OpenAIGPTConfig
from .common_utils import OpenAIError, drop_params_from_unprocessable_entity_error
@@ -81,135 +93,7 @@ class MistralEmbeddingConfig:
return optional_params
-class DeepInfraConfig:
- """
- Reference: https://deepinfra.com/docs/advanced/openai_api
-
- The class `DeepInfra` provides configuration for the DeepInfra's Chat Completions API interface. Below are the parameters:
- """
-
- frequency_penalty: Optional[int] = None
- function_call: Optional[Union[str, dict]] = None
- functions: Optional[list] = None
- logit_bias: Optional[dict] = None
- max_tokens: Optional[int] = None
- n: Optional[int] = None
- presence_penalty: Optional[int] = None
- stop: Optional[Union[str, list]] = None
- temperature: Optional[int] = None
- top_p: Optional[int] = None
- response_format: Optional[dict] = None
- tools: Optional[list] = None
- tool_choice: Optional[Union[str, dict]] = None
-
- def __init__(
- self,
- frequency_penalty: Optional[int] = None,
- function_call: Optional[Union[str, dict]] = None,
- functions: Optional[list] = None,
- logit_bias: Optional[dict] = None,
- max_tokens: Optional[int] = None,
- n: Optional[int] = None,
- presence_penalty: Optional[int] = None,
- stop: Optional[Union[str, list]] = None,
- temperature: Optional[int] = None,
- top_p: Optional[int] = None,
- response_format: Optional[dict] = None,
- tools: Optional[list] = None,
- tool_choice: Optional[Union[str, dict]] = None,
- ) -> None:
- locals_ = locals().copy()
- for key, value in locals_.items():
- if key != "self" and value is not None:
- setattr(self.__class__, key, value)
-
- @classmethod
- def get_config(cls):
- return {
- k: v
- for k, v in cls.__dict__.items()
- if not k.startswith("__")
- and not isinstance(
- v,
- (
- types.FunctionType,
- types.BuiltinFunctionType,
- classmethod,
- staticmethod,
- ),
- )
- and v is not None
- }
-
- def get_supported_openai_params(self):
- return [
- "stream",
- "frequency_penalty",
- "function_call",
- "functions",
- "logit_bias",
- "max_tokens",
- "max_completion_tokens",
- "n",
- "presence_penalty",
- "stop",
- "temperature",
- "top_p",
- "response_format",
- "tools",
- "tool_choice",
- ]
-
- def map_openai_params(
- self,
- non_default_params: dict,
- optional_params: dict,
- model: str,
- drop_params: bool,
- ) -> dict:
- supported_openai_params = self.get_supported_openai_params()
- for param, value in non_default_params.items():
- if (
- param == "temperature"
- and value == 0
- and model == "mistralai/Mistral-7B-Instruct-v0.1"
- ): # this model does no support temperature == 0
- value = 0.0001 # close to 0
- if param == "tool_choice":
- if (
- value != "auto" and value != "none"
- ): # https://deepinfra.com/docs/advanced/function_calling
- ## UNSUPPORTED TOOL CHOICE VALUE
- if litellm.drop_params is True or drop_params is True:
- value = None
- else:
- raise litellm.utils.UnsupportedParamsError(
- message="Deepinfra doesn't support tool_choice={}. To drop unsupported openai params from the call, set `litellm.drop_params = True`".format(
- value
- ),
- status_code=400,
- )
- elif param == "max_completion_tokens":
- optional_params["max_tokens"] = value
- elif param in supported_openai_params:
- if value is not None:
- optional_params[param] = value
- return optional_params
-
- def _get_openai_compatible_provider_info(
- self, api_base: Optional[str], api_key: Optional[str]
- ) -> Tuple[Optional[str], Optional[str]]:
- # deepinfra is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1
- api_base = (
- api_base
- or get_secret_str("DEEPINFRA_API_BASE")
- or "https://api.deepinfra.com/v1/openai"
- )
- dynamic_api_key = api_key or get_secret_str("DEEPINFRA_API_KEY")
- return api_base, dynamic_api_key
-
-
-class OpenAIConfig:
+class OpenAIConfig(BaseConfig):
"""
Reference: https://platform.openai.com/docs/api-reference/chat/create
@@ -273,25 +157,12 @@ class OpenAIConfig:
@classmethod
def get_config(cls):
- return {
- k: v
- for k, v in cls.__dict__.items()
- if not k.startswith("__")
- and not isinstance(
- v,
- (
- types.FunctionType,
- types.BuiltinFunctionType,
- classmethod,
- staticmethod,
- ),
- )
- and v is not None
- }
+ return super().get_config()
def get_supported_openai_params(self, model: str) -> list:
"""
- This function returns the list of supported openai parameters for a given OpenAI Model
+ This function returns the list
+ of supported openai parameters for a given OpenAI Model
- If O1 model, returns O1 supported params
- If gpt-audio model, returns gpt-audio supported params
@@ -319,6 +190,11 @@ class OpenAIConfig:
optional_params[param] = value
return optional_params
+ def _transform_messages(
+ self, messages: List[AllMessageValues]
+ ) -> List[AllMessageValues]:
+ return messages
+
def map_openai_params(
self,
non_default_params: dict,
@@ -349,6 +225,55 @@ class OpenAIConfig:
drop_params=drop_params,
)
+ def get_error_class(
+ self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
+ ) -> BaseLLMException:
+ return OpenAIError(
+ status_code=status_code,
+ message=error_message,
+ headers=headers,
+ )
+
+ def transform_request(
+ self,
+ model: str,
+ messages: List[AllMessageValues],
+ optional_params: dict,
+ litellm_params: dict,
+ headers: dict,
+ ) -> dict:
+ return {"model": model, "messages": messages, **optional_params}
+
+ def transform_response(
+ self,
+ model: str,
+ raw_response: httpx.Response,
+ model_response: ModelResponse,
+ logging_obj: LiteLLMLoggingObj,
+ request_data: dict,
+ messages: List[AllMessageValues],
+ optional_params: dict,
+ litellm_params: dict,
+ encoding: Any,
+ api_key: Optional[str] = None,
+ json_mode: Optional[bool] = None,
+ ) -> ModelResponse:
+ raise NotImplementedError(
+ "OpenAI handler does this transformation as it uses the OpenAI SDK."
+ )
+
+ def validate_environment(
+ self,
+ headers: dict,
+ model: str,
+ messages: List[AllMessageValues],
+ optional_params: dict,
+ api_key: Optional[str] = None,
+ ) -> dict:
+ raise NotImplementedError(
+ "OpenAI handler does this validation as it uses the OpenAI SDK."
+ )
+
class OpenAIChatCompletion(BaseLLM):
@@ -483,6 +408,7 @@ class OpenAIChatCompletion(BaseLLM):
model_response: ModelResponse,
timeout: Union[float, httpx.Timeout],
optional_params: dict,
+ litellm_params: dict,
logging_obj: Any,
model: Optional[str] = None,
messages: Optional[list] = None,
@@ -490,7 +416,6 @@ class OpenAIChatCompletion(BaseLLM):
api_key: Optional[str] = None,
api_base: Optional[str] = None,
acompletion: bool = False,
- litellm_params=None,
logger_fn=None,
headers: Optional[dict] = None,
custom_prompt_dict: dict = {},
@@ -516,31 +441,26 @@ class OpenAIChatCompletion(BaseLLM):
if custom_llm_provider is not None and custom_llm_provider != "openai":
model_response.model = f"{custom_llm_provider}/{model}"
- # process all OpenAI compatible provider logic here
- if custom_llm_provider == "mistral":
- # check if message content passed in as list, and not string
- messages = prompt_factory( # type: ignore
- model=model,
- messages=messages,
- custom_llm_provider=custom_llm_provider,
- )
- if custom_llm_provider == "perplexity" and messages is not None:
- # check if messages.name is passed + supported, if not supported remove
- messages = prompt_factory( # type: ignore
- model=model,
- messages=messages,
- custom_llm_provider=custom_llm_provider,
- )
+
if messages is not None and custom_llm_provider is not None:
provider_config = ProviderConfigManager.get_provider_chat_config(
model=model, provider=LlmProviders(custom_llm_provider)
)
- messages = provider_config._transform_messages(messages)
+ if isinstance(provider_config, OpenAIGPTConfig) or isinstance(
+ provider_config, OpenAIConfig
+ ):
+ messages = provider_config._transform_messages(messages)
for _ in range(
2
): # if call fails due to alternating messages, retry with reformatted message
- data = {"model": model, "messages": messages, **optional_params}
+ data = OpenAIConfig().transform_request(
+ model=model,
+ messages=messages,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ headers=headers or {},
+ )
try:
max_retries = data.pop("max_retries", 2)
@@ -2430,7 +2350,7 @@ class OpenAIAssistantsAPI(BaseLLM):
"""
Here's an example:
```
- from litellm.llms.OpenAI.openai import OpenAIAssistantsAPI, MessageData
+ from litellm.llms.openai.openai import OpenAIAssistantsAPI, MessageData
# create thread
message: MessageData = {"role": "user", "content": "Hey, how's it going?"}
diff --git a/litellm/llms/openai_like/chat/handler.py b/litellm/llms/openai_like/chat/handler.py
index 831051a2c2..f34869bdac 100644
--- a/litellm/llms/openai_like/chat/handler.py
+++ b/litellm/llms/openai_like/chat/handler.py
@@ -26,6 +26,8 @@ from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
)
from litellm.llms.databricks.streaming_utils import ModelResponseIterator
+from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
+from litellm.llms.openai.openai import OpenAIConfig
from litellm.types.utils import CustomStreamingDecoder, ModelResponse
from litellm.utils import (
Choices,
@@ -205,6 +207,7 @@ class OpenAILikeChatHandler(OpenAILikeBase):
)
response.raise_for_status()
except httpx.HTTPStatusError as e:
+ print(f"e.response.text: {e.response.text}")
raise OpenAILikeError(
status_code=e.response.status_code,
message=e.response.text,
@@ -212,6 +215,7 @@ class OpenAILikeChatHandler(OpenAILikeBase):
except httpx.TimeoutException:
raise OpenAILikeError(status_code=408, message="Timeout error occurred.")
except Exception as e:
+ print(f"e: {e}")
raise OpenAILikeError(status_code=500, message=str(e))
return OpenAILikeChatConfig._transform_response(
@@ -280,7 +284,10 @@ class OpenAILikeChatHandler(OpenAILikeBase):
provider_config = ProviderConfigManager.get_provider_chat_config(
model=model, provider=LlmProviders(custom_llm_provider)
)
- messages = provider_config._transform_messages(messages)
+ if isinstance(provider_config, OpenAIGPTConfig) or isinstance(
+ provider_config, OpenAIConfig
+ ):
+ messages = provider_config._transform_messages(messages)
data = {
"model": model,
diff --git a/litellm/llms/openai_like/chat/transformation.py b/litellm/llms/openai_like/chat/transformation.py
index d60c70a378..2ea2010743 100644
--- a/litellm/llms/openai_like/chat/transformation.py
+++ b/litellm/llms/openai_like/chat/transformation.py
@@ -75,6 +75,7 @@ class OpenAILikeChatConfig(OpenAIGPTConfig):
custom_llm_provider: str,
base_model: Optional[str],
) -> ModelResponse:
+ print(f"response: {response}")
response_json = response.json()
logging_obj.post_call(
input=messages,
@@ -99,3 +100,25 @@ class OpenAILikeChatConfig(OpenAIGPTConfig):
if base_model is not None:
returned_response._hidden_params["model"] = base_model
return returned_response
+
+ def map_openai_params(
+ self,
+ non_default_params: dict,
+ optional_params: dict,
+ model: str,
+ drop_params: bool,
+ replace_max_completion_tokens_with_max_tokens: bool = True,
+ ) -> dict:
+ mapped_params = super().map_openai_params(
+ non_default_params, optional_params, model, drop_params
+ )
+ if (
+ "max_completion_tokens" in non_default_params
+ and replace_max_completion_tokens_with_max_tokens
+ ):
+ mapped_params["max_tokens"] = non_default_params[
+ "max_completion_tokens"
+ ] # most openai-compatible providers support 'max_tokens' not 'max_completion_tokens'
+ mapped_params.pop("max_completion_tokens", None)
+
+ return mapped_params
diff --git a/litellm/llms/openrouter.py b/litellm/llms/openrouter.py
deleted file mode 100644
index b6ec4024fd..0000000000
--- a/litellm/llms/openrouter.py
+++ /dev/null
@@ -1,41 +0,0 @@
-from typing import List, Dict
-import types
-
-
-class OpenrouterConfig:
- """
- Reference: https://openrouter.ai/docs#format
-
- """
-
- # OpenRouter-only parameters
- extra_body: Dict[str, List[str]] = {"transforms": []} # default transforms to []
-
- def __init__(
- self,
- transforms: List[str] = [],
- models: List[str] = [],
- route: str = "",
- ) -> None:
- locals_ = locals()
- for key, value in locals_.items():
- if key != "self" and value is not None:
- setattr(self.__class__, key, value)
-
- @classmethod
- def get_config(cls):
- return {
- k: v
- for k, v in cls.__dict__.items()
- if not k.startswith("__")
- and not isinstance(
- v,
- (
- types.FunctionType,
- types.BuiltinFunctionType,
- classmethod,
- staticmethod,
- ),
- )
- and v is not None
- }
diff --git a/litellm/llms/openrouter/chat/transformation.py b/litellm/llms/openrouter/chat/transformation.py
new file mode 100644
index 0000000000..9565fc99e0
--- /dev/null
+++ b/litellm/llms/openrouter/chat/transformation.py
@@ -0,0 +1,43 @@
+"""
+Support for OpenAI's `/v1/chat/completions` endpoint.
+
+Calls done in OpenAI/openai.py as OpenRouter is openai-compatible.
+
+Docs: https://openrouter.ai/docs/parameters
+"""
+
+from typing import Optional
+
+from litellm import get_model_info, verbose_logger
+
+from ...openai.chat.gpt_transformation import OpenAIGPTConfig
+
+
+class OpenrouterConfig(OpenAIGPTConfig):
+
+ def map_openai_params(
+ self,
+ non_default_params: dict,
+ optional_params: dict,
+ model: str,
+ drop_params: bool,
+ ) -> dict:
+ mapped_openai_params = super().map_openai_params(
+ non_default_params, optional_params, model, drop_params
+ )
+
+ # OpenRouter-only parameters
+ extra_body = {}
+ transforms = non_default_params.pop("transforms", None)
+ models = non_default_params.pop("models", None)
+ route = non_default_params.pop("route", None)
+ if transforms is not None:
+ extra_body["transforms"] = transforms
+ if models is not None:
+ extra_body["models"] = models
+ if route is not None:
+ extra_body["route"] = route
+ mapped_openai_params["extra_body"] = (
+ extra_body # openai client supports `extra_body` param
+ )
+ return mapped_openai_params
diff --git a/litellm/llms/prompt_templates/common_utils.py b/litellm/llms/prompt_templates/common_utils.py
index c0798f3b22..5291f40826 100644
--- a/litellm/llms/prompt_templates/common_utils.py
+++ b/litellm/llms/prompt_templates/common_utils.py
@@ -4,7 +4,7 @@ Common utility functions used for translating messages across providers
import json
from copy import deepcopy
-from typing import Dict, List, Literal, Optional, Union
+from typing import Dict, List, Literal, Optional, Union, cast
import litellm
from litellm.types.llms.openai import (
@@ -53,6 +53,13 @@ def strip_name_from_messages(
return new_messages
+def strip_none_values_from_message(message: AllMessageValues) -> AllMessageValues:
+ """
+ Strips None values from message
+ """
+ return cast(AllMessageValues, {k: v for k, v in message.items() if v is not None})
+
+
def convert_content_list_to_str(message: AllMessageValues) -> str:
"""
- handles scenario where content is list and not string
diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py
index 490a39c29f..13b85a3dc2 100644
--- a/litellm/llms/prompt_templates/factory.py
+++ b/litellm/llms/prompt_templates/factory.py
@@ -2856,7 +2856,7 @@ def prompt_factory(
else:
return gemini_text_image_pt(messages=messages)
elif custom_llm_provider == "mistral":
- return litellm.MistralConfig._transform_messages(messages=messages)
+ return litellm.MistralConfig()._transform_messages(messages=messages)
elif custom_llm_provider == "bedrock":
if "amazon.titan-text" in model:
return amazon_titan_pt(messages=messages)
diff --git a/litellm/llms/replicate.py b/litellm/llms/replicate.py
deleted file mode 100644
index 2e9bbb3331..0000000000
--- a/litellm/llms/replicate.py
+++ /dev/null
@@ -1,609 +0,0 @@
-import asyncio
-import json
-import os
-import time
-import types
-from typing import Any, Callable, Optional, Tuple, Union
-
-import httpx # type: ignore
-import requests # type: ignore
-
-import litellm
-from litellm.llms.custom_httpx.http_handler import (
- AsyncHTTPHandler,
- get_async_httpx_client,
-)
-from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
-
-from .prompt_templates.factory import custom_prompt, prompt_factory
-
-
-class ReplicateError(Exception):
- def __init__(self, status_code, message):
- self.status_code = status_code
- self.message = message
- self.request = httpx.Request(
- method="POST", url="https://api.replicate.com/v1/deployments"
- )
- self.response = httpx.Response(status_code=status_code, request=self.request)
- super().__init__(
- self.message
- ) # Call the base class constructor with the parameters it needs
-
-
-class ReplicateConfig:
- """
- Reference: https://replicate.com/meta/llama-2-70b-chat/api
- - `prompt` (string): The prompt to send to the model.
-
- - `system_prompt` (string): The system prompt to send to the model. This is prepended to the prompt and helps guide system behavior. Default value: `You are a helpful assistant`.
-
- - `max_new_tokens` (integer): Maximum number of tokens to generate. Typically, a word is made up of 2-3 tokens. Default value: `128`.
-
- - `min_new_tokens` (integer): Minimum number of tokens to generate. To disable, set to `-1`. A word is usually 2-3 tokens. Default value: `-1`.
-
- - `temperature` (number): Adjusts the randomness of outputs. Values greater than 1 increase randomness, 0 is deterministic, and 0.75 is a reasonable starting value. Default value: `0.75`.
-
- - `top_p` (number): During text decoding, it samples from the top `p` percentage of most likely tokens. Reduce this to ignore less probable tokens. Default value: `0.9`.
-
- - `top_k` (integer): During text decoding, samples from the top `k` most likely tokens. Reduce this to ignore less probable tokens. Default value: `50`.
-
- - `stop_sequences` (string): A comma-separated list of sequences to stop generation at. For example, inputting ',' will cease generation at the first occurrence of either 'end' or ''.
-
- - `seed` (integer): This is the seed for the random generator. Leave it blank to randomize the seed.
-
- - `debug` (boolean): If set to `True`, it provides debugging output in logs.
-
- Please note that Replicate's mapping of these parameters can be inconsistent across different models, indicating that not all of these parameters may be available for use with all models.
- """
-
- system_prompt: Optional[str] = None
- max_new_tokens: Optional[int] = None
- min_new_tokens: Optional[int] = None
- temperature: Optional[int] = None
- top_p: Optional[int] = None
- top_k: Optional[int] = None
- stop_sequences: Optional[str] = None
- seed: Optional[int] = None
- debug: Optional[bool] = None
-
- def __init__(
- self,
- system_prompt: Optional[str] = None,
- max_new_tokens: Optional[int] = None,
- min_new_tokens: Optional[int] = None,
- temperature: Optional[int] = None,
- top_p: Optional[int] = None,
- top_k: Optional[int] = None,
- stop_sequences: Optional[str] = None,
- seed: Optional[int] = None,
- debug: Optional[bool] = None,
- ) -> None:
- locals_ = locals()
- for key, value in locals_.items():
- if key != "self" and value is not None:
- setattr(self.__class__, key, value)
-
- @classmethod
- def get_config(cls):
- return {
- k: v
- for k, v in cls.__dict__.items()
- if not k.startswith("__")
- and not isinstance(
- v,
- (
- types.FunctionType,
- types.BuiltinFunctionType,
- classmethod,
- staticmethod,
- ),
- )
- and v is not None
- }
-
-
-# Function to start a prediction and get the prediction URL
-def start_prediction(
- version_id, input_data, api_token, api_base, logging_obj, print_verbose
-):
- base_url = api_base
- if "deployments" in version_id:
- print_verbose("\nLiteLLM: Request to custom replicate deployment")
- version_id = version_id.replace("deployments/", "")
- base_url = f"https://api.replicate.com/v1/deployments/{version_id}"
- print_verbose(f"Deployment base URL: {base_url}\n")
- else: # assume it's a model
- base_url = f"https://api.replicate.com/v1/models/{version_id}"
- headers = {
- "Authorization": f"Token {api_token}",
- "Content-Type": "application/json",
- }
-
- initial_prediction_data = {
- "input": input_data,
- }
-
- if ":" in version_id and len(version_id) > 64:
- model_parts = version_id.split(":")
- if (
- len(model_parts) > 1 and len(model_parts[1]) == 64
- ): ## checks if model name has a 64 digit code - e.g. "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3"
- initial_prediction_data["version"] = model_parts[1]
-
- ## LOGGING
- logging_obj.pre_call(
- input=input_data["prompt"],
- api_key="",
- additional_args={
- "complete_input_dict": initial_prediction_data,
- "headers": headers,
- "api_base": base_url,
- },
- )
-
- response = requests.post(
- f"{base_url}/predictions", json=initial_prediction_data, headers=headers
- )
- if response.status_code == 201:
- response_data = response.json()
- return response_data.get("urls", {}).get("get")
- else:
- raise ReplicateError(
- response.status_code, f"Failed to start prediction {response.text}"
- )
-
-
-async def async_start_prediction(
- version_id,
- input_data,
- api_token,
- api_base,
- logging_obj,
- print_verbose,
- http_handler: AsyncHTTPHandler,
-) -> str:
- base_url = api_base
- if "deployments" in version_id:
- print_verbose("\nLiteLLM: Request to custom replicate deployment")
- version_id = version_id.replace("deployments/", "")
- base_url = f"https://api.replicate.com/v1/deployments/{version_id}"
- print_verbose(f"Deployment base URL: {base_url}\n")
- else: # assume it's a model
- base_url = f"https://api.replicate.com/v1/models/{version_id}"
- headers = {
- "Authorization": f"Token {api_token}",
- "Content-Type": "application/json",
- }
-
- initial_prediction_data = {
- "input": input_data,
- }
-
- if ":" in version_id and len(version_id) > 64:
- model_parts = version_id.split(":")
- if (
- len(model_parts) > 1 and len(model_parts[1]) == 64
- ): ## checks if model name has a 64 digit code - e.g. "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3"
- initial_prediction_data["version"] = model_parts[1]
-
- ## LOGGING
- logging_obj.pre_call(
- input=input_data["prompt"],
- api_key="",
- additional_args={
- "complete_input_dict": initial_prediction_data,
- "headers": headers,
- "api_base": base_url,
- },
- )
-
- response = await http_handler.post(
- url="{}/predictions".format(base_url),
- data=json.dumps(initial_prediction_data),
- headers=headers,
- )
-
- if response.status_code == 201:
- response_data = response.json()
- return response_data.get("urls", {}).get("get")
- else:
- raise ReplicateError(
- response.status_code, f"Failed to start prediction {response.text}"
- )
-
-
-# Function to handle prediction response (non-streaming)
-def handle_prediction_response(prediction_url, api_token, print_verbose):
- output_string = ""
- headers = {
- "Authorization": f"Token {api_token}",
- "Content-Type": "application/json",
- }
-
- status = ""
- logs = ""
- while True and (status not in ["succeeded", "failed", "canceled"]):
- print_verbose(f"replicate: polling endpoint: {prediction_url}")
- time.sleep(0.5)
- response = requests.get(prediction_url, headers=headers)
- if response.status_code == 200:
- response_data = response.json()
- if "output" in response_data:
- output_string = "".join(response_data["output"])
- print_verbose(f"Non-streamed output:{output_string}")
- status = response_data.get("status", None)
- logs = response_data.get("logs", "")
- if status == "failed":
- replicate_error = response_data.get("error", "")
- raise ReplicateError(
- status_code=400,
- message=f"Error: {replicate_error}, \nReplicate logs:{logs}",
- )
- else:
- # this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
- print_verbose("Replicate: Failed to fetch prediction status and output.")
- return output_string, logs
-
-
-async def async_handle_prediction_response(
- prediction_url, api_token, print_verbose, http_handler: AsyncHTTPHandler
-) -> Tuple[str, Any]:
- output_string = ""
- headers = {
- "Authorization": f"Token {api_token}",
- "Content-Type": "application/json",
- }
-
- status = ""
- logs = ""
- while True and (status not in ["succeeded", "failed", "canceled"]):
- print_verbose(f"replicate: polling endpoint: {prediction_url}")
- await asyncio.sleep(0.5) # prevent replicate rate limit errors
- response = await http_handler.get(prediction_url, headers=headers)
- if response.status_code == 200:
- response_data = response.json()
- if "output" in response_data:
- output_string = "".join(response_data["output"])
- print_verbose(f"Non-streamed output:{output_string}")
- status = response_data.get("status", None)
- logs = response_data.get("logs", "")
- if status == "failed":
- replicate_error = response_data.get("error", "")
- raise ReplicateError(
- status_code=400,
- message=f"Error: {replicate_error}, \nReplicate logs:{logs}",
- )
- else:
- # this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
- print_verbose("Replicate: Failed to fetch prediction status and output.")
- return output_string, logs
-
-
-# Function to handle prediction response (streaming)
-def handle_prediction_response_streaming(prediction_url, api_token, print_verbose):
- previous_output = ""
- output_string = ""
-
- headers = {
- "Authorization": f"Token {api_token}",
- "Content-Type": "application/json",
- }
- status = ""
- while True and (status not in ["succeeded", "failed", "canceled"]):
- time.sleep(0.5) # prevent being rate limited by replicate
- print_verbose(f"replicate: polling endpoint: {prediction_url}")
- response = requests.get(prediction_url, headers=headers)
- if response.status_code == 200:
- response_data = response.json()
- status = response_data["status"]
- if "output" in response_data:
- try:
- output_string = "".join(response_data["output"])
- except Exception:
- raise ReplicateError(
- status_code=422,
- message="Unable to parse response. Got={}".format(
- response_data["output"]
- ),
- )
- new_output = output_string[len(previous_output) :]
- print_verbose(f"New chunk: {new_output}")
- yield {"output": new_output, "status": status}
- previous_output = output_string
- status = response_data["status"]
- if status == "failed":
- replicate_error = response_data.get("error", "")
- raise ReplicateError(
- status_code=400, message=f"Error: {replicate_error}"
- )
- else:
- # this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
- print_verbose(
- f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}"
- )
-
-
-# Function to handle prediction response (streaming)
-async def async_handle_prediction_response_streaming(
- prediction_url, api_token, print_verbose
-):
- http_handler = get_async_httpx_client(llm_provider=litellm.LlmProviders.REPLICATE)
- previous_output = ""
- output_string = ""
-
- headers = {
- "Authorization": f"Token {api_token}",
- "Content-Type": "application/json",
- }
- status = ""
- while True and (status not in ["succeeded", "failed", "canceled"]):
- await asyncio.sleep(0.5) # prevent being rate limited by replicate
- print_verbose(f"replicate: polling endpoint: {prediction_url}")
- response = await http_handler.get(prediction_url, headers=headers)
- if response.status_code == 200:
- response_data = response.json()
- status = response_data["status"]
- if "output" in response_data:
- try:
- output_string = "".join(response_data["output"])
- except Exception:
- raise ReplicateError(
- status_code=422,
- message="Unable to parse response. Got={}".format(
- response_data["output"]
- ),
- )
- new_output = output_string[len(previous_output) :]
- print_verbose(f"New chunk: {new_output}")
- yield {"output": new_output, "status": status}
- previous_output = output_string
- status = response_data["status"]
- if status == "failed":
- replicate_error = response_data.get("error", "")
- raise ReplicateError(
- status_code=400, message=f"Error: {replicate_error}"
- )
- else:
- # this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
- print_verbose(
- f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}"
- )
-
-
-# Function to extract version ID from model string
-def model_to_version_id(model):
- if ":" in model:
- split_model = model.split(":")
- return split_model[1]
- return model
-
-
-def process_response(
- model_response: ModelResponse,
- result: str,
- model: str,
- encoding: Any,
- prompt: str,
-) -> ModelResponse:
- if len(result) == 0: # edge case, where result from replicate is empty
- result = " "
-
- ## Building RESPONSE OBJECT
- if len(result) >= 1:
- model_response.choices[0].message.content = result # type: ignore
-
- # Calculate usage
- prompt_tokens = len(encoding.encode(prompt, disallowed_special=()))
- completion_tokens = len(
- encoding.encode(
- model_response["choices"][0]["message"].get("content", ""),
- disallowed_special=(),
- )
- )
- model_response.model = "replicate/" + model
- usage = Usage(
- prompt_tokens=prompt_tokens,
- completion_tokens=completion_tokens,
- total_tokens=prompt_tokens + completion_tokens,
- )
- setattr(model_response, "usage", usage)
-
- return model_response
-
-
-# Main function for prediction completion
-def completion(
- model: str,
- messages: list,
- api_base: str,
- model_response: ModelResponse,
- print_verbose: Callable,
- optional_params: dict,
- logging_obj,
- api_key,
- encoding,
- custom_prompt_dict={},
- litellm_params=None,
- logger_fn=None,
- acompletion=None,
-) -> Union[ModelResponse, CustomStreamWrapper]:
- # Start a prediction and get the prediction URL
- version_id = model_to_version_id(model)
- ## Load Config
- config = litellm.ReplicateConfig.get_config()
- for k, v in config.items():
- if (
- k not in optional_params
- ): # completion(top_k=3) > replicate_config(top_k=3) <- allows for dynamic variables to be passed in
- optional_params[k] = v
-
- system_prompt = None
- if optional_params is not None and "supports_system_prompt" in optional_params:
- supports_sys_prompt = optional_params.pop("supports_system_prompt")
- else:
- supports_sys_prompt = False
-
- if supports_sys_prompt:
- for i in range(len(messages)):
- if messages[i]["role"] == "system":
- first_sys_message = messages.pop(i)
- system_prompt = first_sys_message["content"]
- break
-
- if model in custom_prompt_dict:
- # check if the model has a registered custom prompt
- model_prompt_details = custom_prompt_dict[model]
- prompt = custom_prompt(
- role_dict=model_prompt_details.get("roles", {}),
- initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
- final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
- bos_token=model_prompt_details.get("bos_token", ""),
- eos_token=model_prompt_details.get("eos_token", ""),
- messages=messages,
- )
- else:
- prompt = prompt_factory(model=model, messages=messages)
-
- if prompt is None or not isinstance(prompt, str):
- raise ReplicateError(
- status_code=400,
- message="LiteLLM Error - prompt is not a string - {}".format(prompt),
- )
-
- # If system prompt is supported, and a system prompt is provided, use it
- if system_prompt is not None:
- input_data = {
- "prompt": prompt,
- "system_prompt": system_prompt,
- **optional_params,
- }
- # Otherwise, use the prompt as is
- else:
- input_data = {"prompt": prompt, **optional_params}
-
- if acompletion is not None and acompletion is True:
- return async_completion(
- model_response=model_response,
- model=model,
- prompt=prompt,
- encoding=encoding,
- optional_params=optional_params,
- version_id=version_id,
- input_data=input_data,
- api_key=api_key,
- api_base=api_base,
- logging_obj=logging_obj,
- print_verbose=print_verbose,
- ) # type: ignore
- ## COMPLETION CALL
- ## Replicate Compeltion calls have 2 steps
- ## Step1: Start Prediction: gets a prediction url
- ## Step2: Poll prediction url for response
- ## Step2: is handled with and without streaming
- model_response.created = int(
- time.time()
- ) # for pricing this must remain right before calling api
-
- prediction_url = start_prediction(
- version_id,
- input_data,
- api_key,
- api_base,
- logging_obj=logging_obj,
- print_verbose=print_verbose,
- )
- print_verbose(prediction_url)
-
- # Handle the prediction response (streaming or non-streaming)
- if "stream" in optional_params and optional_params["stream"] is True:
- print_verbose("streaming request")
- _response = handle_prediction_response_streaming(
- prediction_url, api_key, print_verbose
- )
- return CustomStreamWrapper(_response, model, logging_obj=logging_obj, custom_llm_provider="replicate") # type: ignore
- else:
- result, logs = handle_prediction_response(
- prediction_url, api_key, print_verbose
- )
-
- ## LOGGING
- logging_obj.post_call(
- input=prompt,
- api_key="",
- original_response=result,
- additional_args={
- "complete_input_dict": input_data,
- "logs": logs,
- "api_base": prediction_url,
- },
- )
-
- print_verbose(f"raw model_response: {result}")
-
- return process_response(
- model_response=model_response,
- result=result,
- model=model,
- encoding=encoding,
- prompt=prompt,
- )
-
-
-async def async_completion(
- model_response: ModelResponse,
- model: str,
- prompt: str,
- encoding,
- optional_params: dict,
- version_id,
- input_data,
- api_key,
- api_base,
- logging_obj,
- print_verbose,
-) -> Union[ModelResponse, CustomStreamWrapper]:
- http_handler = get_async_httpx_client(
- llm_provider=litellm.LlmProviders.REPLICATE,
- )
- prediction_url = await async_start_prediction(
- version_id,
- input_data,
- api_key,
- api_base,
- logging_obj=logging_obj,
- print_verbose=print_verbose,
- http_handler=http_handler,
- )
-
- if "stream" in optional_params and optional_params["stream"] is True:
- _response = async_handle_prediction_response_streaming(
- prediction_url, api_key, print_verbose
- )
- return CustomStreamWrapper(_response, model, logging_obj=logging_obj, custom_llm_provider="replicate") # type: ignore
-
- result, logs = await async_handle_prediction_response(
- prediction_url, api_key, print_verbose, http_handler=http_handler
- )
-
- return process_response(
- model_response=model_response,
- result=result,
- model=model,
- encoding=encoding,
- prompt=prompt,
- )
-
-
-# # Example usage:
-# response = completion(
-# api_key="",
-# messages=[{"content": "good morning"}],
-# model="replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf",
-# model_response=ModelResponse(),
-# print_verbose=print,
-# logging_obj=print, # stub logging_obj
-# optional_params={"stream": False}
-# )
-
-# print(response)
diff --git a/litellm/llms/replicate/chat/handler.py b/litellm/llms/replicate/chat/handler.py
new file mode 100644
index 0000000000..898f350bac
--- /dev/null
+++ b/litellm/llms/replicate/chat/handler.py
@@ -0,0 +1,285 @@
+import asyncio
+import json
+import os
+import time
+import types
+from typing import Any, Callable, List, Optional, Tuple, Union
+
+import httpx # type: ignore
+
+import litellm
+from litellm.llms.custom_httpx.http_handler import (
+ AsyncHTTPHandler,
+ HTTPHandler,
+ _get_httpx_client,
+ get_async_httpx_client,
+)
+from litellm.types.llms.openai import AllMessageValues
+from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
+
+from ...prompt_templates.factory import custom_prompt, prompt_factory
+from ..common_utils import ReplicateError
+from .transformation import ReplicateConfig
+
+replicate_config = ReplicateConfig()
+
+
+# Function to handle prediction response (streaming)
+def handle_prediction_response_streaming(
+ prediction_url, api_token, print_verbose, headers: dict, http_client: HTTPHandler
+):
+ previous_output = ""
+ output_string = ""
+
+ status = ""
+ while True and (status not in ["succeeded", "failed", "canceled"]):
+ time.sleep(0.5) # prevent being rate limited by replicate
+ print_verbose(f"replicate: polling endpoint: {prediction_url}")
+ response = http_client.get(prediction_url, headers=headers)
+ if response.status_code == 200:
+ response_data = response.json()
+ status = response_data["status"]
+ if "output" in response_data:
+ try:
+ output_string = "".join(response_data["output"])
+ except Exception:
+ raise ReplicateError(
+ status_code=422,
+ message="Unable to parse response. Got={}".format(
+ response_data["output"]
+ ),
+ headers=response.headers,
+ )
+ new_output = output_string[len(previous_output) :]
+ print_verbose(f"New chunk: {new_output}")
+ yield {"output": new_output, "status": status}
+ previous_output = output_string
+ status = response_data["status"]
+ if status == "failed":
+ replicate_error = response_data.get("error", "")
+ raise ReplicateError(
+ status_code=400,
+ message=f"Error: {replicate_error}",
+ headers=response.headers,
+ )
+ else:
+ # this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
+ print_verbose(
+ f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}"
+ )
+
+
+# Function to handle prediction response (streaming)
+async def async_handle_prediction_response_streaming(
+ prediction_url,
+ api_token,
+ print_verbose,
+ headers: dict,
+ http_client: AsyncHTTPHandler,
+):
+ previous_output = ""
+ output_string = ""
+
+ status = ""
+ while True and (status not in ["succeeded", "failed", "canceled"]):
+ await asyncio.sleep(0.5) # prevent being rate limited by replicate
+ print_verbose(f"replicate: polling endpoint: {prediction_url}")
+ response = await http_client.get(prediction_url, headers=headers)
+ if response.status_code == 200:
+ response_data = response.json()
+ status = response_data["status"]
+ if "output" in response_data:
+ try:
+ output_string = "".join(response_data["output"])
+ except Exception:
+ raise ReplicateError(
+ status_code=422,
+ message="Unable to parse response. Got={}".format(
+ response_data["output"]
+ ),
+ headers=response.headers,
+ )
+ new_output = output_string[len(previous_output) :]
+ print_verbose(f"New chunk: {new_output}")
+ yield {"output": new_output, "status": status}
+ previous_output = output_string
+ status = response_data["status"]
+ if status == "failed":
+ replicate_error = response_data.get("error", "")
+ raise ReplicateError(
+ status_code=400,
+ message=f"Error: {replicate_error}",
+ headers=response.headers,
+ )
+ else:
+ # this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
+ print_verbose(
+ f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}"
+ )
+
+
+# Main function for prediction completion
+def completion(
+ model: str,
+ messages: list,
+ api_base: str,
+ model_response: ModelResponse,
+ print_verbose: Callable,
+ optional_params: dict,
+ litellm_params: dict,
+ logging_obj,
+ api_key,
+ encoding,
+ custom_prompt_dict={},
+ logger_fn=None,
+ acompletion=None,
+ headers={},
+) -> Union[ModelResponse, CustomStreamWrapper]:
+ headers = replicate_config.validate_environment(
+ api_key=api_key,
+ headers=headers,
+ model=model,
+ messages=messages,
+ optional_params=optional_params,
+ )
+ # Start a prediction and get the prediction URL
+ version_id = replicate_config.model_to_version_id(model)
+ input_data = replicate_config.transform_request(
+ model=model,
+ messages=messages,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ headers=headers,
+ )
+
+ if acompletion is not None and acompletion is True:
+ return async_completion(
+ model_response=model_response,
+ model=model,
+ encoding=encoding,
+ messages=messages,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ version_id=version_id,
+ input_data=input_data,
+ api_key=api_key,
+ api_base=api_base,
+ logging_obj=logging_obj,
+ print_verbose=print_verbose,
+ headers=headers,
+ ) # type: ignore
+ ## COMPLETION CALL
+ model_response.created = int(
+ time.time()
+ ) # for pricing this must remain right before calling api
+
+ prediction_url = replicate_config.get_complete_url(api_base, model)
+
+ ## COMPLETION CALL
+ httpx_client = _get_httpx_client(
+ params={"timeout": 600.0},
+ )
+ response = httpx_client.post(
+ url=prediction_url,
+ headers=headers,
+ data=json.dumps(input_data),
+ )
+
+ prediction_url = replicate_config.get_prediction_url(response)
+
+ # Handle the prediction response (streaming or non-streaming)
+ if "stream" in optional_params and optional_params["stream"] is True:
+ print_verbose("streaming request")
+ _response = handle_prediction_response_streaming(
+ prediction_url,
+ api_key,
+ print_verbose,
+ headers=headers,
+ http_client=httpx_client,
+ )
+ return CustomStreamWrapper(_response, model, logging_obj=logging_obj, custom_llm_provider="replicate") # type: ignore
+ else:
+ for _ in range(litellm.DEFAULT_MAX_RETRIES):
+ time.sleep(
+ 1
+ ) # wait 1s to allow response to be generated by replicate - else partial output is generated with status=="processing"
+ response = httpx_client.get(url=prediction_url, headers=headers)
+ return litellm.ReplicateConfig().transform_response(
+ model=model,
+ raw_response=response,
+ model_response=model_response,
+ logging_obj=logging_obj,
+ api_key=api_key,
+ request_data=input_data,
+ messages=messages,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ encoding=encoding,
+ )
+
+ raise ReplicateError(
+ status_code=500,
+ message="No response received from Replicate API after max retries",
+ headers=None,
+ )
+
+
+async def async_completion(
+ model_response: ModelResponse,
+ model: str,
+ messages: List[AllMessageValues],
+ encoding,
+ optional_params: dict,
+ litellm_params: dict,
+ version_id,
+ input_data,
+ api_key,
+ api_base,
+ logging_obj,
+ print_verbose,
+ headers: dict,
+) -> Union[ModelResponse, CustomStreamWrapper]:
+
+ prediction_url = replicate_config.get_complete_url(api_base=api_base, model=model)
+ async_handler = get_async_httpx_client(
+ llm_provider=litellm.LlmProviders.REPLICATE,
+ params={"timeout": 600.0},
+ )
+ response = await async_handler.post(
+ url=prediction_url, headers=headers, data=json.dumps(input_data)
+ )
+ prediction_url = replicate_config.get_prediction_url(response)
+
+ if "stream" in optional_params and optional_params["stream"] is True:
+ _response = async_handle_prediction_response_streaming(
+ prediction_url,
+ api_key,
+ print_verbose,
+ headers=headers,
+ http_client=async_handler,
+ )
+ return CustomStreamWrapper(_response, model, logging_obj=logging_obj, custom_llm_provider="replicate") # type: ignore
+
+ for _ in range(litellm.DEFAULT_MAX_RETRIES):
+ await asyncio.sleep(
+ 1
+ ) # wait 1s to allow response to be generated by replicate - else partial output is generated with status=="processing"
+ response = await async_handler.get(url=prediction_url, headers=headers)
+ return litellm.ReplicateConfig().transform_response(
+ model=model,
+ raw_response=response,
+ model_response=model_response,
+ logging_obj=logging_obj,
+ api_key=api_key,
+ request_data=input_data,
+ messages=messages,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ encoding=encoding,
+ )
+ # Add a fallback return if no response is received after max retries
+ raise ReplicateError(
+ status_code=500,
+ message="No response received from Replicate API after max retries",
+ headers=None,
+ )
diff --git a/litellm/llms/replicate/chat/transformation.py b/litellm/llms/replicate/chat/transformation.py
new file mode 100644
index 0000000000..180c67271e
--- /dev/null
+++ b/litellm/llms/replicate/chat/transformation.py
@@ -0,0 +1,312 @@
+import types
+from typing import TYPE_CHECKING, Any, List, Optional, Union
+
+import httpx
+
+import litellm
+from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
+from litellm.llms.prompt_templates.common_utils import convert_content_list_to_str
+from litellm.llms.prompt_templates.factory import custom_prompt, prompt_factory
+from litellm.types.llms.openai import AllMessageValues
+from litellm.types.utils import Choices, Message, ModelResponse, Usage
+from litellm.utils import token_counter
+
+from ..common_utils import ReplicateError
+
+if TYPE_CHECKING:
+ from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
+
+ LoggingClass = LiteLLMLoggingObj
+else:
+ LoggingClass = Any
+
+
+class ReplicateConfig(BaseConfig):
+ """
+ Reference: https://replicate.com/meta/llama-2-70b-chat/api
+ - `prompt` (string): The prompt to send to the model.
+
+ - `system_prompt` (string): The system prompt to send to the model. This is prepended to the prompt and helps guide system behavior. Default value: `You are a helpful assistant`.
+
+ - `max_new_tokens` (integer): Maximum number of tokens to generate. Typically, a word is made up of 2-3 tokens. Default value: `128`.
+
+ - `min_new_tokens` (integer): Minimum number of tokens to generate. To disable, set to `-1`. A word is usually 2-3 tokens. Default value: `-1`.
+
+ - `temperature` (number): Adjusts the randomness of outputs. Values greater than 1 increase randomness, 0 is deterministic, and 0.75 is a reasonable starting value. Default value: `0.75`.
+
+ - `top_p` (number): During text decoding, it samples from the top `p` percentage of most likely tokens. Reduce this to ignore less probable tokens. Default value: `0.9`.
+
+ - `top_k` (integer): During text decoding, samples from the top `k` most likely tokens. Reduce this to ignore less probable tokens. Default value: `50`.
+
+ - `stop_sequences` (string): A comma-separated list of sequences to stop generation at. For example, inputting ',' will cease generation at the first occurrence of either 'end' or ''.
+
+ - `seed` (integer): This is the seed for the random generator. Leave it blank to randomize the seed.
+
+ - `debug` (boolean): If set to `True`, it provides debugging output in logs.
+
+ Please note that Replicate's mapping of these parameters can be inconsistent across different models, indicating that not all of these parameters may be available for use with all models.
+ """
+
+ system_prompt: Optional[str] = None
+ max_new_tokens: Optional[int] = None
+ min_new_tokens: Optional[int] = None
+ temperature: Optional[int] = None
+ top_p: Optional[int] = None
+ top_k: Optional[int] = None
+ stop_sequences: Optional[str] = None
+ seed: Optional[int] = None
+ debug: Optional[bool] = None
+
+ def __init__(
+ self,
+ system_prompt: Optional[str] = None,
+ max_new_tokens: Optional[int] = None,
+ min_new_tokens: Optional[int] = None,
+ temperature: Optional[int] = None,
+ top_p: Optional[int] = None,
+ top_k: Optional[int] = None,
+ stop_sequences: Optional[str] = None,
+ seed: Optional[int] = None,
+ debug: Optional[bool] = None,
+ ) -> None:
+ locals_ = locals()
+ for key, value in locals_.items():
+ if key != "self" and value is not None:
+ setattr(self.__class__, key, value)
+
+ @classmethod
+ def get_config(cls):
+ return super().get_config()
+
+ def get_supported_openai_params(self, model: str) -> list:
+ return [
+ "stream",
+ "temperature",
+ "max_tokens",
+ "top_p",
+ "stop",
+ "seed",
+ "tools",
+ "tool_choice",
+ "functions",
+ "function_call",
+ ]
+
+ def map_openai_params(
+ self,
+ non_default_params: dict,
+ optional_params: dict,
+ model: str,
+ drop_params: bool,
+ ) -> dict:
+ for param, value in non_default_params.items():
+ if param == "stream":
+ optional_params["stream"] = value
+ if param == "max_tokens":
+ if "vicuna" in model or "flan" in model:
+ optional_params["max_length"] = value
+ elif "meta/codellama-13b" in model:
+ optional_params["max_tokens"] = value
+ else:
+ optional_params["max_new_tokens"] = value
+ if param == "temperature":
+ optional_params["temperature"] = value
+ if param == "top_p":
+ optional_params["top_p"] = value
+ if param == "stop":
+ optional_params["stop_sequences"] = value
+
+ return optional_params
+
+ # Function to extract version ID from model string
+ def model_to_version_id(self, model: str) -> str:
+ if ":" in model:
+ split_model = model.split(":")
+ return split_model[1]
+ return model
+
+ def _transform_messages(
+ self, messages: List[AllMessageValues]
+ ) -> List[AllMessageValues]:
+ return messages
+
+ def get_error_class(
+ self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
+ ) -> BaseLLMException:
+ return ReplicateError(
+ status_code=status_code, message=error_message, headers=headers
+ )
+
+ def get_complete_url(self, api_base: str, model: str) -> str:
+ version_id = self.model_to_version_id(model)
+ base_url = api_base
+ if "deployments" in version_id:
+ version_id = version_id.replace("deployments/", "")
+ base_url = f"https://api.replicate.com/v1/deployments/{version_id}"
+ else: # assume it's a model
+ base_url = f"https://api.replicate.com/v1/models/{version_id}"
+
+ base_url = f"{base_url}/predictions"
+ return base_url
+
+ def transform_request(
+ self,
+ model: str,
+ messages: List[AllMessageValues],
+ optional_params: dict,
+ litellm_params: dict,
+ headers: dict,
+ ) -> dict:
+ ## Load Config
+ config = litellm.ReplicateConfig.get_config()
+ for k, v in config.items():
+ if (
+ k not in optional_params
+ ): # completion(top_k=3) > replicate_config(top_k=3) <- allows for dynamic variables to be passed in
+ optional_params[k] = v
+
+ system_prompt = None
+ if optional_params is not None and "supports_system_prompt" in optional_params:
+ supports_sys_prompt = optional_params.pop("supports_system_prompt")
+ else:
+ supports_sys_prompt = False
+
+ if supports_sys_prompt:
+ for i in range(len(messages)):
+ if messages[i]["role"] == "system":
+ first_sys_message = messages.pop(i)
+ system_prompt = convert_content_list_to_str(first_sys_message)
+ break
+
+ if model in litellm.custom_prompt_dict:
+ # check if the model has a registered custom prompt
+ model_prompt_details = litellm.custom_prompt_dict[model]
+ prompt = custom_prompt(
+ role_dict=model_prompt_details.get("roles", {}),
+ initial_prompt_value=model_prompt_details.get(
+ "initial_prompt_value", ""
+ ),
+ final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
+ bos_token=model_prompt_details.get("bos_token", ""),
+ eos_token=model_prompt_details.get("eos_token", ""),
+ messages=messages,
+ )
+ else:
+ prompt = prompt_factory(model=model, messages=messages)
+
+ if prompt is None or not isinstance(prompt, str):
+ raise ReplicateError(
+ status_code=400,
+ message="LiteLLM Error - prompt is not a string - {}".format(prompt),
+ headers={},
+ )
+
+ # If system prompt is supported, and a system prompt is provided, use it
+ if system_prompt is not None:
+ input_data = {
+ "prompt": prompt,
+ "system_prompt": system_prompt,
+ **optional_params,
+ }
+ # Otherwise, use the prompt as is
+ else:
+ input_data = {"prompt": prompt, **optional_params}
+
+ version_id = self.model_to_version_id(model)
+ request_data: dict = {"input": input_data}
+ if ":" in version_id and len(version_id) > 64:
+ model_parts = version_id.split(":")
+ if (
+ len(model_parts) > 1 and len(model_parts[1]) == 64
+ ): ## checks if model name has a 64 digit code - e.g. "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3"
+ request_data["version"] = model_parts[1]
+
+ return request_data
+
+ def transform_response(
+ self,
+ model: str,
+ raw_response: httpx.Response,
+ model_response: ModelResponse,
+ logging_obj: LoggingClass,
+ request_data: dict,
+ messages: List[AllMessageValues],
+ optional_params: dict,
+ litellm_params: dict,
+ encoding: Any,
+ api_key: Optional[str] = None,
+ json_mode: Optional[bool] = None,
+ ) -> ModelResponse:
+ logging_obj.post_call(
+ input=messages,
+ api_key=api_key,
+ original_response=raw_response.text,
+ additional_args={"complete_input_dict": request_data},
+ )
+ raw_response_json = raw_response.json()
+ if raw_response_json.get("status") != "succeeded":
+ raise ReplicateError(
+ status_code=422,
+ message="LiteLLM Error - prediction not succeeded - {}".format(
+ raw_response_json
+ ),
+ headers=raw_response.headers,
+ )
+ outputs = raw_response_json.get("output", [])
+ response_str = "".join(outputs)
+ if len(response_str) == 0: # edge case, where result from replicate is empty
+ response_str = " "
+
+ ## Building RESPONSE OBJECT
+ if len(response_str) >= 1:
+ model_response.choices[0].message.content = response_str # type: ignore
+
+ # Calculate usage
+ prompt_tokens = token_counter(model=model, messages=messages)
+ completion_tokens = token_counter(
+ model=model,
+ text=response_str,
+ count_response_tokens=True,
+ )
+ model_response.model = "replicate/" + model
+ usage = Usage(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=prompt_tokens + completion_tokens,
+ )
+ setattr(model_response, "usage", usage)
+
+ return model_response
+
+ def get_prediction_url(self, response: httpx.Response) -> str:
+ """
+ response json: {
+ ...,
+ "urls":{"cancel":"https://api.replicate.com/v1/predictions/gqsmqmp1pdrj00cknr08dgmvb4/cancel","get":"https://api.replicate.com/v1/predictions/gqsmqmp1pdrj00cknr08dgmvb4","stream":"https://stream-b.svc.rno2.c.replicate.net/v1/streams/eot4gbydowuin4snhncydwxt57dfwgsc3w3snycx5nid7oef7jga"}
+ }
+ """
+ response_json = response.json()
+ prediction_url = response_json.get("urls", {}).get("get")
+ if prediction_url is None:
+ raise ReplicateError(
+ status_code=400,
+ message="LiteLLM Error - prediction url is None - {}".format(
+ response_json
+ ),
+ headers=response.headers,
+ )
+ return prediction_url
+
+ def validate_environment(
+ self,
+ headers: dict,
+ model: str,
+ messages: List[AllMessageValues],
+ optional_params: dict,
+ api_key: Optional[str] = None,
+ ) -> dict:
+ headers = {
+ "Authorization": f"Token {api_key}",
+ "Content-Type": "application/json",
+ }
+ return headers
diff --git a/litellm/llms/replicate/common_utils.py b/litellm/llms/replicate/common_utils.py
new file mode 100644
index 0000000000..98a5936ccf
--- /dev/null
+++ b/litellm/llms/replicate/common_utils.py
@@ -0,0 +1,15 @@
+from typing import Optional, Union
+
+import httpx
+
+from litellm.llms.base_llm.transformation import BaseLLMException
+
+
+class ReplicateError(BaseLLMException):
+ def __init__(
+ self,
+ status_code: int,
+ message: str,
+ headers: Optional[Union[dict, httpx.Headers]],
+ ):
+ super().__init__(status_code=status_code, message=message, headers=headers)
diff --git a/litellm/llms/sagemaker/completion/handler.py b/litellm/llms/sagemaker/completion/handler.py
index 648f184e89..a0961621fd 100644
--- a/litellm/llms/sagemaker/completion/handler.py
+++ b/litellm/llms/sagemaker/completion/handler.py
@@ -363,6 +363,7 @@ class SagemakerLLM(BaseAWSLLM):
messages=messages,
optional_params=optional_params,
encoding=encoding,
+ litellm_params=litellm_params,
)
async def make_async_call(
@@ -562,6 +563,7 @@ class SagemakerLLM(BaseAWSLLM):
messages=messages,
optional_params=optional_params,
encoding=encoding,
+ litellm_params=litellm_params,
)
def embedding(
diff --git a/litellm/llms/sagemaker/completion/transformation.py b/litellm/llms/sagemaker/completion/transformation.py
index e6bfbb33f6..91efa86adf 100644
--- a/litellm/llms/sagemaker/completion/transformation.py
+++ b/litellm/llms/sagemaker/completion/transformation.py
@@ -202,6 +202,7 @@ class SagemakerConfig(BaseConfig):
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
+ litellm_params: dict,
encoding: str,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
diff --git a/litellm/llms/sambanova/chat.py b/litellm/llms/sambanova/chat.py
index a194a1e0f7..c5e0de4d99 100644
--- a/litellm/llms/sambanova/chat.py
+++ b/litellm/llms/sambanova/chat.py
@@ -7,8 +7,10 @@ this is OpenAI compatible - no translation needed / occurs
import types
from typing import Optional
+from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
-class SambanovaConfig:
+
+class SambanovaConfig(OpenAIGPTConfig):
"""
Reference: https://community.sambanova.ai/t/create-chat-completion-api/
@@ -18,9 +20,7 @@ class SambanovaConfig:
max_tokens: Optional[int] = None
response_format: Optional[dict] = None
seed: Optional[int] = None
- stop: Optional[str] = None
stream: Optional[bool] = None
- temperature: Optional[float] = None
top_p: Optional[int] = None
tool_choice: Optional[str] = None
tools: Optional[list] = None
@@ -46,21 +46,7 @@ class SambanovaConfig:
@classmethod
def get_config(cls):
- return {
- k: v
- for k, v in cls.__dict__.items()
- if not k.startswith("__")
- and not isinstance(
- v,
- (
- types.FunctionType,
- types.BuiltinFunctionType,
- classmethod,
- staticmethod,
- ),
- )
- and v is not None
- }
+ return super().get_config()
def get_supported_openai_params(self, model: str) -> list:
"""
@@ -80,12 +66,3 @@ class SambanovaConfig:
"tools",
"user",
]
-
- def map_openai_params(
- self, model: str, non_default_params: dict, optional_params: dict
- ) -> dict:
- supported_openai_params = self.get_supported_openai_params(model=model)
- for param, value in non_default_params.items():
- if param in supported_openai_params:
- optional_params[param] = value
- return optional_params
diff --git a/litellm/llms/text_completion_codestral.py b/litellm/llms/text_completion_codestral.py
index d3c1ae3cbe..f951b897b8 100644
--- a/litellm/llms/text_completion_codestral.py
+++ b/litellm/llms/text_completion_codestral.py
@@ -22,6 +22,7 @@ from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
get_async_httpx_client,
)
+from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig
from litellm.types.llms.databricks import GenericStreamingChunk
from litellm.utils import (
Choices,
@@ -91,19 +92,17 @@ async def make_call(
return completion_stream
-class MistralTextCompletionConfig:
+class MistralTextCompletionConfig(OpenAITextCompletionConfig):
"""
Reference: https://docs.mistral.ai/api/#operation/createFIMCompletion
"""
suffix: Optional[str] = None
temperature: Optional[int] = None
- top_p: Optional[float] = None
max_tokens: Optional[int] = None
min_tokens: Optional[int] = None
stream: Optional[bool] = None
random_seed: Optional[int] = None
- stop: Optional[str] = None
def __init__(
self,
@@ -123,23 +122,9 @@ class MistralTextCompletionConfig:
@classmethod
def get_config(cls):
- return {
- k: v
- for k, v in cls.__dict__.items()
- if not k.startswith("__")
- and not isinstance(
- v,
- (
- types.FunctionType,
- types.BuiltinFunctionType,
- classmethod,
- staticmethod,
- ),
- )
- and v is not None
- }
+ return super().get_config()
- def get_supported_openai_params(self):
+ def get_supported_openai_params(self, model: str):
return [
"suffix",
"temperature",
@@ -151,7 +136,13 @@ class MistralTextCompletionConfig:
"stop",
]
- def map_openai_params(self, non_default_params: dict, optional_params: dict):
+ def map_openai_params(
+ self,
+ non_default_params: dict,
+ optional_params: dict,
+ model: str,
+ drop_params: bool,
+ ) -> dict:
for param, value in non_default_params.items():
if param == "suffix":
optional_params["suffix"] = value
diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py b/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py
index 02171d032d..5fef37d313 100644
--- a/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py
+++ b/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py
@@ -1,22 +1,20 @@
-from typing import List, Literal, Tuple
+from typing import Dict, List, Literal, Optional, Tuple, Union
import httpx
from litellm import supports_response_schema, supports_system_messages, verbose_logger
+from litellm.llms.base_llm.transformation import BaseLLMException
from litellm.types.llms.vertex_ai import PartType
-class VertexAIError(Exception):
- def __init__(self, status_code, message):
- self.status_code = status_code
- self.message = message
- self.request = httpx.Request(
- method="POST", url=" https://cloud.google.com/vertex-ai/"
- )
- self.response = httpx.Response(status_code=status_code, request=self.request)
- super().__init__(
- self.message
- ) # Call the base class constructor with the parameters it needs
+class VertexAIError(BaseLLMException):
+ def __init__(
+ self,
+ status_code: int,
+ message: str,
+ headers: Optional[Union[Dict, httpx.Headers]] = None,
+ ):
+ super().__init__(message=message, status_code=status_code, headers=headers)
def get_supports_system_message(
diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/transformation.py b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/transformation.py
index c9fe6e3f4d..7e16571f55 100644
--- a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/transformation.py
+++ b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/transformation.py
@@ -299,11 +299,13 @@ def _transform_request_body(
try:
if custom_llm_provider == "gemini":
- content = litellm.GoogleAIStudioGeminiConfig._transform_messages(
+ content = litellm.GoogleAIStudioGeminiConfig()._transform_messages(
messages=messages
)
else:
- content = litellm.VertexGeminiConfig._transform_messages(messages=messages)
+ content = litellm.VertexGeminiConfig()._transform_messages(
+ messages=messages
+ )
tools: Optional[Tools] = optional_params.pop("tools", None)
tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None)
safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop(
@@ -460,15 +462,3 @@ def _transform_system_message(
return SystemInstructions(parts=system_content_blocks), messages
return None, messages
-
-
-def set_headers(auth_header: Optional[str], extra_headers: Optional[dict]) -> dict:
- headers = {
- "Content-Type": "application/json",
- }
- if auth_header is not None:
- headers["Authorization"] = f"Bearer {auth_header}"
- if extra_headers is not None:
- headers.update(extra_headers)
-
- return headers
diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py
index 4287ed1bc2..454da4d4af 100644
--- a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py
+++ b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py
@@ -20,6 +20,7 @@ from typing import (
Optional,
Tuple,
Union,
+ cast,
)
import httpx # type: ignore
@@ -30,6 +31,7 @@ import litellm.litellm_core_utils
import litellm.litellm_core_utils.litellm_logging
from litellm import verbose_logger
from litellm.litellm_core_utils.core_helpers import map_finish_reason
+from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
@@ -86,10 +88,16 @@ from .transformation import (
_gemini_convert_messages_with_history,
_process_gemini_image,
async_transform_request_body,
- set_headers,
sync_transform_request_body,
)
+if TYPE_CHECKING:
+ from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
+
+ LoggingClass = LiteLLMLoggingObj
+else:
+ LoggingClass = Any
+
class VertexAIConfig:
"""
@@ -277,7 +285,7 @@ class VertexAIConfig:
]
-class VertexGeminiConfig:
+class VertexGeminiConfig(BaseConfig):
"""
Reference: https://cloud.google.com/vertex-ai/docs/generative-ai/chat/test-chat-prompts
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
@@ -338,23 +346,9 @@ class VertexGeminiConfig:
@classmethod
def get_config(cls):
- return {
- k: v
- for k, v in cls.__dict__.items()
- if not k.startswith("__")
- and not isinstance(
- v,
- (
- types.FunctionType,
- types.BuiltinFunctionType,
- classmethod,
- staticmethod,
- ),
- )
- and v is not None
- }
+ return super().get_config()
- def get_supported_openai_params(self):
+ def get_supported_openai_params(self, model: str) -> List[str]:
return [
"temperature",
"top_p",
@@ -473,12 +467,11 @@ class VertexGeminiConfig:
def map_openai_params(
self,
+ non_default_params: Dict,
+ optional_params: Dict,
model: str,
- non_default_params: dict,
- optional_params: dict,
drop_params: bool,
- ):
-
+ ) -> Dict:
for param, value in non_default_params.items():
if param == "temperature":
optional_params["temperature"] = value
@@ -751,38 +744,38 @@ class VertexGeminiConfig:
return model_response
- def _transform_response(
+ def transform_response(
self,
model: str,
- response: httpx.Response,
+ raw_response: httpx.Response,
model_response: ModelResponse,
- logging_obj: litellm.litellm_core_utils.litellm_logging.Logging,
- optional_params: dict,
- litellm_params: dict,
- api_key: str,
- data: Union[dict, str, RequestBody],
- messages: List,
- print_verbose,
- encoding,
+ logging_obj: LoggingClass,
+ request_data: Dict,
+ messages: List[AllMessageValues],
+ optional_params: Dict,
+ litellm_params: Dict,
+ encoding: Any,
+ api_key: Optional[str] = None,
+ json_mode: Optional[bool] = None,
) -> ModelResponse:
-
## LOGGING
logging_obj.post_call(
input=messages,
api_key="",
- original_response=response.text,
- additional_args={"complete_input_dict": data},
+ original_response=raw_response.text,
+ additional_args={"complete_input_dict": request_data},
)
## RESPONSE OBJECT
try:
- completion_response = GenerateContentResponseBody(**response.json()) # type: ignore
+ completion_response = GenerateContentResponseBody(**raw_response.json()) # type: ignore
except Exception as e:
raise VertexAIError(
message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format(
- response.text, str(e)
+ raw_response.text, str(e)
),
status_code=422,
+ headers=raw_response.headers,
)
## GET MODEL ##
@@ -915,14 +908,53 @@ class VertexGeminiConfig:
completion_response, str(e)
),
status_code=422,
+ headers=raw_response.headers,
)
return model_response
- @staticmethod
- def _transform_messages(messages: List[AllMessageValues]) -> List[ContentType]:
+ def _transform_messages(
+ self, messages: List[AllMessageValues]
+ ) -> List[ContentType]:
return _gemini_convert_messages_with_history(messages=messages)
+ def get_error_class(
+ self, error_message: str, status_code: int, headers: Union[Dict, httpx.Headers]
+ ) -> BaseLLMException:
+ return VertexAIError(
+ message=error_message, status_code=status_code, headers=headers
+ )
+
+ def transform_request(
+ self,
+ model: str,
+ messages: List[AllMessageValues],
+ optional_params: Dict,
+ litellm_params: Dict,
+ headers: Dict,
+ ) -> Dict:
+ raise NotImplementedError(
+ "Vertex AI has a custom implementation of transform_request. Needs sync + async."
+ )
+
+ def validate_environment(
+ self,
+ headers: Optional[Dict],
+ model: str,
+ messages: List[AllMessageValues],
+ optional_params: Dict,
+ api_key: Optional[str] = None,
+ ) -> Dict:
+ default_headers = {
+ "Content-Type": "application/json",
+ }
+ if api_key is not None:
+ default_headers["Authorization"] = f"Bearer {api_key}"
+ if headers is not None:
+ default_headers.update(headers)
+
+ return default_headers
+
class GoogleAIStudioGeminiConfig(
VertexGeminiConfig
@@ -978,23 +1010,9 @@ class GoogleAIStudioGeminiConfig(
@classmethod
def get_config(cls):
- return {
- k: v
- for k, v in cls.__dict__.items()
- if not k.startswith("__")
- and not isinstance(
- v,
- (
- types.FunctionType,
- types.BuiltinFunctionType,
- classmethod,
- staticmethod,
- ),
- )
- and v is not None
- }
+ return super().get_config()
- def get_supported_openai_params(self):
+ def get_supported_openai_params(self, model: str) -> List[str]:
return [
"temperature",
"top_p",
@@ -1012,22 +1030,27 @@ class GoogleAIStudioGeminiConfig(
def map_openai_params(
self,
- model: str,
non_default_params: Dict,
optional_params: Dict,
+ model: str,
drop_params: bool,
- ):
+ ) -> Dict:
+
# drop frequency_penalty and presence_penalty
if "frequency_penalty" in non_default_params:
del non_default_params["frequency_penalty"]
if "presence_penalty" in non_default_params:
del non_default_params["presence_penalty"]
return super().map_openai_params(
- model, non_default_params, optional_params, drop_params
+ model=model,
+ non_default_params=non_default_params,
+ optional_params=optional_params,
+ drop_params=drop_params,
)
- @staticmethod
- def _transform_messages(messages: List[AllMessageValues]) -> List[ContentType]:
+ def _transform_messages(
+ self, messages: List[AllMessageValues]
+ ) -> List[ContentType]:
"""
Google AI Studio Gemini does not support image urls in messages.
"""
@@ -1075,9 +1098,14 @@ async def make_call(
raise VertexAIError(
status_code=e.response.status_code,
message=VertexGeminiConfig().translate_exception_str(exception_string),
+ headers=e.response.headers,
)
if response.status_code != 200:
- raise VertexAIError(status_code=response.status_code, message=response.text)
+ raise VertexAIError(
+ status_code=response.status_code,
+ message=response.text,
+ headers=response.headers,
+ )
completion_stream = ModelResponseIterator(
streaming_response=response.aiter_lines(), sync_stream=False
@@ -1111,7 +1139,11 @@ def make_sync_call(
response = client.post(api_base, headers=headers, data=data, stream=True)
if response.status_code != 200:
- raise VertexAIError(status_code=response.status_code, message=response.read())
+ raise VertexAIError(
+ status_code=response.status_code,
+ message=str(response.read()),
+ headers=response.headers,
+ )
completion_stream = ModelResponseIterator(
streaming_response=response.iter_lines(), sync_stream=True
@@ -1182,7 +1214,13 @@ class VertexLLM(VertexBase):
should_use_v1beta1_features=should_use_v1beta1_features,
)
- headers = set_headers(auth_header=auth_header, extra_headers=extra_headers)
+ headers = VertexGeminiConfig().validate_environment(
+ api_key=auth_header,
+ headers=extra_headers,
+ model=model,
+ messages=messages,
+ optional_params=optional_params,
+ )
## LOGGING
logging_obj.pre_call(
@@ -1263,7 +1301,13 @@ class VertexLLM(VertexBase):
should_use_v1beta1_features=should_use_v1beta1_features,
)
- headers = set_headers(auth_header=auth_header, extra_headers=extra_headers)
+ headers = VertexGeminiConfig().validate_environment(
+ api_key=auth_header,
+ headers=extra_headers,
+ model=model,
+ messages=messages,
+ optional_params=optional_params,
+ )
request_body = await async_transform_request_body(**data) # type: ignore
_async_client_params = {}
@@ -1287,23 +1331,32 @@ class VertexLLM(VertexBase):
)
try:
- response = await client.post(api_base, headers=headers, json=request_body) # type: ignore
+ response = await client.post(
+ api_base, headers=headers, json=cast(dict, request_body)
+ ) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
- raise VertexAIError(status_code=error_code, message=err.response.text)
+ raise VertexAIError(
+ status_code=error_code,
+ message=err.response.text,
+ headers=err.response.headers,
+ )
except httpx.TimeoutException:
- raise VertexAIError(status_code=408, message="Timeout error occurred.")
+ raise VertexAIError(
+ status_code=408,
+ message="Timeout error occurred.",
+ headers=None,
+ )
- return VertexGeminiConfig()._transform_response(
+ return VertexGeminiConfig().transform_response(
model=model,
- response=response,
+ raw_response=response,
model_response=model_response,
logging_obj=logging_obj,
api_key="",
- data=request_body,
+ request_data=cast(dict, request_body),
messages=messages,
- print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
@@ -1421,7 +1474,13 @@ class VertexLLM(VertexBase):
api_base=api_base,
should_use_v1beta1_features=should_use_v1beta1_features,
)
- headers = set_headers(auth_header=auth_header, extra_headers=extra_headers)
+ headers = VertexGeminiConfig().validate_environment(
+ api_key=auth_header,
+ headers=extra_headers,
+ model=model,
+ messages=messages,
+ optional_params=optional_params,
+ )
## TRANSFORMATION ##
data = sync_transform_request_body(**transform_request_params)
@@ -1479,21 +1538,28 @@ class VertexLLM(VertexBase):
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
- raise VertexAIError(status_code=error_code, message=err.response.text)
+ raise VertexAIError(
+ status_code=error_code,
+ message=err.response.text,
+ headers=err.response.headers,
+ )
except httpx.TimeoutException:
- raise VertexAIError(status_code=408, message="Timeout error occurred.")
+ raise VertexAIError(
+ status_code=408,
+ message="Timeout error occurred.",
+ headers=None,
+ )
- return VertexGeminiConfig()._transform_response(
+ return VertexGeminiConfig().transform_response(
model=model,
- response=response,
+ raw_response=response,
model_response=model_response,
logging_obj=logging_obj,
optional_params=optional_params,
litellm_params=litellm_params,
api_key="",
- data=data, # type: ignore
+ request_data=data, # type: ignore
messages=messages,
- print_verbose=print_verbose,
encoding=encoding,
)
diff --git a/litellm/llms/volcengine.py b/litellm/llms/volcengine.py
index 9b288c8681..a8ecb67663 100644
--- a/litellm/llms/volcengine.py
+++ b/litellm/llms/volcengine.py
@@ -2,9 +2,10 @@ import types
from typing import Literal, Optional, Union
import litellm
+from litellm.llms.openai_like.chat.transformation import OpenAILikeChatConfig
-class VolcEngineConfig:
+class VolcEngineConfig(OpenAILikeChatConfig):
frequency_penalty: Optional[int] = None
function_call: Optional[Union[str, dict]] = None
functions: Optional[list] = None
@@ -38,21 +39,7 @@ class VolcEngineConfig:
@classmethod
def get_config(cls):
- return {
- k: v
- for k, v in cls.__dict__.items()
- if not k.startswith("__")
- and not isinstance(
- v,
- (
- types.FunctionType,
- types.BuiltinFunctionType,
- classmethod,
- staticmethod,
- ),
- )
- and v is not None
- }
+ return super().get_config()
def get_supported_openai_params(self, model: str) -> list:
return [
@@ -77,14 +64,3 @@ class VolcEngineConfig:
"max_retries",
"extra_headers",
] # works across all models
-
- def map_openai_params(
- self, non_default_params: dict, optional_params: dict, model: str
- ) -> dict:
- supported_openai_params = self.get_supported_openai_params(model)
- for param, value in non_default_params.items():
- if param == "max_completion_tokens":
- optional_params["max_tokens"] = value
- elif param in supported_openai_params:
- optional_params[param] = value
- return optional_params
diff --git a/litellm/llms/watsonx/completion/transformation.py b/litellm/llms/watsonx/completion/transformation.py
index ab26890e00..6f2b188106 100644
--- a/litellm/llms/watsonx/completion/transformation.py
+++ b/litellm/llms/watsonx/completion/transformation.py
@@ -274,6 +274,7 @@ class IBMWatsonXAIConfig(BaseConfig):
request_data: Dict,
messages: List[AllMessageValues],
optional_params: Dict,
+ litellm_params: Dict,
encoding: str,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
diff --git a/litellm/main.py b/litellm/main.py
index c639f237d9..cab0c7167d 100644
--- a/litellm/main.py
+++ b/litellm/main.py
@@ -83,26 +83,13 @@ from litellm.utils import (
from ._logging import verbose_logger
from .caching.caching import disable_cache, enable_cache, update_cache
from .litellm_core_utils.streaming_chunk_builder_utils import ChunkProcessor
-from .llms import (
- aleph_alpha,
- baseten,
- maritalk,
- nlp_cloud,
- ollama_chat,
- oobabooga,
- openrouter,
- palm,
- petals,
- replicate,
-)
-from .llms.ai21 import completion as ai21
+from .llms import aleph_alpha, baseten, maritalk, ollama_chat, petals
from .llms.anthropic.chat import AnthropicChatCompletion
from .llms.azure.audio_transcriptions import AzureAudioTranscription
from .llms.azure.azure import AzureChatCompletion, _check_dynamic_azure_params
from .llms.azure.chat.o1_handler import AzureOpenAIO1ChatCompletion
-from .llms.azure_ai.chat import AzureAIChatCompletion
+from .llms.azure.completion.handler import AzureTextCompletion
from .llms.azure_ai.embed import AzureAIEmbedding
-from .llms.azure_text import AzureTextCompletion
from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
from .llms.bedrock.embed.embedding import BedrockEmbedding
from .llms.bedrock.image.image_handler import BedrockImageGeneration
@@ -111,13 +98,16 @@ from .llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
from .llms.custom_llm import CustomLLM, custom_chat_llm_router
from .llms.databricks.chat.handler import DatabricksChatCompletion
from .llms.databricks.embed.handler import DatabricksEmbeddingHandler
+from .llms.deprecated_providers import palm
from .llms.groq.chat.handler import GroqChatCompletion
-from .llms.huggingface_restapi import Huggingface
+from .llms.huggingface.chat.handler import Huggingface
+from .llms.nlp_cloud.chat.handler import completion as nlp_cloud_chat_completion
+from .llms.oobabooga.chat import oobabooga
from .llms.ollama.completion import handler as ollama
from .llms.openai.transcriptions.handler import OpenAIAudioTranscription
-from .llms.openai.chat.o1_handler import OpenAIO1ChatCompletion
from .llms.openai.completion.handler import OpenAITextCompletion
from .llms.openai.openai import OpenAIChatCompletion
+from .llms.openai_like.chat.handler import OpenAILikeChatHandler
from .llms.openai_like.embedding.handler import OpenAILikeEmbeddingHandler
from .llms.predibase import PredibaseChatCompletion
from .llms.prompt_templates.common_utils import get_completion_messages
@@ -131,6 +121,7 @@ from .llms.prompt_templates.factory import (
)
from .llms.sagemaker.chat.handler import SagemakerChatHandler
from .llms.sagemaker.completion.handler import SagemakerLLM
+from .llms.replicate.chat.handler import completion as replicate_chat_completion
from .llms.text_completion_codestral import CodestralTextCompletion
from .llms.together_ai.completion.handler import TogetherAITextCompletion
from .llms.triton import TritonChatCompletion
@@ -159,7 +150,7 @@ from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings.embedding_handler im
from .llms.vertex_ai_and_google_ai_studio.vertex_model_garden.main import (
VertexAIModelGardenModels,
)
-from .llms.vllm.completion import handler
+from .llms.vllm.completion import handler as vllm_handler
from .llms.watsonx.chat.handler import WatsonXChatHandler
from .llms.watsonx.completion.handler import IBMWatsonXAI
from .types.llms.openai import (
@@ -196,12 +187,10 @@ from litellm.utils import (
####### ENVIRONMENT VARIABLES ###################
openai_chat_completions = OpenAIChatCompletion()
openai_text_completions = OpenAITextCompletion()
-openai_o1_chat_completions = OpenAIO1ChatCompletion()
openai_audio_transcriptions = OpenAIAudioTranscription()
databricks_chat_completions = DatabricksChatCompletion()
groq_chat_completions = GroqChatCompletion()
together_ai_text_completions = TogetherAITextCompletion()
-azure_ai_chat_completions = AzureAIChatCompletion()
azure_ai_embedding = AzureAIEmbedding()
anthropic_chat_completions = AnthropicChatCompletion()
azure_chat_completions = AzureChatCompletion()
@@ -228,6 +217,7 @@ watsonxai = IBMWatsonXAI()
sagemaker_llm = SagemakerLLM()
watsonx_chat_completion = WatsonXChatHandler()
openai_like_embedding = OpenAILikeEmbeddingHandler()
+openai_like_chat_completion = OpenAILikeChatHandler()
databricks_embedding = DatabricksEmbeddingHandler()
base_llm_http_handler = BaseLLMHTTPHandler()
sagemaker_chat_completion = SagemakerChatHandler()
@@ -449,6 +439,7 @@ async def acompletion(
or custom_llm_provider == "cerebras"
or custom_llm_provider == "sambanova"
or custom_llm_provider == "ai21_chat"
+ or custom_llm_provider == "ai21"
or custom_llm_provider == "volcengine"
or custom_llm_provider == "codestral"
or custom_llm_provider == "text-completion-codestral"
@@ -1316,7 +1307,7 @@ def completion( # type: ignore # noqa: PLR0915
## COMPLETION CALL
try:
- response = azure_ai_chat_completions.completion(
+ response = openai_chat_completions.completion(
model=model,
messages=messages,
headers=headers,
@@ -1513,9 +1504,7 @@ def completion( # type: ignore # noqa: PLR0915
or custom_llm_provider == "nvidia_nim"
or custom_llm_provider == "cerebras"
or custom_llm_provider == "sambanova"
- or custom_llm_provider == "ai21_chat"
or custom_llm_provider == "volcengine"
- or custom_llm_provider == "codestral"
or custom_llm_provider == "deepseek"
or custom_llm_provider == "anyscale"
or custom_llm_provider == "mistral"
@@ -1562,46 +1551,25 @@ def completion( # type: ignore # noqa: PLR0915
## COMPLETION CALL
try:
- if litellm.openAIO1Config.is_model_o1_reasoning_model(model=model):
- response = openai_o1_chat_completions.completion(
- model=model,
- messages=messages,
- headers=headers,
- model_response=model_response,
- print_verbose=print_verbose,
- api_key=api_key,
- api_base=api_base,
- acompletion=acompletion,
- logging_obj=logging,
- optional_params=optional_params,
- litellm_params=litellm_params,
- logger_fn=logger_fn,
- timeout=timeout, # type: ignore
- custom_prompt_dict=custom_prompt_dict,
- client=client, # pass AsyncOpenAI, OpenAI client
- organization=organization,
- custom_llm_provider=custom_llm_provider,
- )
- else:
- response = openai_chat_completions.completion(
- model=model,
- messages=messages,
- headers=headers,
- model_response=model_response,
- print_verbose=print_verbose,
- api_key=api_key,
- api_base=api_base,
- acompletion=acompletion,
- logging_obj=logging,
- optional_params=optional_params,
- litellm_params=litellm_params,
- logger_fn=logger_fn,
- timeout=timeout, # type: ignore
- custom_prompt_dict=custom_prompt_dict,
- client=client, # pass AsyncOpenAI, OpenAI client
- organization=organization,
- custom_llm_provider=custom_llm_provider,
- )
+ response = openai_chat_completions.completion(
+ model=model,
+ messages=messages,
+ headers=headers,
+ model_response=model_response,
+ print_verbose=print_verbose,
+ api_key=api_key,
+ api_base=api_base,
+ acompletion=acompletion,
+ logging_obj=logging,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ logger_fn=logger_fn,
+ timeout=timeout, # type: ignore
+ custom_prompt_dict=custom_prompt_dict,
+ client=client, # pass AsyncOpenAI, OpenAI client
+ organization=organization,
+ custom_llm_provider=custom_llm_provider,
+ )
except Exception as e:
## LOGGING - log the original exception returned
logging.post_call(
@@ -1627,7 +1595,6 @@ def completion( # type: ignore # noqa: PLR0915
or model in litellm.replicate_models
):
# Setting the relevant API KEY for replicate, replicate defaults to using os.environ.get("REPLICATE_API_TOKEN")
- replicate_key = None
replicate_key = (
api_key
or litellm.replicate_key
@@ -1645,7 +1612,7 @@ def completion( # type: ignore # noqa: PLR0915
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
- model_response = replicate.completion( # type: ignore
+ model_response = replicate_chat_completion( # type: ignore
model=model,
messages=messages,
api_base=api_base,
@@ -1659,6 +1626,7 @@ def completion( # type: ignore # noqa: PLR0915
logging_obj=logging,
custom_prompt_dict=custom_prompt_dict,
acompletion=acompletion,
+ headers=headers,
)
if optional_params.get("stream", False) is True:
@@ -1806,7 +1774,7 @@ def completion( # type: ignore # noqa: PLR0915
or "https://api.nlpcloud.io/v1/gpu/"
)
- response = nlp_cloud.completion(
+ response = nlp_cloud_chat_completion(
model=model,
messages=messages,
api_base=api_base,
@@ -1969,10 +1937,10 @@ def completion( # type: ignore # noqa: PLR0915
api_base
or litellm.api_base
or get_secret("MARITALK_API_BASE")
- or "https://chat.maritaca.ai/api/chat/inference"
+ or "https://chat.maritaca.ai/api"
)
- model_response = maritalk.completion(
+ model_response = openai_like_chat_completion.completion(
model=model,
messages=messages,
api_base=api_base,
@@ -1984,17 +1952,10 @@ def completion( # type: ignore # noqa: PLR0915
encoding=encoding,
api_key=maritalk_key,
logging_obj=logging,
+ custom_llm_provider="maritalk",
+ custom_prompt_dict=custom_prompt_dict,
)
- if "stream" in optional_params and optional_params["stream"] is True:
- # don't try to access stream object,
- response = CustomStreamWrapper(
- model_response,
- model,
- custom_llm_provider="maritalk",
- logging_obj=logging,
- )
- return response
response = model_response
elif custom_llm_provider == "huggingface":
custom_llm_provider = "huggingface"
@@ -2012,7 +1973,7 @@ def completion( # type: ignore # noqa: PLR0915
model=model,
messages=messages,
api_base=api_base, # type: ignore
- headers=hf_headers,
+ headers=hf_headers or {},
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
@@ -2024,6 +1985,7 @@ def completion( # type: ignore # noqa: PLR0915
logging_obj=logging,
custom_prompt_dict=custom_prompt_dict,
timeout=timeout, # type: ignore
+ client=client,
)
if (
"stream" in optional_params
@@ -2146,7 +2108,7 @@ def completion( # type: ignore # noqa: PLR0915
headers = openrouter_headers
## Load Config
- config = openrouter.OpenrouterConfig.get_config()
+ config = litellm.OpenrouterConfig.get_config()
for k, v in config.items():
if k == "extra_body":
# we use openai 'extra_body' to pass openrouter specific params - transforms, route, models
@@ -2190,30 +2152,9 @@ def completion( # type: ignore # noqa: PLR0915
"""
pass
elif custom_llm_provider == "palm":
- palm_api_key = api_key or get_secret("PALM_API_KEY") or litellm.api_key
-
- # palm does not support streaming as yet :(
- model_response = palm.completion(
- model=model,
- messages=messages,
- model_response=model_response,
- print_verbose=print_verbose,
- optional_params=optional_params,
- litellm_params=litellm_params,
- logger_fn=logger_fn,
- encoding=encoding,
- api_key=palm_api_key,
- logging_obj=logging,
+ raise ValueError(
+ "Palm was decommisioned on October 2024. Please use the `gemini/` route for Gemini Google AI Studio Models. Announcement: https://ai.google.dev/palm_docs/palm?hl=en"
)
- # fake palm streaming
- if "stream" in optional_params and optional_params["stream"] is True:
- # fake streaming for palm
- resp_string = model_response["choices"][0]["message"]["content"]
- response = CustomStreamWrapper(
- resp_string, model, custom_llm_provider="palm", logging_obj=logging
- )
- return response
- response = model_response
elif custom_llm_provider == "vertex_ai_beta" or custom_llm_provider == "gemini":
vertex_ai_project = (
optional_params.pop("vertex_project", None)
@@ -2475,51 +2416,9 @@ def completion( # type: ignore # noqa: PLR0915
):
return _model_response
response = _model_response
- elif custom_llm_provider == "ai21":
- custom_llm_provider = "ai21"
- ai21_key = (
- api_key
- or litellm.ai21_key
- or os.environ.get("AI21_API_KEY")
- or litellm.api_key
- )
-
- api_base = (
- api_base
- or litellm.api_base
- or get_secret("AI21_API_BASE")
- or "https://api.ai21.com/studio/v1/"
- )
-
- model_response = ai21.completion(
- model=model,
- messages=messages,
- api_base=api_base,
- model_response=model_response,
- print_verbose=print_verbose,
- optional_params=optional_params,
- litellm_params=litellm_params,
- logger_fn=logger_fn,
- encoding=encoding,
- api_key=ai21_key,
- logging_obj=logging,
- )
-
- if "stream" in optional_params and optional_params["stream"] is True:
- # don't try to access stream object,
- response = CustomStreamWrapper(
- model_response,
- model,
- custom_llm_provider="ai21",
- logging_obj=logging,
- )
- return response
-
- ## RESPONSE OBJECT
- response = model_response
elif custom_llm_provider == "sagemaker_chat":
# boto3 reads keys from .env
- response = sagemaker_chat_completion.completion(
+ model_response = sagemaker_chat_completion.completion(
model=model,
messages=messages,
model_response=model_response,
@@ -2531,9 +2430,13 @@ def completion( # type: ignore # noqa: PLR0915
encoding=encoding,
logging_obj=logging,
acompletion=acompletion,
- headers=headers or {},
)
- elif custom_llm_provider == "sagemaker":
+
+ ## RESPONSE OBJECT
+ response = model_response
+ elif (
+ custom_llm_provider == "sagemaker"
+ ):
# boto3 reads keys from .env
model_response = sagemaker_llm.completion(
model=model,
@@ -2691,7 +2594,7 @@ def completion( # type: ignore # noqa: PLR0915
response = response
elif custom_llm_provider == "vllm":
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
- model_response = handler.completion(
+ model_response = vllm_handler.completion(
model=model,
messages=messages,
custom_prompt_dict=custom_prompt_dict,
@@ -3872,6 +3775,7 @@ async def atext_completion(
or custom_llm_provider == "cerebras"
or custom_llm_provider == "sambanova"
or custom_llm_provider == "ai21_chat"
+ or custom_llm_provider == "ai21"
or custom_llm_provider == "volcengine"
or custom_llm_provider == "text-completion-codestral"
or custom_llm_provider == "deepseek"
diff --git a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py
index 696e864cb6..9161eb8493 100644
--- a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py
+++ b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py
@@ -56,6 +56,7 @@ class AnthropicPassthroughLoggingHandler:
request_data={},
encoding=litellm.encoding,
json_mode=False,
+ litellm_params={},
)
kwargs = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload(
diff --git a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py
index 2773979adf..0d2b2f9afe 100644
--- a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py
+++ b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py
@@ -41,20 +41,19 @@ class VertexPassthroughLoggingHandler:
instance_of_vertex_llm = litellm.VertexGeminiConfig()
litellm_model_response: litellm.ModelResponse = (
- instance_of_vertex_llm._transform_response(
+ instance_of_vertex_llm.transform_response(
model=model,
messages=[
{"role": "user", "content": "no-message-pass-through-endpoint"}
],
- response=httpx_response,
+ raw_response=httpx_response,
model_response=litellm.ModelResponse(),
logging_obj=logging_obj,
optional_params={},
litellm_params={},
api_key="",
- data={},
- print_verbose=litellm.print_verbose,
- encoding=None,
+ request_data={},
+ encoding=litellm.encoding,
)
)
kwargs = VertexPassthroughLoggingHandler._create_vertex_response_logging_payload_for_generate_content(
diff --git a/litellm/utils.py b/litellm/utils.py
index acb3ad07c0..05af5d0252 100644
--- a/litellm/utils.py
+++ b/litellm/utils.py
@@ -2923,22 +2923,16 @@ def get_optional_params( # noqa: PLR0915
)
_check_valid_arg(supported_params=supported_params)
- if stream:
- optional_params["stream"] = stream
- # return optional_params
- if max_tokens is not None:
- if "vicuna" in model or "flan" in model:
- optional_params["max_length"] = max_tokens
- elif "meta/codellama-13b" in model:
- optional_params["max_tokens"] = max_tokens
- else:
- optional_params["max_new_tokens"] = max_tokens
- if temperature is not None:
- optional_params["temperature"] = temperature
- if top_p is not None:
- optional_params["top_p"] = top_p
- if stop is not None:
- optional_params["stop_sequences"] = stop
+ optional_params = litellm.ReplicateConfig().map_openai_params(
+ non_default_params=non_default_params,
+ optional_params=optional_params,
+ model=model,
+ drop_params=(
+ drop_params
+ if drop_params is not None and isinstance(drop_params, bool)
+ else False
+ ),
+ )
elif custom_llm_provider == "predibase":
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
@@ -2954,7 +2948,14 @@ def get_optional_params( # noqa: PLR0915
)
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.HuggingfaceConfig().map_openai_params(
- non_default_params=non_default_params, optional_params=optional_params
+ non_default_params=non_default_params,
+ optional_params=optional_params,
+ model=model,
+ drop_params=(
+ drop_params
+ if drop_params is not None and isinstance(drop_params, bool)
+ else False
+ ),
)
elif custom_llm_provider == "together_ai":
## check if unsupported param passed in
@@ -2973,53 +2974,6 @@ def get_optional_params( # noqa: PLR0915
else False
),
)
- elif custom_llm_provider == "ai21":
- ## check if unsupported param passed in
- supported_params = get_supported_openai_params(
- model=model, custom_llm_provider=custom_llm_provider
- )
- _check_valid_arg(supported_params=supported_params)
-
- if stream:
- optional_params["stream"] = stream
- if n is not None:
- optional_params["numResults"] = n
- if max_tokens is not None:
- optional_params["maxTokens"] = max_tokens
- if temperature is not None:
- optional_params["temperature"] = temperature
- if top_p is not None:
- optional_params["topP"] = top_p
- if stop is not None:
- optional_params["stopSequences"] = stop
- if frequency_penalty is not None:
- optional_params["frequencyPenalty"] = {"scale": frequency_penalty}
- if presence_penalty is not None:
- optional_params["presencePenalty"] = {"scale": presence_penalty}
- elif (
- custom_llm_provider == "palm"
- ): # https://developers.generativeai.google/tutorials/curl_quickstart
- ## check if unsupported param passed in
- supported_params = get_supported_openai_params(
- model=model, custom_llm_provider=custom_llm_provider
- )
- _check_valid_arg(supported_params=supported_params)
-
- if temperature is not None:
- optional_params["temperature"] = temperature
- if top_p is not None:
- optional_params["top_p"] = top_p
- if stream:
- optional_params["stream"] = stream
- if n is not None:
- optional_params["candidate_count"] = n
- if stop is not None:
- if isinstance(stop, str):
- optional_params["stop_sequences"] = [stop]
- elif isinstance(stop, list):
- optional_params["stop_sequences"] = stop
- if max_tokens is not None:
- optional_params["max_output_tokens"] = max_tokens
elif custom_llm_provider == "vertex_ai" and (
model in litellm.vertex_chat_models
or model in litellm.vertex_code_chat_models
@@ -3120,12 +3074,25 @@ def get_optional_params( # noqa: PLR0915
_check_valid_arg(supported_params=supported_params)
if "codestral" in model:
optional_params = litellm.MistralTextCompletionConfig().map_openai_params(
- non_default_params=non_default_params, optional_params=optional_params
+ model=model,
+ non_default_params=non_default_params,
+ optional_params=optional_params,
+ drop_params=(
+ drop_params
+ if drop_params is not None and isinstance(drop_params, bool)
+ else False
+ ),
)
else:
optional_params = litellm.MistralConfig().map_openai_params(
+ model=model,
non_default_params=non_default_params,
optional_params=optional_params,
+ drop_params=(
+ drop_params
+ if drop_params is not None and isinstance(drop_params, bool)
+ else False
+ ),
)
elif custom_llm_provider == "vertex_ai" and model in litellm.vertex_ai_ai21_models:
supported_params = get_supported_openai_params(
@@ -3326,29 +3293,28 @@ def get_optional_params( # noqa: PLR0915
model=model,
non_default_params=non_default_params,
optional_params=optional_params,
+ drop_params=(
+ drop_params
+ if drop_params is not None and isinstance(drop_params, bool)
+ else False
+ ),
)
elif custom_llm_provider == "nlp_cloud":
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
+ optional_params = litellm.NLPCloudConfig().map_openai_params(
+ non_default_params=non_default_params,
+ optional_params=optional_params,
+ model=model,
+ drop_params=(
+ drop_params
+ if drop_params is not None and isinstance(drop_params, bool)
+ else False
+ ),
+ )
- if max_tokens is not None:
- optional_params["max_length"] = max_tokens
- if stream:
- optional_params["stream"] = stream
- if temperature is not None:
- optional_params["temperature"] = temperature
- if top_p is not None:
- optional_params["top_p"] = top_p
- if presence_penalty is not None:
- optional_params["presence_penalty"] = presence_penalty
- if frequency_penalty is not None:
- optional_params["frequency_penalty"] = frequency_penalty
- if n is not None:
- optional_params["num_return_sequences"] = n
- if stop is not None:
- optional_params["stop_sequences"] = stop
elif custom_llm_provider == "petals":
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
@@ -3435,7 +3401,14 @@ def get_optional_params( # noqa: PLR0915
)
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.MistralConfig().map_openai_params(
- non_default_params=non_default_params, optional_params=optional_params
+ non_default_params=non_default_params,
+ optional_params=optional_params,
+ model=model,
+ drop_params=(
+ drop_params
+ if drop_params is not None and isinstance(drop_params, bool)
+ else False
+ ),
)
elif custom_llm_provider == "text-completion-codestral":
supported_params = get_supported_openai_params(
@@ -3443,7 +3416,14 @@ def get_optional_params( # noqa: PLR0915
)
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.MistralTextCompletionConfig().map_openai_params(
- non_default_params=non_default_params, optional_params=optional_params
+ non_default_params=non_default_params,
+ optional_params=optional_params,
+ model=model,
+ drop_params=(
+ drop_params
+ if drop_params is not None and isinstance(drop_params, bool)
+ else False
+ ),
)
elif custom_llm_provider == "databricks":
@@ -3470,6 +3450,11 @@ def get_optional_params( # noqa: PLR0915
model=model,
non_default_params=non_default_params,
optional_params=optional_params,
+ drop_params=(
+ drop_params
+ if drop_params is not None and isinstance(drop_params, bool)
+ else False
+ ),
)
elif custom_llm_provider == "cerebras":
supported_params = get_supported_openai_params(
@@ -3480,6 +3465,11 @@ def get_optional_params( # noqa: PLR0915
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
+ drop_params=(
+ drop_params
+ if drop_params is not None and isinstance(drop_params, bool)
+ else False
+ ),
)
elif custom_llm_provider == "xai":
supported_params = get_supported_openai_params(
@@ -3491,7 +3481,7 @@ def get_optional_params( # noqa: PLR0915
non_default_params=non_default_params,
optional_params=optional_params,
)
- elif custom_llm_provider == "ai21_chat":
+ elif custom_llm_provider == "ai21_chat" or custom_llm_provider == "ai21":
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
@@ -3500,6 +3490,11 @@ def get_optional_params( # noqa: PLR0915
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
+ drop_params=(
+ drop_params
+ if drop_params is not None and isinstance(drop_params, bool)
+ else False
+ ),
)
elif custom_llm_provider == "fireworks_ai":
supported_params = get_supported_openai_params(
@@ -3525,6 +3520,11 @@ def get_optional_params( # noqa: PLR0915
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
+ drop_params=(
+ drop_params
+ if drop_params is not None and isinstance(drop_params, bool)
+ else False
+ ),
)
elif custom_llm_provider == "hosted_vllm":
supported_params = get_supported_openai_params(
@@ -3594,55 +3594,17 @@ def get_optional_params( # noqa: PLR0915
)
_check_valid_arg(supported_params=supported_params)
- if functions is not None:
- optional_params["functions"] = functions
- if function_call is not None:
- optional_params["function_call"] = function_call
- if temperature is not None:
- optional_params["temperature"] = temperature
- if top_p is not None:
- optional_params["top_p"] = top_p
- if n is not None:
- optional_params["n"] = n
- if stream is not None:
- optional_params["stream"] = stream
- if stop is not None:
- optional_params["stop"] = stop
- if max_tokens is not None:
- optional_params["max_tokens"] = max_tokens
- if presence_penalty is not None:
- optional_params["presence_penalty"] = presence_penalty
- if frequency_penalty is not None:
- optional_params["frequency_penalty"] = frequency_penalty
- if logit_bias is not None:
- optional_params["logit_bias"] = logit_bias
- if user is not None:
- optional_params["user"] = user
- if response_format is not None:
- optional_params["response_format"] = response_format
- if seed is not None:
- optional_params["seed"] = seed
- if tools is not None:
- optional_params["tools"] = tools
- if tool_choice is not None:
- optional_params["tool_choice"] = tool_choice
- if max_retries is not None:
- optional_params["max_retries"] = max_retries
-
- # OpenRouter-only parameters
- extra_body = {}
- transforms = passed_params.pop("transforms", None)
- models = passed_params.pop("models", None)
- route = passed_params.pop("route", None)
- if transforms is not None:
- extra_body["transforms"] = transforms
- if models is not None:
- extra_body["models"] = models
- if route is not None:
- extra_body["route"] = route
- optional_params["extra_body"] = (
- extra_body # openai client supports `extra_body` param
+ optional_params = litellm.OpenrouterConfig().map_openai_params(
+ non_default_params=non_default_params,
+ optional_params=optional_params,
+ model=model,
+ drop_params=(
+ drop_params
+ if drop_params is not None and isinstance(drop_params, bool)
+ else False
+ ),
)
+
elif custom_llm_provider == "watsonx":
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
@@ -3727,7 +3689,11 @@ def get_optional_params( # noqa: PLR0915
optional_params=optional_params,
model=model,
api_version=api_version, # type: ignore
- drop_params=drop_params,
+ drop_params=(
+ drop_params
+ if drop_params is not None and isinstance(drop_params, bool)
+ else False
+ ),
)
else: # assume passing in params for text-completion openai
supported_params = get_supported_openai_params(
@@ -6271,7 +6237,7 @@ from litellm.llms.base_llm.transformation import BaseConfig
class ProviderConfigManager:
@staticmethod
- def get_provider_chat_config(
+ def get_provider_chat_config( # noqa: PLR0915
model: str, provider: litellm.LlmProviders
) -> BaseConfig:
"""
@@ -6333,6 +6299,60 @@ class ProviderConfigManager:
return litellm.LMStudioChatConfig()
elif litellm.LlmProviders.GALADRIEL == provider:
return litellm.GaladrielChatConfig()
+ elif litellm.LlmProviders.REPLICATE == provider:
+ return litellm.ReplicateConfig()
+ elif litellm.LlmProviders.HUGGINGFACE == provider:
+ return litellm.HuggingfaceConfig()
+ elif litellm.LlmProviders.TOGETHER_AI == provider:
+ return litellm.TogetherAIConfig()
+ elif litellm.LlmProviders.OPENROUTER == provider:
+ return litellm.OpenrouterConfig()
+ elif litellm.LlmProviders.GEMINI == provider:
+ return litellm.GoogleAIStudioGeminiConfig()
+ elif (
+ litellm.LlmProviders.AI21 == provider
+ or litellm.LlmProviders.AI21_CHAT == provider
+ ):
+ return litellm.AI21ChatConfig()
+ elif litellm.LlmProviders.AZURE == provider:
+ return litellm.AzureOpenAIConfig()
+ elif litellm.LlmProviders.AZURE_AI == provider:
+ return litellm.AzureAIStudioConfig()
+ elif litellm.LlmProviders.AZURE_TEXT == provider:
+ return litellm.AzureOpenAITextConfig()
+ elif litellm.LlmProviders.HOSTED_VLLM == provider:
+ return litellm.HostedVLLMChatConfig()
+ elif litellm.LlmProviders.NLP_CLOUD == provider:
+ return litellm.NLPCloudConfig()
+ elif litellm.LlmProviders.OOBABOOGA == provider:
+ return litellm.OobaboogaConfig()
+ elif litellm.LlmProviders.OLLAMA_CHAT == provider:
+ return litellm.OllamaChatConfig()
+ elif litellm.LlmProviders.DEEPINFRA == provider:
+ return litellm.DeepInfraConfig()
+ elif litellm.LlmProviders.PERPLEXITY == provider:
+ return litellm.PerplexityChatConfig()
+ elif (
+ litellm.LlmProviders.MISTRAL == provider
+ or litellm.LlmProviders.CODESTRAL == provider
+ ):
+ return litellm.MistralConfig()
+ elif litellm.LlmProviders.NVIDIA_NIM == provider:
+ return litellm.NvidiaNimConfig()
+ elif litellm.LlmProviders.CEREBRAS == provider:
+ return litellm.CerebrasConfig()
+ elif litellm.LlmProviders.VOLCENGINE == provider:
+ return litellm.VolcEngineConfig()
+ elif litellm.LlmProviders.TEXT_COMPLETION_CODESTRAL == provider:
+ return litellm.MistralTextCompletionConfig()
+ elif litellm.LlmProviders.SAMBANOVA == provider:
+ return litellm.SambanovaConfig()
+ elif litellm.LlmProviders.MARITALK == provider:
+ return litellm.MaritalkConfig()
+ elif litellm.LlmProviders.CLOUDFLARE == provider:
+ return litellm.CloudflareChatConfig()
+ elif litellm.LlmProviders.ANTHROPIC_TEXT == provider:
+ return litellm.AnthropicTextConfig()
elif litellm.LlmProviders.VLLM == provider:
return litellm.VLLMConfig()
elif litellm.LlmProviders.OLLAMA == provider:
diff --git a/tests/llm_translation/test_max_completion_tokens.py b/tests/llm_translation/test_max_completion_tokens.py
index 4452bd0fc9..1676c912b8 100644
--- a/tests/llm_translation/test_max_completion_tokens.py
+++ b/tests/llm_translation/test_max_completion_tokens.py
@@ -168,12 +168,17 @@ def test_all_model_configs():
drop_params=False,
) == {"max_tokens": 10}
- from litellm.llms.huggingface_restapi import HuggingfaceConfig
+ from litellm.llms.huggingface.chat.handler import HuggingfaceConfig
- assert "max_completion_tokens" in HuggingfaceConfig().get_supported_openai_params()
- assert HuggingfaceConfig().map_openai_params({"max_completion_tokens": 10}, {}) == {
- "max_new_tokens": 10
- }
+ assert "max_completion_tokens" in HuggingfaceConfig().get_supported_openai_params(
+ model="llama3"
+ )
+ assert HuggingfaceConfig().map_openai_params(
+ non_default_params={"max_completion_tokens": 10},
+ optional_params={},
+ model="llama3",
+ drop_params=False,
+ ) == {"max_new_tokens": 10}
from litellm.llms.nvidia_nim.chat import NvidiaNimConfig
@@ -184,15 +189,19 @@ def test_all_model_configs():
model="llama3",
non_default_params={"max_completion_tokens": 10},
optional_params={},
+ drop_params=False,
) == {"max_tokens": 10}
from litellm.llms.ollama_chat import OllamaChatConfig
- assert "max_completion_tokens" in OllamaChatConfig().get_supported_openai_params()
+ assert "max_completion_tokens" in OllamaChatConfig().get_supported_openai_params(
+ model="llama3"
+ )
assert OllamaChatConfig().map_openai_params(
model="llama3",
non_default_params={"max_completion_tokens": 10},
optional_params={},
+ drop_params=False,
) == {"num_predict": 10}
from litellm.llms.predibase import PredibaseConfig
@@ -207,11 +216,13 @@ def test_all_model_configs():
assert (
"max_completion_tokens"
- in MistralTextCompletionConfig().get_supported_openai_params()
+ in MistralTextCompletionConfig().get_supported_openai_params(model="llama3")
)
assert MistralTextCompletionConfig().map_openai_params(
- {"max_completion_tokens": 10},
- {},
+ model="llama3",
+ non_default_params={"max_completion_tokens": 10},
+ optional_params={},
+ drop_params=False,
) == {"max_tokens": 10}
from litellm.llms.volcengine import VolcEngineConfig
@@ -223,9 +234,10 @@ def test_all_model_configs():
model="llama3",
non_default_params={"max_completion_tokens": 10},
optional_params={},
+ drop_params=False,
) == {"max_tokens": 10}
- from litellm.llms.ai21.chat import AI21ChatConfig
+ from litellm.llms.ai21.chat.transformation import AI21ChatConfig
assert "max_completion_tokens" in AI21ChatConfig().get_supported_openai_params(
"jamba-1.5-mini@001"
@@ -234,11 +246,14 @@ def test_all_model_configs():
model="jamba-1.5-mini@001",
non_default_params={"max_completion_tokens": 10},
optional_params={},
+ drop_params=False,
) == {"max_tokens": 10}
from litellm.llms.azure.chat.gpt_transformation import AzureOpenAIConfig
- assert "max_completion_tokens" in AzureOpenAIConfig().get_supported_openai_params()
+ assert "max_completion_tokens" in AzureOpenAIConfig().get_supported_openai_params(
+ model="gpt-3.5-turbo"
+ )
assert AzureOpenAIConfig().map_openai_params(
model="gpt-3.5-turbo",
non_default_params={"max_completion_tokens": 10},
@@ -266,11 +281,13 @@ def test_all_model_configs():
assert (
"max_completion_tokens"
- in MistralTextCompletionConfig().get_supported_openai_params()
+ in MistralTextCompletionConfig().get_supported_openai_params(model="llama3")
)
assert MistralTextCompletionConfig().map_openai_params(
+ model="llama3",
non_default_params={"max_completion_tokens": 10},
optional_params={},
+ drop_params=False,
) == {"max_tokens": 10}
from litellm.llms.bedrock.common_utils import (
@@ -341,7 +358,9 @@ def test_all_model_configs():
assert (
"max_completion_tokens"
- in GoogleAIStudioGeminiConfig().get_supported_openai_params()
+ in GoogleAIStudioGeminiConfig().get_supported_openai_params(
+ model="gemini-1.0-pro"
+ )
)
assert GoogleAIStudioGeminiConfig().map_openai_params(
@@ -351,7 +370,9 @@ def test_all_model_configs():
drop_params=False,
) == {"max_output_tokens": 10}
- assert "max_completion_tokens" in VertexGeminiConfig().get_supported_openai_params()
+ assert "max_completion_tokens" in VertexGeminiConfig().get_supported_openai_params(
+ model="gemini-1.0-pro"
+ )
assert VertexGeminiConfig().map_openai_params(
model="gemini-1.0-pro",
diff --git a/tests/llm_translation/test_optional_params.py b/tests/llm_translation/test_optional_params.py
index 2ada0a8bb7..8acfbf0863 100644
--- a/tests/llm_translation/test_optional_params.py
+++ b/tests/llm_translation/test_optional_params.py
@@ -190,9 +190,10 @@ def test_databricks_optional_params():
custom_llm_provider="databricks",
max_tokens=10,
temperature=0.2,
+ stream=True,
)
print(f"optional_params: {optional_params}")
- assert len(optional_params) == 2
+ assert len(optional_params) == 3
assert "user" not in optional_params
diff --git a/tests/llm_translation/test_prompt_factory.py b/tests/llm_translation/test_prompt_factory.py
index d8cf191f6a..005c62113a 100644
--- a/tests/llm_translation/test_prompt_factory.py
+++ b/tests/llm_translation/test_prompt_factory.py
@@ -449,8 +449,12 @@ def test_azure_tool_call_invoke_helper():
{"role": "assistant", "function_call": {"name": "get_weather"}},
]
- transformed_messages = litellm.AzureOpenAIConfig.transform_request(
- model="gpt-4o", messages=messages, optional_params={}
+ transformed_messages = litellm.AzureOpenAIConfig().transform_request(
+ model="gpt-4o",
+ messages=messages,
+ optional_params={},
+ litellm_params={},
+ headers={},
)
assert transformed_messages["messages"] == [
diff --git a/tests/local_testing/test_batch_completions.py b/tests/local_testing/test_batch_completions.py
index 87cb88e44d..e8fef5249f 100644
--- a/tests/local_testing/test_batch_completions.py
+++ b/tests/local_testing/test_batch_completions.py
@@ -69,7 +69,7 @@ def test_batch_completions_models():
def test_batch_completion_models_all_responses():
try:
responses = batch_completion_models_all_responses(
- models=["j2-light", "claude-3-haiku-20240307"],
+ models=["gemini/gemini-1.5-flash", "claude-3-haiku-20240307"],
messages=[{"role": "user", "content": "write a poem"}],
max_tokens=10,
)
diff --git a/tests/local_testing/test_completion.py b/tests/local_testing/test_completion.py
index 0f8addf775..833dbb8ffa 100644
--- a/tests/local_testing/test_completion.py
+++ b/tests/local_testing/test_completion.py
@@ -1606,30 +1606,33 @@ HF Tests we should pass
#####################################################
#####################################################
# Test util to sort models to TGI, conv, None
+from litellm.llms.huggingface.chat.transformation import HuggingfaceChatConfig
+
+
def test_get_hf_task_for_model():
model = "glaiveai/glaive-coder-7b"
- model_type, _ = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
+ model_type, _ = HuggingfaceChatConfig().get_hf_task_for_model(model)
print(f"model:{model}, model type: {model_type}")
assert model_type == "text-generation-inference"
model = "meta-llama/Llama-2-7b-hf"
- model_type, _ = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
+ model_type, _ = HuggingfaceChatConfig().get_hf_task_for_model(model)
print(f"model:{model}, model type: {model_type}")
assert model_type == "text-generation-inference"
model = "facebook/blenderbot-400M-distill"
- model_type, _ = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
+ model_type, _ = HuggingfaceChatConfig().get_hf_task_for_model(model)
print(f"model:{model}, model type: {model_type}")
assert model_type == "conversational"
model = "facebook/blenderbot-3B"
- model_type, _ = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
+ model_type, _ = HuggingfaceChatConfig().get_hf_task_for_model(model)
print(f"model:{model}, model type: {model_type}")
assert model_type == "conversational"
# neither Conv or None
model = "roneneldan/TinyStories-3M"
- model_type, _ = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
+ model_type, _ = HuggingfaceChatConfig().get_hf_task_for_model(model)
print(f"model:{model}, model type: {model_type}")
assert model_type == "text-generation"
@@ -1717,14 +1720,17 @@ def tgi_mock_post(url, **kwargs):
def test_hf_test_completion_tgi():
litellm.set_verbose = True
try:
+ client = HTTPHandler()
- with patch("requests.post", side_effect=tgi_mock_post) as mock_client:
+ with patch.object(client, "post", side_effect=tgi_mock_post) as mock_client:
response = completion(
model="huggingface/HuggingFaceH4/zephyr-7b-beta",
messages=[{"content": "Hello, how are you?", "role": "user"}],
max_tokens=10,
wait_for_model=True,
+ client=client,
)
+ mock_client.assert_called_once()
# Add any assertions-here to check the response
print(response)
assert "options" in mock_client.call_args.kwargs["data"]
@@ -1862,13 +1868,15 @@ def mock_post(url, **kwargs):
def test_hf_classifier_task():
try:
- with patch("requests.post", side_effect=mock_post):
+ client = HTTPHandler()
+ with patch.object(client, "post", side_effect=mock_post):
litellm.set_verbose = True
user_message = "I like you. I love you"
messages = [{"content": user_message, "role": "user"}]
response = completion(
model="huggingface/text-classification/shahrukhx01/question-vs-statement-classifier",
messages=messages,
+ client=client,
)
print(f"response: {response}")
assert isinstance(response, litellm.ModelResponse)
@@ -3096,19 +3104,20 @@ async def test_completion_replicate_llama3(sync_mode):
response = completion(
model=model_name,
messages=messages,
+ max_tokens=10,
)
else:
response = await litellm.acompletion(
model=model_name,
messages=messages,
+ max_tokens=10,
)
print(f"ASYNC REPLICATE RESPONSE - {response}")
- print(response)
+ print(f"REPLICATE RESPONSE - {response}")
# Add any assertions here to check the response
assert isinstance(response, litellm.ModelResponse)
+ assert len(response.choices[0].message.content.strip()) > 0
response_format_tests(response=response)
- except litellm.APIError as e:
- pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@@ -3745,22 +3754,6 @@ def test_mistral_anyscale_stream():
# pytest.fail(f"Error occurred: {e}")
-#### Test A121 ###################
-@pytest.mark.skip(reason="Local test")
-def test_completion_ai21():
- print("running ai21 j2light test")
- litellm.set_verbose = True
- model_name = "j2-light"
- try:
- response = completion(
- model=model_name, messages=messages, max_tokens=100, temperature=0.8
- )
- # Add any assertions here to check the response
- print(response)
- except Exception as e:
- pytest.fail(f"Error occurred: {e}")
-
-
# test_completion_ai21()
# test_completion_ai21()
## test deep infra
diff --git a/tests/local_testing/test_completion_cost.py b/tests/local_testing/test_completion_cost.py
index cce8d6d670..66c3da8782 100644
--- a/tests/local_testing/test_completion_cost.py
+++ b/tests/local_testing/test_completion_cost.py
@@ -165,10 +165,10 @@ def test_get_gpt3_tokens():
# test_get_gpt3_tokens()
-def test_get_palm_tokens():
+def test_get_gemini_tokens():
# # 🦄🦄🦄🦄🦄🦄🦄🦄
- max_tokens = get_max_tokens("palm/chat-bison")
- assert max_tokens == 4096
+ max_tokens = get_max_tokens("gemini/gemini-1.5-flash")
+ assert max_tokens == 8192
print(max_tokens)
diff --git a/tests/local_testing/test_completion_with_retries.py b/tests/local_testing/test_completion_with_retries.py
index efb66c40c6..01b0cf3288 100644
--- a/tests/local_testing/test_completion_with_retries.py
+++ b/tests/local_testing/test_completion_with_retries.py
@@ -29,19 +29,6 @@ def logger_fn(user_model_dict):
pass
-# completion with num retries + impact on exception mapping
-def test_completion_with_num_retries():
- try:
- response = completion(
- model="j2-ultra",
- messages=[{"messages": "vibe", "bad": "message"}],
- num_retries=2,
- )
- pytest.fail(f"Unmapped exception occurred")
- except Exception as e:
- pass
-
-
# test_completion_with_num_retries()
def test_completion_with_0_num_retries():
try:
diff --git a/tests/local_testing/test_config.py b/tests/local_testing/test_config.py
index 353fcd28eb..eec8be5115 100644
--- a/tests/local_testing/test_config.py
+++ b/tests/local_testing/test_config.py
@@ -290,35 +290,46 @@ async def test_add_and_delete_deployments(llm_router, model_list_flag_value):
assert len(llm_router.model_list) == len(model_list) + prev_llm_router_val
-def test_provider_config_manager():
- from litellm import LITELLM_CHAT_PROVIDERS, LlmProviders
- from litellm.utils import ProviderConfigManager
- from litellm.llms.base_llm.transformation import BaseConfig
- from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
+from litellm import LITELLM_CHAT_PROVIDERS, LlmProviders
+from litellm.utils import ProviderConfigManager
+from litellm.llms.base_llm.transformation import BaseConfig
- for provider in LITELLM_CHAT_PROVIDERS:
- if provider == LlmProviders.TRITON or provider == LlmProviders.PREDIBASE:
- continue
- assert isinstance(
- ProviderConfigManager.get_provider_chat_config(
- model="gpt-3.5-turbo", provider=LlmProviders(provider)
- ),
- BaseConfig,
- ), f"Provider {provider} is not a subclass of BaseConfig"
- config = ProviderConfigManager.get_provider_chat_config(
- model="gpt-3.5-turbo", provider=LlmProviders(provider)
- )
-
- if (
- provider != litellm.LlmProviders.OPENAI
- and provider != litellm.LlmProviders.OPENAI_LIKE
- and provider != litellm.LlmProviders.CUSTOM_OPENAI
- ):
- assert (
- config.__class__.__name__ != "OpenAIGPTConfig"
- ), f"Provider {provider} is an instance of OpenAIGPTConfig"
+def _check_provider_config(config: BaseConfig, provider: LlmProviders):
+ assert isinstance(
+ config,
+ BaseConfig,
+ ), f"Provider {provider} is not a subclass of BaseConfig. Got={config}"
+ if (
+ provider != litellm.LlmProviders.OPENAI
+ and provider != litellm.LlmProviders.OPENAI_LIKE
+ and provider != litellm.LlmProviders.CUSTOM_OPENAI
+ ):
assert (
- "_abc_impl" not in config.get_config()
- ), f"Provider {provider} has _abc_impl"
+ config.__class__.__name__ != "OpenAIGPTConfig"
+ ), f"Provider {provider} is an instance of OpenAIGPTConfig"
+
+ assert "_abc_impl" not in config.get_config(), f"Provider {provider} has _abc_impl"
+
+
+# def test_provider_config_manager():
+# from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
+
+# for provider in LITELLM_CHAT_PROVIDERS:
+# if (
+# provider == LlmProviders.VERTEX_AI
+# or provider == LlmProviders.VERTEX_AI_BETA
+# or provider == LlmProviders.BEDROCK
+# or provider == LlmProviders.BASETEN
+# or provider == LlmProviders.SAGEMAKER
+# or provider == LlmProviders.SAGEMAKER_CHAT
+# or provider == LlmProviders.VLLM
+# or provider == LlmProviders.PETALS
+# or provider == LlmProviders.OLLAMA
+# ):
+# continue
+# config = ProviderConfigManager.get_provider_chat_config(
+# model="gpt-3.5-turbo", provider=LlmProviders(provider)
+# )
+# _check_provider_config(config, provider)
diff --git a/tests/local_testing/test_gcs_bucket.py b/tests/local_testing/test_gcs_bucket.py
index 4d431b662e..0813596810 100644
--- a/tests/local_testing/test_gcs_bucket.py
+++ b/tests/local_testing/test_gcs_bucket.py
@@ -522,6 +522,7 @@ async def test_basic_gcs_logging_per_request_with_no_litellm_callback_set():
)
+@pytest.mark.flaky(retries=5, delay=3)
@pytest.mark.asyncio
async def test_get_gcs_logging_config_without_service_account():
"""
diff --git a/tests/local_testing/test_provider_specific_config.py b/tests/local_testing/test_provider_specific_config.py
index 1f1ccaef88..dc6e62e8ca 100644
--- a/tests/local_testing/test_provider_specific_config.py
+++ b/tests/local_testing/test_provider_specific_config.py
@@ -167,51 +167,6 @@ def cohere_test_completion():
# cohere_test_completion()
-# AI21
-
-
-def ai21_test_completion():
- litellm.AI21Config(maxTokens=10)
- litellm.set_verbose = True
- try:
- # OVERRIDE WITH DYNAMIC MAX TOKENS
- response_1 = litellm.completion(
- model="j2-mid",
- messages=[
- {
- "content": "Hello, how are you? Be as verbose as possible",
- "role": "user",
- }
- ],
- max_tokens=100,
- )
- response_1_text = response_1.choices[0].message.content
- print(f"response_1_text: {response_1_text}")
-
- # USE CONFIG TOKENS
- response_2 = litellm.completion(
- model="j2-mid",
- messages=[
- {
- "content": "Hello, how are you? Be as verbose as possible",
- "role": "user",
- }
- ],
- )
- response_2_text = response_2.choices[0].message.content
- print(f"response_2_text: {response_2_text}")
-
- assert len(response_2_text) < len(response_1_text)
-
- response_3 = litellm.completion(
- model="j2-light",
- messages=[{"content": "Hello, how are you?", "role": "user"}],
- n=2,
- )
- assert len(response_3.choices) > 1
- except Exception as e:
- pytest.fail(f"Error occurred: {e}")
-
# ai21_test_completion()
diff --git a/tests/local_testing/test_router_provider_budgets.py b/tests/local_testing/test_router_provider_budgets.py
index 3f3cecc779..f360e0dddd 100644
--- a/tests/local_testing/test_router_provider_budgets.py
+++ b/tests/local_testing/test_router_provider_budgets.py
@@ -47,6 +47,7 @@ def cleanup_redis():
print(f"Error cleaning up Redis: {str(e)}")
+@pytest.mark.flaky(retries=6, delay=2)
@pytest.mark.asyncio
async def test_provider_budgets_e2e_test():
"""
@@ -106,7 +107,7 @@ async def test_provider_budgets_e2e_test():
print("response.hidden_params", response._hidden_params)
- await asyncio.sleep(0.5)
+ await asyncio.sleep(1)
assert response._hidden_params.get("custom_llm_provider") == "azure"
diff --git a/tests/local_testing/test_streaming.py b/tests/local_testing/test_streaming.py
index 30d9d3e0f6..02ac8cb91b 100644
--- a/tests/local_testing/test_streaming.py
+++ b/tests/local_testing/test_streaming.py
@@ -1931,66 +1931,11 @@ async def test_completion_watsonx_stream():
# raise Exception("Empty response received")
# except Exception:
# pytest.fail(f"error occurred: {traceback.format_exc()}")
-# test_maritalk_streaming()
-# test on openai completion call
-
-
-# # test on ai21 completion call
-def ai21_completion_call():
- try:
- messages = [
- {
- "role": "system",
- "content": "You are an all-knowing oracle",
- },
- {"role": "user", "content": "What is the meaning of the Universe?"},
- ]
- response = completion(
- model="j2-ultra", messages=messages, stream=True, max_tokens=500
- )
- print(f"response: {response}")
- has_finished = False
- complete_response = ""
- start_time = time.time()
- for idx, chunk in enumerate(response):
- chunk, finished = streaming_format_tests(idx, chunk)
- has_finished = finished
- complete_response += chunk
- if finished:
- break
- if has_finished is False:
- raise Exception("finished reason missing from final chunk")
- if complete_response.strip() == "":
- raise Exception("Empty response received")
- print(f"completion_response: {complete_response}")
- except Exception:
- pytest.fail(f"error occurred: {traceback.format_exc()}")
# ai21_completion_call()
-def ai21_completion_call_bad_key():
- try:
- api_key = "bad-key"
- response = completion(
- model="j2-ultra", messages=messages, stream=True, api_key=api_key
- )
- print(f"response: {response}")
- complete_response = ""
- start_time = time.time()
- for idx, chunk in enumerate(response):
- chunk, finished = streaming_format_tests(idx, chunk)
- if finished:
- break
- complete_response += chunk
- if complete_response.strip() == "":
- raise Exception("Empty response received")
- print(f"completion_response: {complete_response}")
- except Exception:
- pytest.fail(f"error occurred: {traceback.format_exc()}")
-
-
# ai21_completion_call_bad_key()
@@ -2418,34 +2363,6 @@ def test_completion_openai_with_functions():
#### Test Async streaming ####
-# # test on ai21 completion call
-async def ai21_async_completion_call():
- try:
- response = completion(
- model="j2-ultra", messages=messages, stream=True, logger_fn=logger_fn
- )
- print(f"response: {response}")
- complete_response = ""
- start_time = time.time()
- # Change for loop to async for loop
- idx = 0
- async for chunk in response:
- chunk, finished = streaming_format_tests(idx, chunk)
- if finished:
- break
- complete_response += chunk
- idx += 1
- if complete_response.strip() == "":
- raise Exception("Empty response received")
- print(f"complete response: {complete_response}")
- except Exception:
- print(f"error occurred: {traceback.format_exc()}")
- pass
-
-
-# asyncio.run(ai21_async_completion_call())
-
-
async def completion_call():
try:
response = completion(
diff --git a/tests/local_testing/test_text_completion.py b/tests/local_testing/test_text_completion.py
index 5d94820dc1..19588a9720 100644
--- a/tests/local_testing/test_text_completion.py
+++ b/tests/local_testing/test_text_completion.py
@@ -3934,6 +3934,7 @@ def test_completion_text_003_prompt_array():
##### hugging face tests
+@pytest.mark.skip(reason="local test")
def test_completion_hf_prompt_array():
try:
litellm.set_verbose = True
diff --git a/tests/local_testing/test_utils.py b/tests/local_testing/test_utils.py
index 1682235254..76f713cdc2 100644
--- a/tests/local_testing/test_utils.py
+++ b/tests/local_testing/test_utils.py
@@ -437,8 +437,8 @@ def test_token_counter():
print(tokens)
assert tokens > 0
- tokens = token_counter(model="palm/chat-bison", messages=messages)
- print("palm/chat-bison")
+ tokens = token_counter(model="gemini/chat-bison", messages=messages)
+ print("gemini/chat-bison")
print(tokens)
assert tokens > 0
@@ -465,7 +465,7 @@ def test_token_counter():
("azure/gpt-4-1106-preview", True),
("groq/gemma-7b-it", True),
("anthropic.claude-instant-v1", False),
- ("palm/chat-bison", False),
+ ("gemini/gemini-1.5-flash", True),
],
)
def test_supports_function_calling(model, expected_bool):