Litellm merge pr (#7161)

* build: merge branch

* test: fix openai naming

* fix(main.py): fix openai renaming

* style: ignore function length for config factory

* fix(sagemaker/): fix routing logic

* fix: fix imports

* fix: fix override
This commit is contained in:
Krish Dholakia 2024-12-10 22:49:26 -08:00 committed by GitHub
parent d5aae81c6d
commit 350cfc36f7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
88 changed files with 3617 additions and 4421 deletions

View file

@ -1,43 +0,0 @@
# PaLM API - Google
:::warning
Warning: [The PaLM API is decomissioned by Google](https://ai.google.dev/palm_docs/deprecation) The PaLM API is scheduled to be decomissioned in October 2024. Please upgrade to the Gemini API or Vertex AI API
:::
## Pre-requisites
* `pip install -q google-generativeai`
## Sample Usage
```python
from litellm import completion
import os
os.environ['PALM_API_KEY'] = ""
response = completion(
model="palm/chat-bison",
messages=[{"role": "user", "content": "write code for saying hi from LiteLLM"}]
)
```
## Sample Usage - Streaming
```python
from litellm import completion
import os
os.environ['PALM_API_KEY'] = ""
response = completion(
model="palm/chat-bison",
messages=[{"role": "user", "content": "write code for saying hi from LiteLLM"}],
stream=True
)
for chunk in response:
print(chunk)
```
## Chat Models
| Model Name | Function Call | Required OS Variables |
|------------------|--------------------------------------|-------------------------|
| chat-bison | `completion('palm/chat-bison', messages)` | `os.environ['PALM_API_KEY']` |

View file

@ -190,11 +190,9 @@ const sidebars = {
"providers/aleph_alpha", "providers/aleph_alpha",
"providers/baseten", "providers/baseten",
"providers/openrouter", "providers/openrouter",
"providers/palm",
"providers/sambanova", "providers/sambanova",
"providers/custom_llm_server", "providers/custom_llm_server",
"providers/petals", "providers/petals",
], ],
}, },
{ {

View file

@ -601,6 +601,7 @@ openai_compatible_providers: List = [
"cerebras", "cerebras",
"sambanova", "sambanova",
"ai21_chat", "ai21_chat",
"ai21",
"volcengine", "volcengine",
"codestral", "codestral",
"deepseek", "deepseek",
@ -853,7 +854,6 @@ class LlmProviders(str, Enum):
OPENROUTER = "openrouter" OPENROUTER = "openrouter"
VERTEX_AI = "vertex_ai" VERTEX_AI = "vertex_ai"
VERTEX_AI_BETA = "vertex_ai_beta" VERTEX_AI_BETA = "vertex_ai_beta"
PALM = "palm"
GEMINI = "gemini" GEMINI = "gemini"
AI21 = "ai21" AI21 = "ai21"
BASETEN = "baseten" BASETEN = "baseten"
@ -871,7 +871,6 @@ class LlmProviders(str, Enum):
OLLAMA_CHAT = "ollama_chat" OLLAMA_CHAT = "ollama_chat"
DEEPINFRA = "deepinfra" DEEPINFRA = "deepinfra"
PERPLEXITY = "perplexity" PERPLEXITY = "perplexity"
ANYSCALE = "anyscale"
MISTRAL = "mistral" MISTRAL = "mistral"
GROQ = "groq" GROQ = "groq"
NVIDIA_NIM = "nvidia_nim" NVIDIA_NIM = "nvidia_nim"
@ -1057,10 +1056,15 @@ from .types.utils import ImageObject
from .llms.custom_llm import CustomLLM from .llms.custom_llm import CustomLLM
from .llms.openai_like.chat.handler import OpenAILikeChatConfig from .llms.openai_like.chat.handler import OpenAILikeChatConfig
from .llms.galadriel.chat.transformation import GaladrielChatConfig from .llms.galadriel.chat.transformation import GaladrielChatConfig
from .llms.huggingface_restapi import HuggingfaceConfig
from .llms.empower.chat.transformation import EmpowerChatConfig
from .llms.github.chat.transformation import GithubChatConfig from .llms.github.chat.transformation import GithubChatConfig
from .llms.anthropic.chat.handler import AnthropicConfig from .llms.empower.chat.transformation import EmpowerChatConfig
from .llms.huggingface.chat.transformation import (
HuggingfaceChatConfig as HuggingfaceConfig,
)
from .llms.oobabooga.chat.transformation import OobaboogaConfig
from .llms.maritalk import MaritalkConfig
from .llms.openrouter.chat.transformation import OpenrouterConfig
from .llms.anthropic.chat.transformation import AnthropicConfig
from .llms.anthropic.experimental_pass_through.transformation import ( from .llms.anthropic.experimental_pass_through.transformation import (
AnthropicExperimentalPassThroughConfig, AnthropicExperimentalPassThroughConfig,
) )
@ -1069,24 +1073,26 @@ from .llms.anthropic.completion.transformation import AnthropicTextConfig
from .llms.databricks.chat.transformation import DatabricksConfig from .llms.databricks.chat.transformation import DatabricksConfig
from .llms.databricks.embed.transformation import DatabricksEmbeddingConfig from .llms.databricks.embed.transformation import DatabricksEmbeddingConfig
from .llms.predibase import PredibaseConfig from .llms.predibase import PredibaseConfig
from .llms.replicate import ReplicateConfig from .llms.replicate.chat.transformation import ReplicateConfig
from .llms.cohere.completion.transformation import CohereTextConfig as CohereConfig from .llms.cohere.completion.transformation import CohereTextConfig as CohereConfig
from .llms.clarifai.chat.transformation import ClarifaiConfig from .llms.clarifai.chat.transformation import ClarifaiConfig
from .llms.cloudflare.chat.transformation import CloudflareChatConfig from .llms.ai21.chat.transformation import AI21ChatConfig, AI21ChatConfig as AI21Config
from .llms.ai21.completion import AI21Config
from .llms.ai21.chat import AI21ChatConfig
from .llms.together_ai.chat import TogetherAIConfig from .llms.together_ai.chat import TogetherAIConfig
from .llms.palm import PalmConfig from .llms.cloudflare.chat.transformation import CloudflareChatConfig
from .llms.gemini import GeminiConfig from .llms.deprecated_providers.palm import (
from .llms.nlp_cloud import NLPCloudConfig PalmConfig,
) # here to prevent breaking changes
from .llms.nlp_cloud.chat.handler import NLPCloudConfig
from .llms.aleph_alpha import AlephAlphaConfig from .llms.aleph_alpha import AlephAlphaConfig
from .llms.petals import PetalsConfig from .llms.petals import PetalsConfig
from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexGeminiConfig, VertexGeminiConfig,
GoogleAIStudioGeminiConfig, GoogleAIStudioGeminiConfig,
VertexAIConfig, VertexAIConfig,
GoogleAIStudioGeminiConfig as GeminiConfig,
) )
from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings.transformation import ( from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings.transformation import (
VertexAITextEmbeddingConfig, VertexAITextEmbeddingConfig,
) )
@ -1107,7 +1113,6 @@ from .llms.ollama.completion.transformation import OllamaConfig
from .llms.sagemaker.completion.transformation import SagemakerConfig from .llms.sagemaker.completion.transformation import SagemakerConfig
from .llms.sagemaker.chat.transformation import SagemakerChatConfig from .llms.sagemaker.chat.transformation import SagemakerChatConfig
from .llms.ollama_chat import OllamaChatConfig from .llms.ollama_chat import OllamaChatConfig
from .llms.maritalk import MaritTalkConfig
from .llms.bedrock.chat.invoke_handler import ( from .llms.bedrock.chat.invoke_handler import (
AmazonCohereChatConfig, AmazonCohereChatConfig,
AmazonConverseConfig, AmazonConverseConfig,
@ -1134,11 +1139,8 @@ from .llms.bedrock.embed.amazon_titan_v2_transformation import (
) )
from .llms.cohere.chat.transformation import CohereChatConfig from .llms.cohere.chat.transformation import CohereChatConfig
from .llms.bedrock.embed.cohere_transformation import BedrockCohereEmbeddingConfig from .llms.bedrock.embed.cohere_transformation import BedrockCohereEmbeddingConfig
from .llms.openai.openai import ( from .llms.openai.openai import OpenAIConfig, MistralEmbeddingConfig
OpenAIConfig, from .llms.deepinfra.chat.transformation import DeepInfraConfig
MistralEmbeddingConfig,
DeepInfraConfig,
)
from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig
from .llms.groq.chat.transformation import GroqChatConfig from .llms.groq.chat.transformation import GroqChatConfig
from .llms.azure_ai.chat.transformation import AzureAIStudioConfig from .llms.azure_ai.chat.transformation import AzureAIStudioConfig
@ -1167,7 +1169,7 @@ nvidiaNimEmbeddingConfig = NvidiaNimEmbeddingConfig()
from .llms.cerebras.chat import CerebrasConfig from .llms.cerebras.chat import CerebrasConfig
from .llms.sambanova.chat import SambanovaConfig from .llms.sambanova.chat import SambanovaConfig
from .llms.ai21.chat import AI21ChatConfig from .llms.ai21.chat.transformation import AI21ChatConfig
from .llms.fireworks_ai.chat.transformation import FireworksAIConfig from .llms.fireworks_ai.chat.transformation import FireworksAIConfig
from .llms.fireworks_ai.embed.fireworks_ai_transformation import ( from .llms.fireworks_ai.embed.fireworks_ai_transformation import (
FireworksAIEmbeddingConfig, FireworksAIEmbeddingConfig,
@ -1183,6 +1185,7 @@ from .llms.azure.azure import (
) )
from .llms.azure.chat.gpt_transformation import AzureOpenAIConfig from .llms.azure.chat.gpt_transformation import AzureOpenAIConfig
from .llms.azure.completion.transformation import AzureOpenAITextConfig
from .llms.hosted_vllm.chat.transformation import HostedVLLMChatConfig from .llms.hosted_vllm.chat.transformation import HostedVLLMChatConfig
from .llms.vllm.completion.transformation import VLLMConfig from .llms.vllm.completion.transformation import VLLMConfig
from .llms.deepseek.chat.transformation import DeepSeekChatConfig from .llms.deepseek.chat.transformation import DeepSeekChatConfig

View file

@ -3,54 +3,51 @@ DEFAULT_BATCH_SIZE = 512
DEFAULT_FLUSH_INTERVAL_SECONDS = 5 DEFAULT_FLUSH_INTERVAL_SECONDS = 5
DEFAULT_MAX_RETRIES = 2 DEFAULT_MAX_RETRIES = 2
LITELLM_CHAT_PROVIDERS = [ LITELLM_CHAT_PROVIDERS = [
# "openai", "openai",
# "openai_like", "openai_like",
# "xai", "xai",
# "custom_openai", "custom_openai",
# "text-completion-openai", "text-completion-openai",
# "cohere", "cohere",
# "cohere_chat", "cohere_chat",
# "clarifai", "clarifai",
# "anthropic", "anthropic",
# "anthropic_text", "anthropic_text",
# "replicate", "replicate",
# "huggingface", "huggingface",
# "together_ai", "together_ai",
# "openrouter", "openrouter",
# "vertex_ai", "vertex_ai",
# "vertex_ai_beta", "vertex_ai_beta",
# "palm", "gemini",
# "gemini", "ai21",
# "ai21", "baseten",
# "baseten", "azure",
# "azure", "azure_text",
# "azure_text", "azure_ai",
# "azure_ai", "sagemaker",
# "sagemaker", "sagemaker_chat",
# "sagemaker_chat", "bedrock",
# "bedrock",
"vllm", "vllm",
# "nlp_cloud", "nlp_cloud",
# "petals", "petals",
# "oobabooga", "oobabooga",
"ollama", "ollama",
# "ollama_chat", "ollama_chat",
# "deepinfra", "deepinfra",
# "perplexity", "perplexity",
# "anyscale", "mistral",
# "mistral", "groq",
# "groq", "nvidia_nim",
# "nvidia_nim", "cerebras",
# "cerebras", "ai21_chat",
# "ai21_chat", "volcengine",
# "volcengine", "codestral",
# "codestral", "text-completion-codestral",
# "text-completion-codestral", "deepseek",
# "deepseek", "sambanova",
# "sambanova", "maritalk",
# "maritalk", "cloudflare",
# "voyage",
# "cloudflare",
"fireworks_ai", "fireworks_ai",
"friendliai", "friendliai",
"watsonx", "watsonx",

View file

@ -285,9 +285,7 @@ def get_llm_provider( # noqa: PLR0915
): ):
custom_llm_provider = "vertex_ai" custom_llm_provider = "vertex_ai"
## ai21 ## ai21
elif model in litellm.ai21_models: elif model in litellm.ai21_chat_models or model in litellm.ai21_models:
custom_llm_provider = "ai21"
elif model in litellm.ai21_chat_models:
custom_llm_provider = "ai21_chat" custom_llm_provider = "ai21_chat"
api_base = ( api_base = (
api_base api_base

View file

@ -31,7 +31,7 @@ def get_supported_openai_params( # noqa: PLR0915
elif custom_llm_provider == "ollama": elif custom_llm_provider == "ollama":
return litellm.OllamaConfig().get_supported_openai_params(model=model) return litellm.OllamaConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "ollama_chat": elif custom_llm_provider == "ollama_chat":
return litellm.OllamaChatConfig().get_supported_openai_params() return litellm.OllamaChatConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "anthropic": elif custom_llm_provider == "anthropic":
return litellm.AnthropicConfig().get_supported_openai_params(model=model) return litellm.AnthropicConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "fireworks_ai": elif custom_llm_provider == "fireworks_ai":
@ -50,7 +50,7 @@ def get_supported_openai_params( # noqa: PLR0915
return litellm.CerebrasConfig().get_supported_openai_params(model=model) return litellm.CerebrasConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "xai": elif custom_llm_provider == "xai":
return litellm.XAIChatConfig().get_supported_openai_params(model=model) return litellm.XAIChatConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "ai21_chat": elif custom_llm_provider == "ai21_chat" or custom_llm_provider == "ai21":
return litellm.AI21ChatConfig().get_supported_openai_params(model=model) return litellm.AI21ChatConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "volcengine": elif custom_llm_provider == "volcengine":
return litellm.VolcEngineConfig().get_supported_openai_params(model=model) return litellm.VolcEngineConfig().get_supported_openai_params(model=model)
@ -97,79 +97,50 @@ def get_supported_openai_params( # noqa: PLR0915
model=model model=model
) )
else: else:
return litellm.AzureOpenAIConfig().get_supported_openai_params() return litellm.AzureOpenAIConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "openrouter": elif custom_llm_provider == "openrouter":
return [ return litellm.OpenrouterConfig().get_supported_openai_params(model=model)
"temperature",
"top_p",
"frequency_penalty",
"presence_penalty",
"repetition_penalty",
"seed",
"max_tokens",
"logit_bias",
"logprobs",
"top_logprobs",
"response_format",
"stop",
"tools",
"tool_choice",
]
elif custom_llm_provider == "mistral" or custom_llm_provider == "codestral": elif custom_llm_provider == "mistral" or custom_llm_provider == "codestral":
# mistal and codestral api have the exact same params # mistal and codestral api have the exact same params
if request_type == "chat_completion": if request_type == "chat_completion":
return litellm.MistralConfig().get_supported_openai_params() return litellm.MistralConfig().get_supported_openai_params(model=model)
elif request_type == "embeddings": elif request_type == "embeddings":
return litellm.MistralEmbeddingConfig().get_supported_openai_params() return litellm.MistralEmbeddingConfig().get_supported_openai_params()
elif custom_llm_provider == "text-completion-codestral": elif custom_llm_provider == "text-completion-codestral":
return litellm.MistralTextCompletionConfig().get_supported_openai_params() return litellm.MistralTextCompletionConfig().get_supported_openai_params(
model=model
)
elif custom_llm_provider == "sambanova":
return litellm.SambanovaConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "replicate": elif custom_llm_provider == "replicate":
return [ return litellm.ReplicateConfig().get_supported_openai_params(model=model)
"stream",
"temperature",
"max_tokens",
"top_p",
"stop",
"seed",
"tools",
"tool_choice",
"functions",
"function_call",
]
elif custom_llm_provider == "huggingface": elif custom_llm_provider == "huggingface":
return litellm.HuggingfaceConfig().get_supported_openai_params() return litellm.HuggingfaceConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "jina_ai": elif custom_llm_provider == "jina_ai":
if request_type == "embeddings": if request_type == "embeddings":
return litellm.JinaAIEmbeddingConfig().get_supported_openai_params() return litellm.JinaAIEmbeddingConfig().get_supported_openai_params()
elif custom_llm_provider == "together_ai": elif custom_llm_provider == "together_ai":
return litellm.TogetherAIConfig().get_supported_openai_params(model=model) return litellm.TogetherAIConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "ai21":
return [
"stream",
"n",
"temperature",
"max_tokens",
"top_p",
"stop",
"frequency_penalty",
"presence_penalty",
]
elif custom_llm_provider == "databricks": elif custom_llm_provider == "databricks":
if request_type == "chat_completion": if request_type == "chat_completion":
return litellm.DatabricksConfig().get_supported_openai_params(model=model) return litellm.DatabricksConfig().get_supported_openai_params(model=model)
elif request_type == "embeddings": elif request_type == "embeddings":
return litellm.DatabricksEmbeddingConfig().get_supported_openai_params() return litellm.DatabricksEmbeddingConfig().get_supported_openai_params()
elif custom_llm_provider == "palm" or custom_llm_provider == "gemini": elif custom_llm_provider == "palm" or custom_llm_provider == "gemini":
return litellm.GoogleAIStudioGeminiConfig().get_supported_openai_params() return litellm.GoogleAIStudioGeminiConfig().get_supported_openai_params(
model=model
)
elif custom_llm_provider == "vertex_ai": elif custom_llm_provider == "vertex_ai":
if request_type == "chat_completion": if request_type == "chat_completion":
if model.startswith("meta/"): if model.startswith("meta/"):
return litellm.VertexAILlama3Config().get_supported_openai_params() return litellm.VertexAILlama3Config().get_supported_openai_params()
if model.startswith("mistral"): if model.startswith("mistral"):
return litellm.MistralConfig().get_supported_openai_params() return litellm.MistralConfig().get_supported_openai_params(model=model)
if model.startswith("codestral"): if model.startswith("codestral"):
return ( return (
litellm.MistralTextCompletionConfig().get_supported_openai_params() litellm.MistralTextCompletionConfig().get_supported_openai_params(
model=model
)
) )
if model.startswith("claude"): if model.startswith("claude"):
return litellm.VertexAIAnthropicConfig().get_supported_openai_params( return litellm.VertexAIAnthropicConfig().get_supported_openai_params(
@ -180,7 +151,7 @@ def get_supported_openai_params( # noqa: PLR0915
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params() return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()
elif custom_llm_provider == "vertex_ai_beta": elif custom_llm_provider == "vertex_ai_beta":
if request_type == "chat_completion": if request_type == "chat_completion":
return litellm.VertexGeminiConfig().get_supported_openai_params() return litellm.VertexGeminiConfig().get_supported_openai_params(model=model)
elif request_type == "embeddings": elif request_type == "embeddings":
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params() return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()
elif custom_llm_provider == "sagemaker": elif custom_llm_provider == "sagemaker":
@ -199,20 +170,11 @@ def get_supported_openai_params( # noqa: PLR0915
elif custom_llm_provider == "cloudflare": elif custom_llm_provider == "cloudflare":
return litellm.CloudflareChatConfig().get_supported_openai_params(model=model) return litellm.CloudflareChatConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "nlp_cloud": elif custom_llm_provider == "nlp_cloud":
return [ return litellm.NLPCloudConfig().get_supported_openai_params(model=model)
"max_tokens",
"stream",
"temperature",
"top_p",
"presence_penalty",
"frequency_penalty",
"n",
"stop",
]
elif custom_llm_provider == "petals": elif custom_llm_provider == "petals":
return ["max_tokens", "temperature", "top_p", "stream"] return ["max_tokens", "temperature", "top_p", "stream"]
elif custom_llm_provider == "deepinfra": elif custom_llm_provider == "deepinfra":
return litellm.DeepInfraConfig().get_supported_openai_params() return litellm.DeepInfraConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "perplexity": elif custom_llm_provider == "perplexity":
return [ return [
"temperature", "temperature",

View file

@ -7,8 +7,10 @@ this is OpenAI compatible - no translation needed / occurs
import types import types
from typing import Optional, Union from typing import Optional, Union
from ...openai_like.chat.transformation import OpenAILikeChatConfig
class AI21ChatConfig:
class AI21ChatConfig(OpenAILikeChatConfig):
""" """
Reference: https://docs.ai21.com/reference/jamba-15-api-ref#request-parameters Reference: https://docs.ai21.com/reference/jamba-15-api-ref#request-parameters
@ -19,8 +21,6 @@ class AI21ChatConfig:
response_format: Optional[dict] = None response_format: Optional[dict] = None
documents: Optional[list] = None documents: Optional[list] = None
max_tokens: Optional[int] = None max_tokens: Optional[int] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
stop: Optional[Union[str, list]] = None stop: Optional[Union[str, list]] = None
n: Optional[int] = None n: Optional[int] = None
stream: Optional[bool] = None stream: Optional[bool] = None
@ -49,21 +49,7 @@ class AI21ChatConfig:
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return { return super().get_config()
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self, model: str) -> list: def get_supported_openai_params(self, model: str) -> list:
""" """
@ -77,22 +63,9 @@ class AI21ChatConfig:
"max_tokens", "max_tokens",
"max_completion_tokens", "max_completion_tokens",
"temperature", "temperature",
"top_p",
"stop", "stop",
"n", "n",
"stream", "stream",
"seed", "seed",
"tool_choice", "tool_choice",
"user",
] ]
def map_openai_params(
self, model: str, non_default_params: dict, optional_params: dict
) -> dict:
supported_openai_params = self.get_supported_openai_params(model=model)
for param, value in non_default_params.items():
if param == "max_completion_tokens":
optional_params["max_tokens"] = value
elif param in supported_openai_params:
optional_params[param] = value
return optional_params

View file

@ -1,221 +0,0 @@
import json
import os
import time # type: ignore
import traceback
import types
from enum import Enum
from typing import Callable, Optional
import httpx
import requests # type: ignore
import litellm
from litellm.utils import Choices, Message, ModelResponse
class AI21Error(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url="https://api.ai21.com/studio/v1/"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class AI21Config:
"""
Reference: https://docs.ai21.com/reference/j2-complete-ref
The class `AI21Config` provides configuration for the AI21's API interface. Below are the parameters:
- `numResults` (int32): Number of completions to sample and return. Optional, default is 1. If the temperature is greater than 0 (non-greedy decoding), a value greater than 1 can be meaningful.
- `maxTokens` (int32): The maximum number of tokens to generate per result. Optional, default is 16. If no `stopSequences` are given, generation stops after producing `maxTokens`.
- `minTokens` (int32): The minimum number of tokens to generate per result. Optional, default is 0. If `stopSequences` are given, they are ignored until `minTokens` are generated.
- `temperature` (float): Modifies the distribution from which tokens are sampled. Optional, default is 0.7. A value of 0 essentially disables sampling and results in greedy decoding.
- `topP` (float): Used for sampling tokens from the corresponding top percentile of probability mass. Optional, default is 1. For instance, a value of 0.9 considers only tokens comprising the top 90% probability mass.
- `stopSequences` (array of strings): Stops decoding if any of the input strings is generated. Optional.
- `topKReturn` (int32): Range between 0 to 10, including both. Optional, default is 0. Specifies the top-K alternative tokens to return. A non-zero value includes the string representations and log-probabilities for each of the top-K alternatives at each position.
- `frequencyPenalty` (object): Placeholder for frequency penalty object.
- `presencePenalty` (object): Placeholder for presence penalty object.
- `countPenalty` (object): Placeholder for count penalty object.
"""
numResults: Optional[int] = None
maxTokens: Optional[int] = None
minTokens: Optional[int] = None
temperature: Optional[float] = None
topP: Optional[float] = None
stopSequences: Optional[list] = None
topKReturn: Optional[int] = None
frequencePenalty: Optional[dict] = None
presencePenalty: Optional[dict] = None
countPenalty: Optional[dict] = None
def __init__(
self,
numResults: Optional[int] = None,
maxTokens: Optional[int] = None,
minTokens: Optional[int] = None,
temperature: Optional[float] = None,
topP: Optional[float] = None,
stopSequences: Optional[list] = None,
topKReturn: Optional[int] = None,
frequencePenalty: Optional[dict] = None,
presencePenalty: Optional[dict] = None,
countPenalty: Optional[dict] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def validate_environment(api_key):
if api_key is None:
raise ValueError(
"Missing AI21 API Key - A call is being made to ai21 but no key is set either in the environment variables or via params"
)
headers = {
"accept": "application/json",
"content-type": "application/json",
"Authorization": "Bearer " + api_key,
}
return headers
def completion(
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params: dict,
litellm_params=None,
logger_fn=None,
):
headers = validate_environment(api_key)
model = model
prompt = ""
for message in messages:
if "role" in message:
if message["role"] == "user":
prompt += f"{message['content']}"
else:
prompt += f"{message['content']}"
else:
prompt += f"{message['content']}"
## Load Config
config = litellm.AI21Config.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > ai21_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
data = {
"prompt": prompt,
# "instruction": prompt, # some baseten models require the prompt to be passed in via the 'instruction' kwarg
**optional_params,
}
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=api_key,
additional_args={"complete_input_dict": data},
)
## COMPLETION CALL
response = requests.post(
api_base + model + "/complete", headers=headers, data=json.dumps(data)
)
if response.status_code != 200:
raise AI21Error(status_code=response.status_code, message=response.text)
if "stream" in optional_params and optional_params["stream"] is True:
return response.iter_lines()
else:
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
## RESPONSE OBJECT
completion_response = response.json()
try:
choices_list = []
for idx, item in enumerate(completion_response["completions"]):
if len(item["data"]["text"]) > 0:
message_obj = Message(content=item["data"]["text"])
else:
message_obj = Message(content=None)
choice_obj = Choices(
finish_reason=item["finishReason"]["reason"],
index=idx + 1,
message=message_obj,
)
choices_list.append(choice_obj)
model_response.choices = choices_list # type: ignore
except Exception:
raise AI21Error(
message=traceback.format_exc(), status_code=response.status_code
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content"))
)
model_response.created = int(time.time())
model_response.model = model
setattr(
model_response,
"usage",
litellm.Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
return model_response
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -52,20 +52,6 @@ from ..common_utils import AnthropicError, process_anthropic_headers
from .transformation import AnthropicConfig from .transformation import AnthropicConfig
# makes headers for API call
def validate_environment(
api_key,
user_headers,
model,
messages: List[AllMessageValues],
is_vertex_request: bool,
tools: Optional[List[AllAnthropicToolsValues]],
anthropic_version: Optional[str] = None,
):
pass
async def make_call( async def make_call(
client: Optional[AsyncHTTPHandler], client: Optional[AsyncHTTPHandler],
api_base: str, api_base: str,
@ -239,7 +225,7 @@ class AnthropicChatCompletion(BaseLLM):
data: dict, data: dict,
optional_params: dict, optional_params: dict,
json_mode: bool, json_mode: bool,
litellm_params=None, litellm_params: dict,
logger_fn=None, logger_fn=None,
headers={}, headers={},
client: Optional[AsyncHTTPHandler] = None, client: Optional[AsyncHTTPHandler] = None,
@ -283,6 +269,7 @@ class AnthropicChatCompletion(BaseLLM):
request_data=data, request_data=data,
messages=messages, messages=messages,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding, encoding=encoding,
json_mode=json_mode, json_mode=json_mode,
) )
@ -460,6 +447,7 @@ class AnthropicChatCompletion(BaseLLM):
request_data=data, request_data=data,
messages=messages, messages=messages,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding, encoding=encoding,
json_mode=json_mode, json_mode=json_mode,
) )

View file

@ -567,6 +567,7 @@ class AnthropicConfig(BaseConfig):
request_data: Dict, request_data: Dict,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: Dict, optional_params: Dict,
litellm_params: dict,
encoding: Any, encoding: Any,
api_key: Optional[str] = None, api_key: Optional[str] = None,
json_mode: Optional[bool] = None, json_mode: Optional[bool] = None,
@ -715,11 +716,6 @@ class AnthropicConfig(BaseConfig):
return litellm.Message(content=json_mode_content_str) return litellm.Message(content=json_mode_content_str)
return None return None
def _transform_messages(
self, messages: List[AllMessageValues]
) -> List[AllMessageValues]:
return messages
def get_error_class( def get_error_class(
self, error_message: str, status_code: int, headers: Union[Dict, httpx.Headers] self, error_message: str, status_code: int, headers: Union[Dict, httpx.Headers]
) -> BaseLLMException: ) -> BaseLLMException:

View file

@ -180,6 +180,7 @@ class AnthropicTextConfig(BaseConfig):
request_data: dict, request_data: dict,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
litellm_params: dict,
encoding: str, encoding: str,
api_key: Optional[str] = None, api_key: Optional[str] = None,
json_mode: Optional[bool] = None, json_mode: Optional[bool] = None,

View file

@ -35,38 +35,11 @@ from ...types.llms.openai import (
RetrieveBatchRequest, RetrieveBatchRequest,
) )
from ..base import BaseLLM from ..base import BaseLLM
from .common_utils import process_azure_headers from .common_utils import AzureOpenAIError, process_azure_headers
azure_ad_cache = DualCache() azure_ad_cache = DualCache()
class AzureOpenAIError(Exception):
def __init__(
self,
status_code,
message,
request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None,
headers: Optional[httpx.Headers] = None,
):
self.status_code = status_code
self.message = message
self.headers = headers
if request:
self.request = request
else:
self.request = httpx.Request(method="POST", url="https://api.openai.com/v1")
if response:
self.response = response
else:
self.response = httpx.Response(
status_code=status_code, request=self.request
)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class AzureOpenAIAssistantsAPIConfig: class AzureOpenAIAssistantsAPIConfig:
""" """
Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/assistants-reference-messages?tabs=python#create-message Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/assistants-reference-messages?tabs=python#create-message
@ -412,8 +385,12 @@ class AzureChatCompletion(BaseLLM):
data = {"model": None, "messages": messages, **optional_params} data = {"model": None, "messages": messages, **optional_params}
else: else:
data = litellm.AzureOpenAIConfig.transform_request( data = litellm.AzureOpenAIConfig().transform_request(
model=model, messages=messages, optional_params=optional_params model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers or {},
) )
if acompletion is True: if acompletion is True:

View file

@ -1,7 +1,10 @@
import types import types
from typing import List, Optional, Type, Union from typing import TYPE_CHECKING, Any, List, Optional, Type, Union
from httpx._models import Headers, Response
import litellm import litellm
from litellm.llms.base_llm.transformation import BaseLLMException
from ....exceptions import UnsupportedParamsError from ....exceptions import UnsupportedParamsError
from ....types.llms.openai import ( from ....types.llms.openai import (
@ -11,10 +14,19 @@ from ....types.llms.openai import (
ChatCompletionToolParam, ChatCompletionToolParam,
ChatCompletionToolParamFunctionChunk, ChatCompletionToolParamFunctionChunk,
) )
from ...base_llm.transformation import BaseConfig
from ...prompt_templates.factory import convert_to_azure_openai_messages from ...prompt_templates.factory import convert_to_azure_openai_messages
from ..common_utils import AzureOpenAIError
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
LoggingClass = LiteLLMLoggingObj
else:
LoggingClass = Any
class AzureOpenAIConfig: class AzureOpenAIConfig(BaseConfig):
""" """
Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions
@ -61,23 +73,9 @@ class AzureOpenAIConfig:
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return { return super().get_config()
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self): def get_supported_openai_params(self, model: str) -> List[str]:
return [ return [
"temperature", "temperature",
"n", "n",
@ -110,10 +108,10 @@ class AzureOpenAIConfig:
non_default_params: dict, non_default_params: dict,
optional_params: dict, optional_params: dict,
model: str, model: str,
api_version: str, # Y-M-D-{optional} drop_params: bool,
drop_params, api_version: str = "",
) -> dict: ) -> dict:
supported_openai_params = self.get_supported_openai_params() supported_openai_params = self.get_supported_openai_params(model)
api_version_times = api_version.split("-") api_version_times = api_version.split("-")
api_version_year = api_version_times[0] api_version_year = api_version_times[0]
@ -204,9 +202,13 @@ class AzureOpenAIConfig:
return optional_params return optional_params
@classmethod
def transform_request( def transform_request(
cls, model: str, messages: List[AllMessageValues], optional_params: dict self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict: ) -> dict:
messages = convert_to_azure_openai_messages(messages) messages = convert_to_azure_openai_messages(messages)
return { return {
@ -215,6 +217,24 @@ class AzureOpenAIConfig:
**optional_params, **optional_params,
} }
def transform_response(
self,
model: str,
raw_response: Response,
model_response: litellm.ModelResponse,
logging_obj: LoggingClass,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> litellm.ModelResponse:
raise NotImplementedError(
"Azure OpenAI handler.py has custom logic for transforming response, as it uses the OpenAI SDK."
)
def get_mapped_special_auth_params(self) -> dict: def get_mapped_special_auth_params(self) -> dict:
return {"token": "azure_ad_token"} return {"token": "azure_ad_token"}
@ -246,3 +266,22 @@ class AzureOpenAIConfig:
"westus3", "westus3",
"westus4", "westus4",
] ]
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, Headers]
) -> BaseLLMException:
return AzureOpenAIError(
message=error_message, status_code=status_code, headers=headers
)
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
) -> dict:
raise NotImplementedError(
"Azure OpenAI has custom logic for validating environment, as it uses the OpenAI SDK."
)

View file

@ -1,7 +1,27 @@
from typing import Union from typing import Optional, Union
import httpx import httpx
from litellm.llms.base_llm.transformation import BaseLLMException
class AzureOpenAIError(BaseLLMException):
def __init__(
self,
status_code,
message,
request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None,
headers: Optional[Union[httpx.Headers, dict]] = None,
):
super().__init__(
status_code=status_code,
message=message,
request=request,
response=response,
headers=headers,
)
def process_azure_headers(headers: Union[httpx.Headers, dict]) -> dict: def process_azure_headers(headers: Union[httpx.Headers, dict]) -> dict:
openai_headers = {} openai_headers = {}

View file

@ -19,104 +19,16 @@ from litellm.utils import (
convert_to_model_response_object, convert_to_model_response_object,
) )
from .base import BaseLLM from ...base import BaseLLM
from .openai.completion.handler import OpenAITextCompletion from ...openai.completion.handler import OpenAITextCompletion
from .openai.completion.transformation import OpenAITextCompletionConfig from ...openai.completion.transformation import OpenAITextCompletionConfig
from .prompt_templates.factory import custom_prompt, prompt_factory from ...prompt_templates.factory import custom_prompt, prompt_factory
from ..common_utils import AzureOpenAIError
openai_text_completion_config = OpenAITextCompletionConfig() openai_text_completion_config = OpenAITextCompletionConfig()
class AzureOpenAIError(Exception):
def __init__(
self,
status_code,
message,
request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None,
headers: Optional[httpx.Headers] = None,
):
self.status_code = status_code
self.message = message
self.headers = headers
if request:
self.request = request
else:
self.request = httpx.Request(method="POST", url="https://api.openai.com/v1")
if response:
self.response = response
else:
self.response = httpx.Response(
status_code=status_code, request=self.request
)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class AzureOpenAIConfig(OpenAIConfig):
"""
Reference: https://platform.openai.com/docs/api-reference/chat/create
The class `AzureOpenAIConfig` provides configuration for the OpenAI's Chat API interface, for use with Azure. It inherits from `OpenAIConfig`. Below are the parameters::
- `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition.
- `function_call` (string or object): This optional parameter controls how the model calls functions.
- `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs.
- `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion.
- `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message.
- `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics.
- `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2.
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
"""
def __init__(
self,
frequency_penalty: Optional[int] = None,
function_call: Optional[Union[str, dict]] = None,
functions: Optional[list] = None,
logit_bias: Optional[dict] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[int] = None,
stop: Optional[Union[str, list]] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
) -> None:
super().__init__(
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
stop=stop,
temperature=temperature,
top_p=top_p,
)
def select_azure_base_url_or_endpoint(azure_client_params: dict): def select_azure_base_url_or_endpoint(azure_client_params: dict):
# azure_client_params = {
# "api_version": api_version,
# "azure_endpoint": api_base,
# "azure_deployment": model,
# "http_client": litellm.client_session,
# "max_retries": max_retries,
# "timeout": timeout,
# }
azure_endpoint = azure_client_params.get("azure_endpoint", None) azure_endpoint = azure_client_params.get("azure_endpoint", None)
if azure_endpoint is not None: if azure_endpoint is not None:
# see : https://github.com/openai/openai-python/blob/3d61ed42aba652b547029095a7eb269ad4e1e957/src/openai/lib/azure.py#L192 # see : https://github.com/openai/openai-python/blob/3d61ed42aba652b547029095a7eb269ad4e1e957/src/openai/lib/azure.py#L192

View file

@ -0,0 +1,53 @@
from typing import Optional, Union
from ...openai.completion.transformation import OpenAITextCompletionConfig
class AzureOpenAITextConfig(OpenAITextCompletionConfig):
"""
Reference: https://platform.openai.com/docs/api-reference/chat/create
The class `AzureOpenAIConfig` provides configuration for the OpenAI's Chat API interface, for use with Azure. It inherits from `OpenAIConfig`. Below are the parameters::
- `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition.
- `function_call` (string or object): This optional parameter controls how the model calls functions.
- `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs.
- `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion.
- `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message.
- `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics.
- `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2.
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
"""
def __init__(
self,
frequency_penalty: Optional[int] = None,
logit_bias: Optional[dict] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[int] = None,
stop: Optional[Union[str, list]] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
) -> None:
super().__init__(
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
stop=stop,
temperature=temperature,
top_p=top_p,
)

View file

@ -1 +0,0 @@
from .handler import AzureAIChatCompletion

View file

@ -1,59 +1,3 @@
from typing import Any, Callable, List, Optional, Union """
LLM Calling done in `openai/openai.py`
from httpx._config import Timeout """
from litellm.llms.bedrock.chat.invoke_handler import MockResponseIterator
from litellm.llms.openai.openai import OpenAIChatCompletion
from litellm.types.utils import ModelResponse
from litellm.utils import CustomStreamWrapper
from .transformation import AzureAIStudioConfig
class AzureAIChatCompletion(OpenAIChatCompletion):
def completion(
self,
model_response: ModelResponse,
timeout: Union[float, Timeout],
optional_params: dict,
logging_obj: Any,
model: Optional[str] = None,
messages: Optional[list] = None,
print_verbose: Optional[Callable[..., Any]] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
acompletion: bool = False,
litellm_params=None,
logger_fn=None,
headers: Optional[dict] = None,
custom_prompt_dict: dict = {},
client=None,
organization: Optional[str] = None,
custom_llm_provider: Optional[str] = None,
drop_params: Optional[bool] = None,
):
transformed_messages = AzureAIStudioConfig()._transform_messages(
messages=messages # type: ignore
)
return super().completion(
model_response,
timeout,
optional_params,
logging_obj,
model,
transformed_messages,
print_verbose,
api_key,
api_base,
acompletion,
litellm_params,
logger_fn,
headers,
custom_prompt_dict,
client,
organization,
custom_llm_provider,
drop_params,
)

View file

@ -13,6 +13,7 @@ from typing import (
Iterator, Iterator,
List, List,
Optional, Optional,
TypedDict,
Union, Union,
) )
@ -34,15 +35,25 @@ class BaseLLMException(Exception):
self, self,
status_code: int, status_code: int,
message: str, message: str,
headers: Optional[Union[httpx.Headers, Dict]] = None, headers: Optional[Union[dict, httpx.Headers]] = None,
request: Optional[httpx.Request] = None, request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None, response: Optional[httpx.Response] = None,
): ):
self.status_code = status_code self.status_code = status_code
self.message: str = message self.message: str = message
self.headers = headers self.headers = headers
self.request = httpx.Request(method="POST", url="https://docs.litellm.ai/docs") if request:
self.response = httpx.Response(status_code=status_code, request=self.request) self.request = request
else:
self.request = httpx.Request(
method="POST", url="https://docs.litellm.ai/docs"
)
if response:
self.response = response
else:
self.response = httpx.Response(
status_code=status_code, request=self.request
)
super().__init__( super().__init__(
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
@ -117,12 +128,6 @@ class BaseConfig(ABC):
) -> dict: ) -> dict:
pass pass
@abstractmethod
def _transform_messages(
self, messages: List[AllMessageValues]
) -> List[AllMessageValues]:
pass
@abstractmethod @abstractmethod
def transform_response( def transform_response(
self, self,
@ -133,7 +138,8 @@ class BaseConfig(ABC):
request_data: dict, request_data: dict,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
encoding: str, litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None, api_key: Optional[str] = None,
json_mode: Optional[bool] = None, json_mode: Optional[bool] = None,
) -> ModelResponse: ) -> ModelResponse:

View file

@ -7,8 +7,10 @@ this is OpenAI compatible - no translation needed / occurs
import types import types
from typing import Optional, Union from typing import Optional, Union
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
class CerebrasConfig:
class CerebrasConfig(OpenAIGPTConfig):
""" """
Reference: https://inference-docs.cerebras.ai/api-reference/chat-completions Reference: https://inference-docs.cerebras.ai/api-reference/chat-completions
@ -18,9 +20,7 @@ class CerebrasConfig:
max_tokens: Optional[int] = None max_tokens: Optional[int] = None
response_format: Optional[dict] = None response_format: Optional[dict] = None
seed: Optional[int] = None seed: Optional[int] = None
stop: Optional[str] = None
stream: Optional[bool] = None stream: Optional[bool] = None
temperature: Optional[float] = None
top_p: Optional[int] = None top_p: Optional[int] = None
tool_choice: Optional[str] = None tool_choice: Optional[str] = None
tools: Optional[list] = None tools: Optional[list] = None
@ -46,21 +46,7 @@ class CerebrasConfig:
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return { return super().get_config()
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self, model: str) -> list: def get_supported_openai_params(self, model: str) -> list:
""" """
@ -83,7 +69,11 @@ class CerebrasConfig:
] ]
def map_openai_params( def map_openai_params(
self, model: str, non_default_params: dict, optional_params: dict self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict: ) -> dict:
supported_openai_params = self.get_supported_openai_params(model=model) supported_openai_params = self.get_supported_openai_params(model=model)
for param, value in non_default_params.items(): for param, value in non_default_params.items():

View file

@ -148,6 +148,7 @@ class ClarifaiConfig(BaseConfig):
request_data: dict, request_data: dict,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
litellm_params: dict,
encoding: str, encoding: str,
api_key: Optional[str] = None, api_key: Optional[str] = None,
json_mode: Optional[bool] = None, json_mode: Optional[bool] = None,

View file

@ -49,6 +49,10 @@ class CloudflareChatConfig(BaseConfig):
if key != "self" and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return super().get_config()
def validate_environment( def validate_environment(
self, self,
headers: dict, headers: dict,
@ -120,6 +124,7 @@ class CloudflareChatConfig(BaseConfig):
request_data: dict, request_data: dict,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
litellm_params: dict,
encoding: str, encoding: str,
api_key: Optional[str] = None, api_key: Optional[str] = None,
json_mode: Optional[bool] = None, json_mode: Optional[bool] = None,

View file

@ -216,7 +216,8 @@ class CohereChatConfig(BaseConfig):
request_data: dict, request_data: dict,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
encoding: str, litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None, api_key: Optional[str] = None,
json_mode: Optional[bool] = None, json_mode: Optional[bool] = None,
) -> ModelResponse: ) -> ModelResponse:

View file

@ -217,7 +217,8 @@ class CohereTextConfig(BaseConfig):
request_data: dict, request_data: dict,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
encoding: str, litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None, api_key: Optional[str] = None,
json_mode: Optional[bool] = None, json_mode: Optional[bool] = None,
) -> ModelResponse: ) -> ModelResponse:

View file

@ -211,7 +211,6 @@ class AsyncHTTPHandler:
headers=headers, headers=headers,
) )
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
if stream is True: if stream is True:
setattr(e, "message", await e.response.aread()) setattr(e, "message", await e.response.aread())
setattr(e, "text", await e.response.aread()) setattr(e, "text", await e.response.aread())

View file

@ -51,7 +51,8 @@ class BaseLLMHTTPHandler:
logging_obj: LiteLLMLoggingObj, logging_obj: LiteLLMLoggingObj,
messages: list, messages: list,
optional_params: dict, optional_params: dict,
encoding: str, litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None, api_key: Optional[str] = None,
): ):
async_httpx_client = get_async_httpx_client( async_httpx_client = get_async_httpx_client(
@ -75,6 +76,7 @@ class BaseLLMHTTPHandler:
request_data=data, request_data=data,
messages=messages, messages=messages,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding, encoding=encoding,
) )
@ -163,6 +165,7 @@ class BaseLLMHTTPHandler:
api_key=api_key, api_key=api_key,
messages=messages, messages=messages,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding, encoding=encoding,
) )
@ -211,6 +214,7 @@ class BaseLLMHTTPHandler:
request_data=data, request_data=data,
messages=messages, messages=messages,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding, encoding=encoding,
) )

View file

@ -10,14 +10,14 @@ from pydantic import BaseModel
from litellm.types.llms.openai import AllMessageValues from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ProviderField from litellm.types.utils import ProviderField
from ...openai.chat.gpt_transformation import OpenAIGPTConfig from ...openai_like.chat.transformation import OpenAILikeChatConfig
from ...prompt_templates.common_utils import ( from ...prompt_templates.common_utils import (
handle_messages_with_content_list_to_str_conversion, handle_messages_with_content_list_to_str_conversion,
strip_name_from_messages, strip_name_from_messages,
) )
class DatabricksConfig(OpenAIGPTConfig): class DatabricksConfig(OpenAILikeChatConfig):
""" """
Reference: https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request Reference: https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request
""" """
@ -85,30 +85,6 @@ class DatabricksConfig(OpenAIGPTConfig):
return False return False
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
):
for param, value in non_default_params.items():
if param == "max_tokens" or param == "max_completion_tokens":
optional_params["max_tokens"] = value
if param == "n":
optional_params["n"] = value
if param == "stream" and value is True:
optional_params["stream"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "stop":
optional_params["stop"] = value
if param == "response_format":
optional_params["response_format"] = value
return optional_params
def _transform_messages( def _transform_messages(
self, messages: List[AllMessageValues] self, messages: List[AllMessageValues]
) -> List[AllMessageValues]: ) -> List[AllMessageValues]:

View file

@ -0,0 +1,120 @@
import types
from typing import Optional, Tuple, Union
import litellm
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
from litellm.secret_managers.main import get_secret_str
class DeepInfraConfig(OpenAIGPTConfig):
"""
Reference: https://deepinfra.com/docs/advanced/openai_api
The class `DeepInfra` provides configuration for the DeepInfra's Chat Completions API interface. Below are the parameters:
"""
frequency_penalty: Optional[int] = None
function_call: Optional[Union[str, dict]] = None
functions: Optional[list] = None
logit_bias: Optional[dict] = None
max_tokens: Optional[int] = None
n: Optional[int] = None
presence_penalty: Optional[int] = None
stop: Optional[Union[str, list]] = None
temperature: Optional[int] = None
top_p: Optional[int] = None
response_format: Optional[dict] = None
tools: Optional[list] = None
tool_choice: Optional[Union[str, dict]] = None
def __init__(
self,
frequency_penalty: Optional[int] = None,
function_call: Optional[Union[str, dict]] = None,
functions: Optional[list] = None,
logit_bias: Optional[dict] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[int] = None,
stop: Optional[Union[str, list]] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
response_format: Optional[dict] = None,
tools: Optional[list] = None,
tool_choice: Optional[Union[str, dict]] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return super().get_config()
def get_supported_openai_params(self, model: str):
return [
"stream",
"frequency_penalty",
"function_call",
"functions",
"logit_bias",
"max_tokens",
"max_completion_tokens",
"n",
"presence_penalty",
"stop",
"temperature",
"top_p",
"response_format",
"tools",
"tool_choice",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
supported_openai_params = self.get_supported_openai_params(model=model)
for param, value in non_default_params.items():
if (
param == "temperature"
and value == 0
and model == "mistralai/Mistral-7B-Instruct-v0.1"
): # this model does no support temperature == 0
value = 0.0001 # close to 0
if param == "tool_choice":
if (
value != "auto" and value != "none"
): # https://deepinfra.com/docs/advanced/function_calling
## UNSUPPORTED TOOL CHOICE VALUE
if litellm.drop_params is True or drop_params is True:
value = None
else:
raise litellm.utils.UnsupportedParamsError(
message="Deepinfra doesn't support tool_choice={}. To drop unsupported openai params from the call, set `litellm.drop_params = True`".format(
value
),
status_code=400,
)
elif param == "max_completion_tokens":
optional_params["max_tokens"] = value
elif param in supported_openai_params:
if value is not None:
optional_params[param] = value
return optional_params
def _get_openai_compatible_provider_info(
self, api_base: Optional[str], api_key: Optional[str]
) -> Tuple[Optional[str], Optional[str]]:
# deepinfra is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1
api_base = (
api_base
or get_secret_str("DEEPINFRA_API_BASE")
or "https://api.deepinfra.com/v1/openai"
)
dynamic_api_key = api_key or get_secret_str("DEEPINFRA_API_KEY")
return api_base, dynamic_api_key

View file

@ -1,421 +0,0 @@
# ####################################
# ######### DEPRECATED FILE ##########
# ####################################
# # logic moved to `vertex_httpx.py` #
import copy
import time
import traceback
import types
from typing import Callable, Optional
import httpx
from packaging.version import Version
import litellm
from litellm import verbose_logger
from litellm.utils import Choices, Message, ModelResponse, Usage
from .prompt_templates.factory import custom_prompt, get_system_prompt, prompt_factory
class GeminiError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST",
url="https://developers.generativeai.google/api/python/google/generativeai/chat",
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class GeminiConfig:
"""
Reference: https://ai.google.dev/api/python/google/generativeai/GenerationConfig
The class `GeminiConfig` provides configuration for the Gemini's API interface. Here are the parameters:
- `candidate_count` (int): Number of generated responses to return.
- `stop_sequences` (List[str]): The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop at the first appearance of a stop sequence. The stop sequence will not be included as part of the response.
- `max_output_tokens` (int): The maximum number of tokens to include in a candidate. If unset, this will default to output_token_limit specified in the model's specification.
- `temperature` (float): Controls the randomness of the output. Note: The default value varies by model, see the Model.temperature attribute of the Model returned the genai.get_model function. Values can range from [0.0,1.0], inclusive. A value closer to 1.0 will produce responses that are more varied and creative, while a value closer to 0.0 will typically result in more straightforward responses from the model.
- `top_p` (float): Optional. The maximum cumulative probability of tokens to consider when sampling.
- `top_k` (int): Optional. The maximum number of tokens to consider when sampling.
"""
candidate_count: Optional[int] = None
stop_sequences: Optional[list] = None
max_output_tokens: Optional[int] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
def __init__(
self,
candidate_count: Optional[int] = None,
stop_sequences: Optional[list] = None,
max_output_tokens: Optional[int] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
# class TextStreamer:
# """
# A class designed to return an async stream from AsyncGenerateContentResponse object.
# """
# def __init__(self, response):
# self.response = response
# self._aiter = self.response.__aiter__()
# async def __aiter__(self):
# while True:
# try:
# # This will manually advance the async iterator.
# # In the case the next object doesn't exists, __anext__() will simply raise a StopAsyncIteration exception
# next_object = await self._aiter.__anext__()
# yield next_object
# except StopAsyncIteration:
# # After getting all items from the async iterator, stop iterating
# break
# def supports_system_instruction():
# import google.generativeai as genai
# gemini_pkg_version = Version(genai.__version__)
# return gemini_pkg_version >= Version("0.5.0")
# def completion(
# model: str,
# messages: list,
# model_response: ModelResponse,
# print_verbose: Callable,
# api_key,
# encoding,
# logging_obj,
# custom_prompt_dict: dict,
# acompletion: bool = False,
# optional_params=None,
# litellm_params=None,
# logger_fn=None,
# ):
# try:
# import google.generativeai as genai # type: ignore
# except Exception:
# raise Exception(
# "Importing google.generativeai failed, please run 'pip install -q google-generativeai"
# )
# genai.configure(api_key=api_key)
# system_prompt = ""
# if model in custom_prompt_dict:
# # check if the model has a registered custom prompt
# model_prompt_details = custom_prompt_dict[model]
# prompt = custom_prompt(
# role_dict=model_prompt_details["roles"],
# initial_prompt_value=model_prompt_details["initial_prompt_value"],
# final_prompt_value=model_prompt_details["final_prompt_value"],
# messages=messages,
# )
# else:
# system_prompt, messages = get_system_prompt(messages=messages)
# prompt = prompt_factory(
# model=model, messages=messages, custom_llm_provider="gemini"
# )
# ## Load Config
# inference_params = copy.deepcopy(optional_params)
# stream = inference_params.pop("stream", None)
# # Handle safety settings
# safety_settings_param = inference_params.pop("safety_settings", None)
# safety_settings = None
# if safety_settings_param:
# safety_settings = [
# genai.types.SafetySettingDict(x) for x in safety_settings_param
# ]
# config = litellm.GeminiConfig.get_config()
# for k, v in config.items():
# if (
# k not in inference_params
# ): # completion(top_k=3) > gemini_config(top_k=3) <- allows for dynamic variables to be passed in
# inference_params[k] = v
# ## LOGGING
# logging_obj.pre_call(
# input=prompt,
# api_key="",
# additional_args={
# "complete_input_dict": {
# "inference_params": inference_params,
# "system_prompt": system_prompt,
# }
# },
# )
# ## COMPLETION CALL
# try:
# _params = {"model_name": "models/{}".format(model)}
# _system_instruction = supports_system_instruction()
# if _system_instruction and len(system_prompt) > 0:
# _params["system_instruction"] = system_prompt
# _model = genai.GenerativeModel(**_params)
# if stream is True:
# if acompletion is True:
# async def async_streaming():
# try:
# response = await _model.generate_content_async(
# contents=prompt,
# generation_config=genai.types.GenerationConfig(
# **inference_params
# ),
# safety_settings=safety_settings,
# stream=True,
# )
# response = litellm.CustomStreamWrapper(
# TextStreamer(response),
# model,
# custom_llm_provider="gemini",
# logging_obj=logging_obj,
# )
# return response
# except Exception as e:
# raise GeminiError(status_code=500, message=str(e))
# return async_streaming()
# response = _model.generate_content(
# contents=prompt,
# generation_config=genai.types.GenerationConfig(**inference_params),
# safety_settings=safety_settings,
# stream=True,
# )
# return response
# elif acompletion == True:
# return async_completion(
# _model=_model,
# model=model,
# prompt=prompt,
# inference_params=inference_params,
# safety_settings=safety_settings,
# logging_obj=logging_obj,
# print_verbose=print_verbose,
# model_response=model_response,
# messages=messages,
# encoding=encoding,
# )
# else:
# params = {
# "contents": prompt,
# "generation_config": genai.types.GenerationConfig(**inference_params),
# "safety_settings": safety_settings,
# }
# response = _model.generate_content(**params)
# except Exception as e:
# raise GeminiError(
# message=str(e),
# status_code=500,
# )
# ## LOGGING
# logging_obj.post_call(
# input=prompt,
# api_key="",
# original_response=response,
# additional_args={"complete_input_dict": {}},
# )
# print_verbose(f"raw model_response: {response}")
# ## RESPONSE OBJECT
# completion_response = response
# try:
# choices_list = []
# for idx, item in enumerate(completion_response.candidates):
# if len(item.content.parts) > 0:
# message_obj = Message(content=item.content.parts[0].text)
# else:
# message_obj = Message(content=None)
# choice_obj = Choices(index=idx, message=message_obj)
# choices_list.append(choice_obj)
# model_response.choices = choices_list
# except Exception as e:
# verbose_logger.error("LiteLLM.gemini.py: Exception occured - {}".format(str(e)))
# raise GeminiError(
# message=traceback.format_exc(), status_code=response.status_code
# )
# try:
# completion_response = model_response["choices"][0]["message"].get("content")
# if completion_response is None:
# raise Exception
# except Exception:
# original_response = f"response: {response}"
# if hasattr(response, "candidates"):
# original_response = f"response: {response.candidates}"
# if "SAFETY" in original_response:
# original_response += (
# "\nThe candidate content was flagged for safety reasons."
# )
# elif "RECITATION" in original_response:
# original_response += (
# "\nThe candidate content was flagged for recitation reasons."
# )
# raise GeminiError(
# status_code=400,
# message=f"No response received. Original response - {original_response}",
# )
# ## CALCULATING USAGE
# prompt_str = ""
# for m in messages:
# if isinstance(m["content"], str):
# prompt_str += m["content"]
# elif isinstance(m["content"], list):
# for content in m["content"]:
# if content["type"] == "text":
# prompt_str += content["text"]
# prompt_tokens = len(encoding.encode(prompt_str))
# completion_tokens = len(
# encoding.encode(model_response["choices"][0]["message"].get("content", ""))
# )
# model_response.created = int(time.time())
# model_response.model = "gemini/" + model
# usage = Usage(
# prompt_tokens=prompt_tokens,
# completion_tokens=completion_tokens,
# total_tokens=prompt_tokens + completion_tokens,
# )
# setattr(model_response, "usage", usage)
# return model_response
# async def async_completion(
# _model,
# model,
# prompt,
# inference_params,
# safety_settings,
# logging_obj,
# print_verbose,
# model_response,
# messages,
# encoding,
# ):
# import google.generativeai as genai # type: ignore
# response = await _model.generate_content_async(
# contents=prompt,
# generation_config=genai.types.GenerationConfig(**inference_params),
# safety_settings=safety_settings,
# )
# ## LOGGING
# logging_obj.post_call(
# input=prompt,
# api_key="",
# original_response=response,
# additional_args={"complete_input_dict": {}},
# )
# print_verbose(f"raw model_response: {response}")
# ## RESPONSE OBJECT
# completion_response = response
# try:
# choices_list = []
# for idx, item in enumerate(completion_response.candidates):
# if len(item.content.parts) > 0:
# message_obj = Message(content=item.content.parts[0].text)
# else:
# message_obj = Message(content=None)
# choice_obj = Choices(index=idx, message=message_obj)
# choices_list.append(choice_obj)
# model_response["choices"] = choices_list
# except Exception as e:
# verbose_logger.error("LiteLLM.gemini.py: Exception occured - {}".format(str(e)))
# raise GeminiError(
# message=traceback.format_exc(), status_code=response.status_code
# )
# try:
# completion_response = model_response["choices"][0]["message"].get("content")
# if completion_response is None:
# raise Exception
# except Exception:
# original_response = f"response: {response}"
# if hasattr(response, "candidates"):
# original_response = f"response: {response.candidates}"
# if "SAFETY" in original_response:
# original_response += (
# "\nThe candidate content was flagged for safety reasons."
# )
# elif "RECITATION" in original_response:
# original_response += (
# "\nThe candidate content was flagged for recitation reasons."
# )
# raise GeminiError(
# status_code=400,
# message=f"No response received. Original response - {original_response}",
# )
# ## CALCULATING USAGE
# prompt_str = ""
# for m in messages:
# if isinstance(m["content"], str):
# prompt_str += m["content"]
# elif isinstance(m["content"], list):
# for content in m["content"]:
# if content["type"] == "text":
# prompt_str += content["text"]
# prompt_tokens = len(encoding.encode(prompt_str))
# completion_tokens = len(
# encoding.encode(model_response["choices"][0]["message"].get("content", ""))
# )
# model_response["created"] = int(time.time())
# model_response["model"] = "gemini/" + model
# usage = Usage(
# prompt_tokens=prompt_tokens,
# completion_tokens=completion_tokens,
# total_tokens=prompt_tokens + completion_tokens,
# )
# model_response.usage = usage
# return model_response
# def embedding():
# # logic for parsing in - calling - parsing out model embedding calls
# pass

View file

@ -0,0 +1,750 @@
## Uses the huggingface text generation inference API
import copy
import enum
import json
import os
import time
import types
from enum import Enum
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Tuple,
Union,
cast,
get_args,
)
import httpx
import requests
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client,
get_async_httpx_client,
)
from litellm.llms.huggingface.chat.transformation import (
HuggingfaceChatConfig as HuggingfaceConfig,
)
from litellm.secret_managers.main import get_secret_str
from litellm.types.completion import ChatCompletionMessageToolCallParam
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import Logprobs as TextCompletionLogprobs
from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage
from ...base import BaseLLM
from ...prompt_templates.factory import custom_prompt, prompt_factory
from ..common_utils import HuggingfaceError, hf_task_list, hf_tasks
hf_chat_config = HuggingfaceConfig()
hf_tasks_embeddings = Literal[ # pipeline tags + hf tei endpoints - https://huggingface.github.io/text-embeddings-inference/#/
"sentence-similarity", "feature-extraction", "rerank", "embed", "similarity"
]
def get_hf_task_embedding_for_model(
model: str, task_type: Optional[str], api_base: str
) -> Optional[str]:
if task_type is not None:
if task_type in get_args(hf_tasks_embeddings):
return task_type
else:
raise Exception(
"Invalid task_type={}. Expected one of={}".format(
task_type, hf_tasks_embeddings
)
)
http_client = HTTPHandler(concurrent_limit=1)
model_info = http_client.get(url=api_base)
model_info_dict = model_info.json()
pipeline_tag: Optional[str] = model_info_dict.get("pipeline_tag", None)
return pipeline_tag
async def async_get_hf_task_embedding_for_model(
model: str, task_type: Optional[str], api_base: str
) -> Optional[str]:
if task_type is not None:
if task_type in get_args(hf_tasks_embeddings):
return task_type
else:
raise Exception(
"Invalid task_type={}. Expected one of={}".format(
task_type, hf_tasks_embeddings
)
)
http_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.HUGGINGFACE,
)
model_info = await http_client.get(url=api_base)
model_info_dict = model_info.json()
pipeline_tag: Optional[str] = model_info_dict.get("pipeline_tag", None)
return pipeline_tag
async def make_call(
client: Optional[AsyncHTTPHandler],
api_base: str,
headers: dict,
data: str,
model: str,
messages: list,
logging_obj,
timeout: Optional[Union[float, httpx.Timeout]],
json_mode: bool,
) -> Tuple[Any, httpx.Headers]:
if client is None:
client = litellm.module_level_aclient
try:
response = await client.post(
api_base, headers=headers, data=data, stream=True, timeout=timeout
)
except httpx.HTTPStatusError as e:
error_headers = getattr(e, "headers", None)
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise HuggingfaceError(
status_code=e.response.status_code,
message=str(await e.response.aread()),
headers=cast(dict, error_headers) if error_headers else None,
)
except Exception as e:
for exception in litellm.LITELLM_EXCEPTION_TYPES:
if isinstance(e, exception):
raise e
raise HuggingfaceError(status_code=500, message=str(e))
# LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response=response, # Pass the completion stream for logging
additional_args={"complete_input_dict": data},
)
return response.aiter_lines(), response.headers
class Huggingface(BaseLLM):
_client_session: Optional[httpx.Client] = None
_aclient_session: Optional[httpx.AsyncClient] = None
def __init__(self) -> None:
super().__init__()
def completion( # noqa: PLR0915
self,
model: str,
messages: list,
api_base: Optional[str],
model_response: ModelResponse,
print_verbose: Callable,
timeout: float,
encoding,
api_key,
logging_obj,
optional_params: dict,
litellm_params: dict,
custom_prompt_dict={},
acompletion: bool = False,
logger_fn=None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
headers: dict = {},
):
super().completion()
exception_mapping_worked = False
try:
task, model = hf_chat_config.get_hf_task_for_model(model)
litellm_params["task"] = task
headers = hf_chat_config.validate_environment(
api_key=api_key,
headers=headers,
model=model,
messages=messages,
optional_params=optional_params,
)
completion_url = hf_chat_config.get_api_base(api_base=api_base, model=model)
data = hf_chat_config.transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
## LOGGING
logging_obj.pre_call(
input=data,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": completion_url,
"acompletion": acompletion,
},
)
## COMPLETION CALL
if acompletion is True:
### ASYNC STREAMING
if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, api_base=completion_url, data=data, headers=headers, model_response=model_response, model=model, timeout=timeout, messages=messages) # type: ignore
else:
### ASYNC COMPLETION
return self.acompletion(api_base=completion_url, data=data, headers=headers, model_response=model_response, task=task, encoding=encoding, model=model, optional_params=optional_params, timeout=timeout, litellm_params=litellm_params) # type: ignore
if client is None or not isinstance(client, HTTPHandler):
client = HTTPHandler()
### SYNC STREAMING
if "stream" in optional_params and optional_params["stream"] is True:
response = client.post(
url=completion_url,
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"],
)
return response.iter_lines()
### SYNC COMPLETION
else:
response = client.post(
url=completion_url,
headers=headers,
data=json.dumps(data),
)
return hf_chat_config.transform_response(
model=model,
raw_response=response,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
request_data=data,
messages=messages,
optional_params=optional_params,
encoding=encoding,
json_mode=None,
litellm_params=litellm_params,
)
except httpx.HTTPStatusError as e:
raise HuggingfaceError(
status_code=e.response.status_code,
message=e.response.text,
headers=e.response.headers,
)
except HuggingfaceError as e:
exception_mapping_worked = True
raise e
except Exception as e:
if exception_mapping_worked:
raise e
else:
import traceback
raise HuggingfaceError(status_code=500, message=traceback.format_exc())
async def acompletion(
self,
api_base: str,
data: dict,
headers: dict,
model_response: ModelResponse,
encoding: Any,
model: str,
optional_params: dict,
litellm_params: dict,
timeout: float,
logging_obj: LiteLLMLoggingObj,
api_key: str,
messages: List[AllMessageValues],
):
response: Optional[httpx.Response] = None
try:
http_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.HUGGINGFACE
)
### ASYNC COMPLETION
http_response = await http_client.post(
url=api_base, headers=headers, data=json.dumps(data), timeout=timeout
)
response = http_response
return hf_chat_config.transform_response(
model=model,
raw_response=http_response,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
request_data=data,
messages=messages,
optional_params=optional_params,
encoding=encoding,
json_mode=None,
litellm_params=litellm_params,
)
except Exception as e:
if isinstance(e, httpx.TimeoutException):
raise HuggingfaceError(status_code=500, message="Request Timeout Error")
elif isinstance(e, HuggingfaceError):
raise e
elif response is not None and hasattr(response, "text"):
raise HuggingfaceError(
status_code=500,
message=f"{str(e)}\n\nOriginal Response: {response.text}",
headers=response.headers,
)
else:
raise HuggingfaceError(status_code=500, message=f"{str(e)}")
async def async_streaming(
self,
logging_obj,
api_base: str,
data: dict,
headers: dict,
model_response: ModelResponse,
messages: List[AllMessageValues],
model: str,
timeout: float,
client: Optional[AsyncHTTPHandler] = None,
):
completion_stream, _ = await make_call(
client=client,
api_base=api_base,
headers=headers,
data=json.dumps(data),
model=model,
messages=messages,
logging_obj=logging_obj,
timeout=timeout,
json_mode=False,
)
streamwrapper = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="huggingface",
logging_obj=logging_obj,
)
return streamwrapper
def _transform_input_on_pipeline_tag(
self, input: List, pipeline_tag: Optional[str]
) -> dict:
if pipeline_tag is None:
return {"inputs": input}
if pipeline_tag == "sentence-similarity" or pipeline_tag == "similarity":
if len(input) < 2:
raise HuggingfaceError(
status_code=400,
message="sentence-similarity requires 2+ sentences",
)
return {"inputs": {"source_sentence": input[0], "sentences": input[1:]}}
elif pipeline_tag == "rerank":
if len(input) < 2:
raise HuggingfaceError(
status_code=400,
message="reranker requires 2+ sentences",
)
return {"inputs": {"query": input[0], "texts": input[1:]}}
return {"inputs": input} # default to feature-extraction pipeline tag
async def _async_transform_input(
self,
model: str,
task_type: Optional[str],
embed_url: str,
input: List,
optional_params: dict,
) -> dict:
hf_task = await async_get_hf_task_embedding_for_model(
model=model, task_type=task_type, api_base=embed_url
)
data = self._transform_input_on_pipeline_tag(input=input, pipeline_tag=hf_task)
if len(optional_params.keys()) > 0:
data["options"] = optional_params
return data
def _process_optional_params(self, data: dict, optional_params: dict) -> dict:
special_options_keys = HuggingfaceConfig().get_special_options_params()
special_parameters_keys = [
"min_length",
"max_length",
"top_k",
"top_p",
"temperature",
"repetition_penalty",
"max_time",
]
for k, v in optional_params.items():
if k in special_options_keys:
data.setdefault("options", {})
data["options"][k] = v
elif k in special_parameters_keys:
data.setdefault("parameters", {})
data["parameters"][k] = v
else:
data[k] = v
return data
def _transform_input(
self,
input: List,
model: str,
call_type: Literal["sync", "async"],
optional_params: dict,
embed_url: str,
) -> dict:
data: Dict = {}
## TRANSFORMATION ##
if "sentence-transformers" in model:
if len(input) == 0:
raise HuggingfaceError(
status_code=400,
message="sentence transformers requires 2+ sentences",
)
data = {"inputs": {"source_sentence": input[0], "sentences": input[1:]}}
else:
data = {"inputs": input}
task_type = optional_params.pop("input_type", None)
if call_type == "sync":
hf_task = get_hf_task_embedding_for_model(
model=model, task_type=task_type, api_base=embed_url
)
elif call_type == "async":
return self._async_transform_input(
model=model, task_type=task_type, embed_url=embed_url, input=input
) # type: ignore
data = self._transform_input_on_pipeline_tag(
input=input, pipeline_tag=hf_task
)
if len(optional_params.keys()) > 0:
data = self._process_optional_params(
data=data, optional_params=optional_params
)
return data
def _process_embedding_response(
self,
embeddings: dict,
model_response: litellm.EmbeddingResponse,
model: str,
input: List,
encoding: Any,
) -> litellm.EmbeddingResponse:
output_data = []
if "similarities" in embeddings:
for idx, embedding in embeddings["similarities"]:
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding, # flatten list returned from hf
}
)
else:
for idx, embedding in enumerate(embeddings):
if isinstance(embedding, float):
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding, # flatten list returned from hf
}
)
elif isinstance(embedding, list) and isinstance(embedding[0], float):
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding, # flatten list returned from hf
}
)
else:
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding[0][
0
], # flatten list returned from hf
}
)
model_response.object = "list"
model_response.data = output_data
model_response.model = model
input_tokens = 0
for text in input:
input_tokens += len(encoding.encode(text))
setattr(
model_response,
"usage",
litellm.Usage(
prompt_tokens=input_tokens,
completion_tokens=input_tokens,
total_tokens=input_tokens,
prompt_tokens_details=None,
completion_tokens_details=None,
),
)
return model_response
async def aembedding(
self,
model: str,
input: list,
model_response: litellm.utils.EmbeddingResponse,
timeout: Union[float, httpx.Timeout],
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
api_base: str,
api_key: Optional[str],
headers: dict,
encoding: Callable,
client: Optional[AsyncHTTPHandler] = None,
):
## TRANSFORMATION ##
data = self._transform_input(
input=input,
model=model,
call_type="sync",
optional_params=optional_params,
embed_url=api_base,
)
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": api_base,
},
)
## COMPLETION CALL
if client is None:
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.HUGGINGFACE,
)
response = await client.post(api_base, headers=headers, data=json.dumps(data))
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)
embeddings = response.json()
if "error" in embeddings:
raise HuggingfaceError(status_code=500, message=embeddings["error"])
## PROCESS RESPONSE ##
return self._process_embedding_response(
embeddings=embeddings,
model_response=model_response,
model=model,
input=input,
encoding=encoding,
)
def embedding(
self,
model: str,
input: list,
model_response: litellm.EmbeddingResponse,
optional_params: dict,
logging_obj: LiteLLMLoggingObj,
encoding: Callable,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
aembedding: Optional[bool] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
headers={},
) -> litellm.EmbeddingResponse:
super().embedding()
headers = hf_chat_config.validate_environment(
api_key=api_key,
headers=headers,
model=model,
optional_params=optional_params,
messages=[],
)
# print_verbose(f"{model}, {task}")
embed_url = ""
if "https" in model:
embed_url = model
elif api_base:
embed_url = api_base
elif "HF_API_BASE" in os.environ:
embed_url = os.getenv("HF_API_BASE", "")
elif "HUGGINGFACE_API_BASE" in os.environ:
embed_url = os.getenv("HUGGINGFACE_API_BASE", "")
else:
embed_url = f"https://api-inference.huggingface.co/models/{model}"
## ROUTING ##
if aembedding is True:
return self.aembedding(
input=input,
model_response=model_response,
timeout=timeout,
logging_obj=logging_obj,
headers=headers,
api_base=embed_url, # type: ignore
api_key=api_key,
client=client if isinstance(client, AsyncHTTPHandler) else None,
model=model,
optional_params=optional_params,
encoding=encoding,
)
## TRANSFORMATION ##
data = self._transform_input(
input=input,
model=model,
call_type="sync",
optional_params=optional_params,
embed_url=embed_url,
)
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": embed_url,
},
)
## COMPLETION CALL
if client is None or not isinstance(client, HTTPHandler):
client = HTTPHandler(concurrent_limit=1)
response = client.post(embed_url, headers=headers, data=json.dumps(data))
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)
embeddings = response.json()
if "error" in embeddings:
raise HuggingfaceError(status_code=500, message=embeddings["error"])
## PROCESS RESPONSE ##
return self._process_embedding_response(
embeddings=embeddings,
model_response=model_response,
model=model,
input=input,
encoding=encoding,
)
def _transform_logprobs(
self, hf_response: Optional[List]
) -> Optional[TextCompletionLogprobs]:
"""
Transform Hugging Face logprobs to OpenAI.Completion() format
"""
if hf_response is None:
return None
# Initialize an empty list for the transformed logprobs
_logprob: TextCompletionLogprobs = TextCompletionLogprobs(
text_offset=[],
token_logprobs=[],
tokens=[],
top_logprobs=[],
)
# For each Hugging Face response, transform the logprobs
for response in hf_response:
# Extract the relevant information from the response
response_details = response["details"]
top_tokens = response_details.get("top_tokens", {})
for i, token in enumerate(response_details["prefill"]):
# Extract the text of the token
token_text = token["text"]
# Extract the logprob of the token
token_logprob = token["logprob"]
# Add the token information to the 'token_info' list
_logprob.tokens.append(token_text)
_logprob.token_logprobs.append(token_logprob)
# stub this to work with llm eval harness
top_alt_tokens = {"": -1.0, "": -2.0, "": -3.0} # noqa: F601
_logprob.top_logprobs.append(top_alt_tokens)
# For each element in the 'tokens' list, extract the relevant information
for i, token in enumerate(response_details["tokens"]):
# Extract the text of the token
token_text = token["text"]
# Extract the logprob of the token
token_logprob = token["logprob"]
top_alt_tokens = {}
temp_top_logprobs = []
if top_tokens != {}:
temp_top_logprobs = top_tokens[i]
# top_alt_tokens should look like this: { "alternative_1": -1, "alternative_2": -2, "alternative_3": -3 }
for elem in temp_top_logprobs:
text = elem["text"]
logprob = elem["logprob"]
top_alt_tokens[text] = logprob
# Add the token information to the 'token_info' list
_logprob.tokens.append(token_text)
_logprob.token_logprobs.append(token_logprob)
_logprob.top_logprobs.append(top_alt_tokens)
# Add the text offset of the token
# This is computed as the sum of the lengths of all previous tokens
_logprob.text_offset.append(
sum(len(t["text"]) for t in response_details["tokens"][:i])
)
return _logprob

View file

@ -0,0 +1,590 @@
import json
import os
import time
import types
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import httpx
import litellm
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.llms.prompt_templates.common_utils import convert_content_list_to_str
from litellm.llms.prompt_templates.factory import custom_prompt, prompt_factory
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import Choices, Message, ModelResponse, Usage
from litellm.utils import token_counter
from ..common_utils import HuggingfaceError, hf_task_list, hf_tasks, output_parser
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
LoggingClass = LiteLLMLoggingObj
else:
LoggingClass = Any
tgi_models_cache = None
conv_models_cache = None
class HuggingfaceChatConfig(BaseConfig):
"""
Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate
"""
hf_task: Optional[hf_tasks] = (
None # litellm-specific param, used to know the api spec to use when calling huggingface api
)
best_of: Optional[int] = None
decoder_input_details: Optional[bool] = None
details: Optional[bool] = True # enables returning logprobs + best of
max_new_tokens: Optional[int] = None
repetition_penalty: Optional[float] = None
return_full_text: Optional[bool] = (
False # by default don't return the input as part of the output
)
seed: Optional[int] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
top_n_tokens: Optional[int] = None
top_p: Optional[int] = None
truncate: Optional[int] = None
typical_p: Optional[float] = None
watermark: Optional[bool] = None
def __init__(
self,
best_of: Optional[int] = None,
decoder_input_details: Optional[bool] = None,
details: Optional[bool] = None,
max_new_tokens: Optional[int] = None,
repetition_penalty: Optional[float] = None,
return_full_text: Optional[bool] = None,
seed: Optional[int] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_n_tokens: Optional[int] = None,
top_p: Optional[int] = None,
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: Optional[bool] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return super().get_config()
def get_special_options_params(self):
return ["use_cache", "wait_for_model"]
def get_supported_openai_params(self, model: str):
return [
"stream",
"temperature",
"max_tokens",
"max_completion_tokens",
"top_p",
"stop",
"n",
"echo",
]
def map_openai_params(
self,
non_default_params: Dict,
optional_params: Dict,
model: str,
drop_params: bool,
) -> Dict:
for param, value in non_default_params.items():
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
if param == "temperature":
if value == 0.0 or value == 0:
# hugging face exception raised when temp==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
value = 0.01
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "n":
optional_params["best_of"] = value
optional_params["do_sample"] = (
True # Need to sample if you want best of for hf inference endpoints
)
if param == "stream":
optional_params["stream"] = value
if param == "stop":
optional_params["stop"] = value
if param == "max_tokens" or param == "max_completion_tokens":
# HF TGI raises the following exception when max_new_tokens==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
if value == 0:
value = 1
optional_params["max_new_tokens"] = value
if param == "echo":
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
optional_params["decoder_input_details"] = True
return optional_params
def get_hf_api_key(self) -> Optional[str]:
return get_secret_str("HUGGINGFACE_API_KEY")
def read_tgi_conv_models(self):
try:
global tgi_models_cache, conv_models_cache
# Check if the cache is already populated
# so we don't keep on reading txt file if there are 1k requests
if (tgi_models_cache is not None) and (conv_models_cache is not None):
return tgi_models_cache, conv_models_cache
# If not, read the file and populate the cache
tgi_models = set()
script_directory = os.path.dirname(os.path.abspath(__file__))
script_directory = os.path.dirname(script_directory)
# Construct the file path relative to the script's directory
file_path = os.path.join(
script_directory,
"huggingface_llms_metadata",
"hf_text_generation_models.txt",
)
with open(file_path, "r") as file:
for line in file:
tgi_models.add(line.strip())
# Cache the set for future use
tgi_models_cache = tgi_models
# If not, read the file and populate the cache
file_path = os.path.join(
script_directory,
"huggingface_llms_metadata",
"hf_conversational_models.txt",
)
conv_models = set()
with open(file_path, "r") as file:
for line in file:
conv_models.add(line.strip())
# Cache the set for future use
conv_models_cache = conv_models
return tgi_models, conv_models
except Exception:
return set(), set()
def get_hf_task_for_model(self, model: str) -> Tuple[hf_tasks, str]:
# read text file, cast it to set
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
if model.split("/")[0] in hf_task_list:
split_model = model.split("/", 1)
return split_model[0], split_model[1] # type: ignore
tgi_models, conversational_models = self.read_tgi_conv_models()
if model in tgi_models:
return "text-generation-inference", model
elif model in conversational_models:
return "conversational", model
elif "roneneldan/TinyStories" in model:
return "text-generation", model
else:
return "text-generation-inference", model # default to tgi
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
task = litellm_params.get("task", None)
## VALIDATE API FORMAT
if task is None or not isinstance(task, str) or task not in hf_task_list:
raise Exception(
"Invalid hf task - {}. Valid formats - {}.".format(task, hf_tasks)
)
## Load Config
config = litellm.HuggingfaceConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > huggingfaceConfig(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
### MAP INPUT PARAMS
#### HANDLE SPECIAL PARAMS
special_params = self.get_special_options_params()
special_params_dict = {}
# Create a list of keys to pop after iteration
keys_to_pop = []
for k, v in optional_params.items():
if k in special_params:
special_params_dict[k] = v
keys_to_pop.append(k)
# Pop the keys from the dictionary after iteration
for k in keys_to_pop:
optional_params.pop(k)
if task == "conversational":
inference_params = deepcopy(optional_params)
inference_params.pop("details")
inference_params.pop("return_full_text")
past_user_inputs = []
generated_responses = []
text = ""
for message in messages:
if message["role"] == "user":
if text != "":
past_user_inputs.append(text)
text = convert_content_list_to_str(message)
elif message["role"] == "assistant" or message["role"] == "system":
generated_responses.append(convert_content_list_to_str(message))
data = {
"inputs": {
"text": text,
"past_user_inputs": past_user_inputs,
"generated_responses": generated_responses,
},
"parameters": inference_params,
}
elif task == "text-generation-inference":
# always send "details" and "return_full_text" as params
if model in litellm.custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = litellm.custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", None),
initial_prompt_value=model_prompt_details.get(
"initial_prompt_value", ""
),
final_prompt_value=model_prompt_details.get(
"final_prompt_value", ""
),
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages)
data = {
"inputs": prompt, # type: ignore
"parameters": optional_params,
"stream": ( # type: ignore
True
if "stream" in optional_params
and isinstance(optional_params["stream"], bool)
and optional_params["stream"] is True # type: ignore
else False
),
}
else:
# Non TGI and Conversational llms
# We need this branch, it removes 'details' and 'return_full_text' from params
if model in litellm.custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = litellm.custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", {}),
initial_prompt_value=model_prompt_details.get(
"initial_prompt_value", ""
),
final_prompt_value=model_prompt_details.get(
"final_prompt_value", ""
),
bos_token=model_prompt_details.get("bos_token", ""),
eos_token=model_prompt_details.get("eos_token", ""),
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages)
inference_params = deepcopy(optional_params)
inference_params.pop("details")
inference_params.pop("return_full_text")
data = {
"inputs": prompt, # type: ignore
}
if task == "text-generation-inference":
data["parameters"] = inference_params
data["stream"] = ( # type: ignore
True # type: ignore
if "stream" in optional_params and optional_params["stream"] is True
else False
)
### RE-ADD SPECIAL PARAMS
if len(special_params_dict.keys()) > 0:
data.update({"options": special_params_dict})
return data
def get_api_base(self, api_base: Optional[str], model: str) -> str:
"""
Get the API base for the Huggingface API.
Do not add the chat/embedding/rerank extension here. Let the handler do this.
"""
if "https" in model:
completion_url = model
elif api_base is not None:
completion_url = api_base
elif "HF_API_BASE" in os.environ:
completion_url = os.getenv("HF_API_BASE", "")
elif "HUGGINGFACE_API_BASE" in os.environ:
completion_url = os.getenv("HUGGINGFACE_API_BASE", "")
else:
completion_url = f"https://api-inference.huggingface.co/models/{model}"
return completion_url
def validate_environment(
self,
headers: Dict,
model: str,
messages: List[AllMessageValues],
optional_params: Dict,
api_key: Optional[str] = None,
) -> Dict:
default_headers = {
"content-type": "application/json",
}
if api_key is not None:
default_headers["Authorization"] = (
f"Bearer {api_key}" # Huggingface Inference Endpoint default is to accept bearer tokens
)
headers = {**headers, **default_headers}
return headers
def _transform_messages(
self,
messages: List[AllMessageValues],
) -> List[AllMessageValues]:
return messages
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return HuggingfaceError(
status_code=status_code, message=error_message, headers=headers
)
def _convert_streamed_response_to_complete_response(
self,
response: httpx.Response,
logging_obj: LoggingClass,
model: str,
data: dict,
api_key: Optional[str] = None,
) -> List[Dict[str, Any]]:
streamed_response = CustomStreamWrapper(
completion_stream=response.iter_lines(),
model=model,
custom_llm_provider="huggingface",
logging_obj=logging_obj,
)
content = ""
for chunk in streamed_response:
content += chunk["choices"][0]["delta"]["content"]
completion_response: List[Dict[str, Any]] = [{"generated_text": content}]
## LOGGING
logging_obj.post_call(
input=data,
api_key=api_key,
original_response=completion_response,
additional_args={"complete_input_dict": data},
)
return completion_response
def convert_to_model_response_object( # noqa: PLR0915
self,
completion_response: Union[List[Dict[str, Any]], Dict[str, Any]],
model_response: litellm.ModelResponse,
task: Optional[hf_tasks],
optional_params: dict,
encoding: Any,
messages: List[AllMessageValues],
model: str,
):
if task is None:
task = "text-generation-inference" # default to tgi
if task == "conversational":
if len(completion_response["generated_text"]) > 0: # type: ignore
model_response.choices[0].message.content = completion_response[ # type: ignore
"generated_text"
]
elif task == "text-generation-inference":
if (
not isinstance(completion_response, list)
or not isinstance(completion_response[0], dict)
or "generated_text" not in completion_response[0]
):
raise HuggingfaceError(
status_code=422,
message=f"response is not in expected format - {completion_response}",
headers=None,
)
if len(completion_response[0]["generated_text"]) > 0:
model_response.choices[0].message.content = output_parser( # type: ignore
completion_response[0]["generated_text"]
)
## GETTING LOGPROBS + FINISH REASON
if (
"details" in completion_response[0]
and "tokens" in completion_response[0]["details"]
):
model_response.choices[0].finish_reason = completion_response[0][
"details"
]["finish_reason"]
sum_logprob = 0
for token in completion_response[0]["details"]["tokens"]:
if token["logprob"] is not None:
sum_logprob += token["logprob"]
setattr(model_response.choices[0].message, "_logprob", sum_logprob) # type: ignore
if "best_of" in optional_params and optional_params["best_of"] > 1:
if (
"details" in completion_response[0]
and "best_of_sequences" in completion_response[0]["details"]
):
choices_list = []
for idx, item in enumerate(
completion_response[0]["details"]["best_of_sequences"]
):
sum_logprob = 0
for token in item["tokens"]:
if token["logprob"] is not None:
sum_logprob += token["logprob"]
if len(item["generated_text"]) > 0:
message_obj = Message(
content=output_parser(item["generated_text"]),
logprobs=sum_logprob,
)
else:
message_obj = Message(content=None)
choice_obj = Choices(
finish_reason=item["finish_reason"],
index=idx + 1,
message=message_obj,
)
choices_list.append(choice_obj)
model_response.choices.extend(choices_list)
elif task == "text-classification":
model_response.choices[0].message.content = json.dumps( # type: ignore
completion_response
)
else:
if (
isinstance(completion_response, list)
and len(completion_response[0]["generated_text"]) > 0
):
model_response.choices[0].message.content = output_parser( # type: ignore
completion_response[0]["generated_text"]
)
## CALCULATING USAGE
prompt_tokens = 0
try:
prompt_tokens = token_counter(model=model, messages=messages)
except Exception:
# this should remain non blocking we should not block a response returning if calculating usage fails
pass
output_text = model_response["choices"][0]["message"].get("content", "")
if output_text is not None and len(output_text) > 0:
completion_tokens = 0
try:
completion_tokens = len(
encoding.encode(
model_response["choices"][0]["message"].get("content", "")
)
) ##[TODO] use the llama2 tokenizer here
except Exception:
# this should remain non blocking we should not block a response returning if calculating usage fails
pass
else:
completion_tokens = 0
model_response.created = int(time.time())
model_response.model = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
model_response._hidden_params["original_response"] = completion_response
return model_response
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LoggingClass,
request_data: Dict,
messages: List[AllMessageValues],
optional_params: Dict,
litellm_params: Dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
## Some servers might return streaming responses even though stream was not set to true. (e.g. Baseten)
task = litellm_params.get("task", None)
is_streamed = False
if (
raw_response.__dict__["headers"].get("Content-Type", "")
== "text/event-stream"
):
is_streamed = True
# iterate over the complete streamed response, and return the final answer
if is_streamed:
completion_response = self._convert_streamed_response_to_complete_response(
response=raw_response,
logging_obj=logging_obj,
model=model,
data=request_data,
api_key=api_key,
)
else:
## LOGGING
logging_obj.post_call(
input=request_data,
api_key=api_key,
original_response=raw_response.text,
additional_args={"complete_input_dict": request_data},
)
## RESPONSE OBJECT
try:
completion_response = raw_response.json()
if isinstance(completion_response, dict):
completion_response = [completion_response]
except Exception:
raise HuggingfaceError(
message=f"Original Response received: {raw_response.text}",
status_code=raw_response.status_code,
)
if isinstance(completion_response, dict) and "error" in completion_response:
raise HuggingfaceError(
message=completion_response["error"], # type: ignore
status_code=raw_response.status_code,
)
return self.convert_to_model_response_object(
completion_response=completion_response,
model_response=model_response,
task=task if task is not None and task in hf_task_list else None,
optional_params=optional_params,
encoding=encoding,
messages=messages,
model=model,
)

View file

@ -0,0 +1,45 @@
from typing import Literal, Optional, Union
import httpx
from litellm.llms.base_llm.transformation import BaseLLMException
class HuggingfaceError(BaseLLMException):
def __init__(
self,
status_code: int,
message: str,
headers: Optional[Union[dict, httpx.Headers]] = None,
):
super().__init__(status_code=status_code, message=message, headers=headers)
hf_tasks = Literal[
"text-generation-inference",
"conversational",
"text-classification",
"text-generation",
]
hf_task_list = [
"text-generation-inference",
"conversational",
"text-classification",
"text-generation",
]
def output_parser(generated_text: str):
"""
Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens.
Initial issue that prompted this - https://github.com/BerriAI/litellm/issues/763
"""
chat_template_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "<s>", "</s>"]
for token in chat_template_tokens:
if generated_text.strip().startswith(token):
generated_text = generated_text.replace(token, "", 1)
if generated_text.endswith(token):
generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1]
return generated_text

File diff suppressed because it is too large Load diff

View file

@ -4,59 +4,42 @@ import time
import traceback import traceback
import types import types
from enum import Enum from enum import Enum
from typing import Any, Callable, List, Optional from typing import Any, Callable, List, Optional, Union
import requests # type: ignore from httpx._models import Headers
import litellm import litellm
from litellm.llms.base_llm.transformation import BaseLLMException
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
from litellm.utils import Choices, Message, ModelResponse, Usage from litellm.utils import Choices, Message, ModelResponse, Usage
class MaritalkError(Exception): class MaritalkError(BaseLLMException):
def __init__(self, status_code, message): def __init__(
self.status_code = status_code self,
self.message = message status_code: int,
super().__init__( message: str,
self.message headers: Optional[Union[dict, Headers]] = None,
) # Call the base class constructor with the parameters it needs ):
super().__init__(status_code=status_code, message=message, headers=headers)
class MaritTalkConfig: class MaritalkConfig(OpenAIGPTConfig):
"""
The class `MaritTalkConfig` provides configuration for the MaritTalk's API interface. Here are the parameters:
- `max_tokens` (integer): Maximum number of tokens the model will generate as part of the response. Default is 1.
- `model` (string): The model used for conversation. Default is 'maritalk'.
- `do_sample` (boolean): If set to True, the API will generate a response using sampling. Default is True.
- `temperature` (number): A non-negative float controlling the randomness in generation. Lower temperatures result in less random generations. Default is 0.7.
- `top_p` (number): Selection threshold for token inclusion based on cumulative probability. Default is 0.95.
- `repetition_penalty` (number): Penalty for repetition in the generated conversation. Default is 1.
- `stopping_tokens` (list of string): List of tokens where the conversation can be stopped/stopped.
"""
max_tokens: Optional[int] = None
model: Optional[str] = None
do_sample: Optional[bool] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
repetition_penalty: Optional[float] = None
stopping_tokens: Optional[List[str]] = None
def __init__( def __init__(
self, self,
max_tokens: Optional[int] = None, frequency_penalty: Optional[float] = None,
model: Optional[str] = None, presence_penalty: Optional[float] = None,
do_sample: Optional[bool] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
repetition_penalty: Optional[float] = None, top_k: Optional[int] = None,
stopping_tokens: Optional[List[str]] = None, temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
stop: Optional[List[str]] = None,
stream: Optional[bool] = None,
stream_options: Optional[dict] = None,
tools: Optional[List[dict]] = None,
tool_choice: Optional[Union[str, dict]] = None,
) -> None: ) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
@ -65,129 +48,27 @@ class MaritTalkConfig:
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return { return super().get_config()
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self, model: str) -> List:
return [
"frequency_penalty",
"presence_penalty",
"top_p",
"top_k",
"temperature",
"max_tokens",
"n",
"stop",
"stream",
"stream_options",
"tools",
"tool_choice",
]
def validate_environment(api_key): def get_error_class(
headers = { self, error_message: str, status_code: int, headers: Union[dict, Headers]
"accept": "application/json", ) -> BaseLLMException:
"content-type": "application/json", return MaritalkError(
} status_code=status_code, message=error_message, headers=headers
if api_key:
headers["Authorization"] = f"Key {api_key}"
return headers
def completion(
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params: dict,
litellm_params=None,
logger_fn=None,
):
headers = validate_environment(api_key)
completion_url = api_base
model = model
## Load Config
config = litellm.MaritTalkConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > maritalk_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
data = {
"messages": messages,
**optional_params,
}
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=api_key,
additional_args={"complete_input_dict": data},
)
## COMPLETION CALL
response = requests.post(
completion_url,
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"] if "stream" in optional_params else False,
)
if "stream" in optional_params and optional_params["stream"] is True:
return response.iter_lines()
else:
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
) )
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
completion_response = response.json()
if "error" in completion_response:
raise MaritalkError(
message=completion_response["error"],
status_code=response.status_code,
)
else:
try:
if len(completion_response["answer"]) > 0:
model_response.choices[0].message.content = completion_response[ # type: ignore
"answer"
]
except Exception:
raise MaritalkError(
message=response.text, status_code=response.status_code
)
## CALCULATING USAGE
prompt = "".join(m["content"] for m in messages)
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
model_response.created = int(time.time())
model_response.model = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
def embedding(
model: str,
input: list,
api_key: Optional[str],
logging_obj: Any,
model_response=None,
encoding=None,
):
pass

View file

@ -9,11 +9,16 @@ Docs - https://docs.mistral.ai/api/
import types import types
from typing import List, Literal, Optional, Tuple, Union from typing import List, Literal, Optional, Tuple, Union
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
from litellm.llms.prompt_templates.common_utils import (
handle_messages_with_content_list_to_str_conversion,
strip_none_values_from_message,
)
from litellm.secret_managers.main import get_secret_str from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues from litellm.types.llms.openai import AllMessageValues
class MistralConfig: class MistralConfig(OpenAIGPTConfig):
""" """
Reference: https://docs.mistral.ai/api/ Reference: https://docs.mistral.ai/api/
@ -67,23 +72,9 @@ class MistralConfig:
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return { return super().get_config()
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self): def get_supported_openai_params(self, model: str) -> List[str]:
return [ return [
"stream", "stream",
"temperature", "temperature",
@ -104,7 +95,13 @@ class MistralConfig:
else: # openai 'tool_choice' object param not supported by Mistral API else: # openai 'tool_choice' object param not supported by Mistral API
return "any" return "any"
def map_openai_params(self, non_default_params: dict, optional_params: dict): def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
for param, value in non_default_params.items(): for param, value in non_default_params.items():
if param == "max_tokens": if param == "max_tokens":
optional_params["max_tokens"] = value optional_params["max_tokens"] = value
@ -150,8 +147,9 @@ class MistralConfig:
) )
return api_base, dynamic_api_key return api_base, dynamic_api_key
@classmethod def _transform_messages(
def _transform_messages(cls, messages: List[AllMessageValues]): self, messages: List[AllMessageValues]
) -> List[AllMessageValues]:
""" """
- handles scenario where content is list and not string - handles scenario where content is list and not string
- content list is just text, and no images - content list is just text, and no images
@ -160,48 +158,36 @@ class MistralConfig:
Motivation: mistral api doesn't support content as a list Motivation: mistral api doesn't support content as a list
""" """
new_messages = [] ## 1. If 'image_url' in content, then return as is
for m in messages: for m in messages:
special_keys = ["role", "content", "tool_calls", "function_call"] _content_block = m.get("content")
extra_args = {} if _content_block and isinstance(_content_block, list):
if isinstance(m, dict): for c in _content_block:
for k, v in m.items(): if c.get("type") == "image_url":
if k not in special_keys:
extra_args[k] = v
texts = ""
_content = m.get("content")
if _content is not None and isinstance(_content, list):
for c in _content:
_text: Optional[str] = c.get("text")
if c["type"] == "image_url":
return messages return messages
elif c["type"] == "text" and isinstance(_text, str):
texts += _text
elif _content is not None and isinstance(_content, str):
texts = _content
new_m = {"role": m["role"], "content": texts, **extra_args} ## 2. If content is list, then convert to string
messages = handle_messages_with_content_list_to_str_conversion(messages)
if m.get("tool_calls"): ## 3. Handle name in message
new_m["tool_calls"] = m.get("tool_calls") new_messages: List[AllMessageValues] = []
for m in messages:
m = MistralConfig._handle_name_in_message(m)
m = strip_none_values_from_message(m) # prevents 'extra_forbidden' error
new_messages.append(m)
new_m = cls._handle_name_in_message(new_m)
new_messages.append(new_m)
return new_messages return new_messages
@classmethod @classmethod
def _handle_name_in_message(cls, message: dict) -> dict: def _handle_name_in_message(cls, message: AllMessageValues) -> AllMessageValues:
""" """
Mistral API only supports `name` in tool messages Mistral API only supports `name` in tool messages
If role == tool, then we keep `name` If role == tool, then we keep `name`
Otherwise, we drop `name` Otherwise, we drop `name`
""" """
if message.get("name") is not None: _name = message.get("name") # type: ignore
if message["role"] == "tool": if _name is not None and message["role"] != "tool":
message["name"] = message.get("name") message.pop("name", None) # type: ignore
else:
message.pop("name", None)
return message return message

View file

@ -0,0 +1,140 @@
import json
import os
import time
import types
from enum import Enum
from typing import Any, Callable, List, Optional, Union
import httpx
import litellm
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client,
get_async_httpx_client,
)
from litellm.types.llms.openai import AllMessageValues
from litellm.utils import ModelResponse, Usage
from ..common_utils import NLPCloudError
from .transformation import NLPCloudConfig
nlp_config = NLPCloudConfig()
def completion(
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params: dict,
litellm_params: dict,
logger_fn=None,
default_max_tokens_to_sample=None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
headers={},
):
headers = nlp_config.validate_environment(
api_key=api_key,
headers=headers,
model=model,
messages=messages,
optional_params=optional_params,
)
## Load Config
config = litellm.NLPCloudConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > togetherai_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
completion_url_fragment_1 = api_base
completion_url_fragment_2 = "/generation"
model = model
completion_url = completion_url_fragment_1 + model + completion_url_fragment_2
data = nlp_config.transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
## LOGGING
logging_obj.pre_call(
input=None,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": completion_url,
},
)
## COMPLETION CALL
if client is None or not isinstance(client, HTTPHandler):
client = _get_httpx_client()
response = client.post(
completion_url,
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"] if "stream" in optional_params else False,
)
if "stream" in optional_params and optional_params["stream"] is True:
return clean_and_iterate_chunks(response)
else:
return nlp_config.transform_response(
model=model,
raw_response=response,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
request_data=data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
)
# def clean_and_iterate_chunks(response):
# def process_chunk(chunk):
# print(f"received chunk: {chunk}")
# cleaned_chunk = chunk.decode("utf-8")
# # Perform further processing based on your needs
# return cleaned_chunk
# for line in response.iter_lines():
# if line:
# yield process_chunk(line)
def clean_and_iterate_chunks(response):
buffer = b""
for chunk in response.iter_content(chunk_size=1024):
if not chunk:
break
buffer += chunk
while b"\x00" in buffer:
buffer = buffer.replace(b"\x00", b"")
yield buffer.decode("utf-8")
buffer = b""
# No more data expected, yield any remaining data in the buffer
if buffer:
yield buffer.decode("utf-8")
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -1,26 +1,25 @@
import json import json
import os
import time import time
import types from typing import TYPE_CHECKING, Any, List, Optional, Union
from enum import Enum
from typing import Callable, Optional
import requests # type: ignore import httpx
import litellm from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.llms.prompt_templates.common_utils import convert_content_list_to_str
from litellm.types.llms.openai import AllMessageValues
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage
from ..common_utils import NLPCloudError
class NLPCloudError(Exception): if TYPE_CHECKING:
def __init__(self, status_code, message): from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
self.status_code = status_code
self.message = message LoggingClass = LiteLLMLoggingObj
super().__init__( else:
self.message LoggingClass = Any
) # Call the base class constructor with the parameters it needs
class NLPCloudConfig: class NLPCloudConfig(BaseConfig):
""" """
Reference: https://docs.nlpcloud.com/#generation Reference: https://docs.nlpcloud.com/#generation
@ -84,106 +83,119 @@ class NLPCloudConfig:
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return { return super().get_config()
k: v
for k, v in cls.__dict__.items() def validate_environment(
if not k.startswith("__") self,
and not isinstance( headers: dict,
v, model: str,
( messages: List[AllMessageValues],
types.FunctionType, optional_params: dict,
types.BuiltinFunctionType, api_key: Optional[str] = None,
classmethod, ) -> dict:
staticmethod, headers = {
), "accept": "application/json",
) "content-type": "application/json",
and v is not None }
if api_key:
headers["Authorization"] = f"Token {api_key}"
return headers
def get_supported_openai_params(self, model: str) -> List:
return [
"max_tokens",
"stream",
"temperature",
"top_p",
"presence_penalty",
"frequency_penalty",
"n",
"stop",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["max_length"] = value
if param == "stream":
optional_params["stream"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "presence_penalty":
optional_params["presence_penalty"] = value
if param == "frequency_penalty":
optional_params["frequency_penalty"] = value
if param == "n":
optional_params["num_return_sequences"] = value
if param == "stop":
optional_params["stop_sequences"] = value
return optional_params
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return NLPCloudError(
status_code=status_code, message=error_message, headers=headers
)
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
text = " ".join(convert_content_list_to_str(message) for message in messages)
data = {
"text": text,
**optional_params,
} }
return data
def validate_environment(api_key): def transform_response(
headers = { self,
"accept": "application/json", model: str,
"content-type": "application/json", raw_response: httpx.Response,
} model_response: ModelResponse,
if api_key: logging_obj: LoggingClass,
headers["Authorization"] = f"Token {api_key}" request_data: dict,
return headers messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
def completion( encoding: Any,
model: str, api_key: Optional[str] = None,
messages: list, json_mode: Optional[bool] = None,
api_base: str, ) -> ModelResponse:
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params: dict,
litellm_params=None,
logger_fn=None,
default_max_tokens_to_sample=None,
):
headers = validate_environment(api_key)
## Load Config
config = litellm.NLPCloudConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > togetherai_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
completion_url_fragment_1 = api_base
completion_url_fragment_2 = "/generation"
model = model
text = " ".join(message["content"] for message in messages)
data = {
"text": text,
**optional_params,
}
completion_url = completion_url_fragment_1 + model + completion_url_fragment_2
## LOGGING
logging_obj.pre_call(
input=text,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": completion_url,
},
)
## COMPLETION CALL
response = requests.post(
completion_url,
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"] if "stream" in optional_params else False,
)
if "stream" in optional_params and optional_params["stream"] is True:
return clean_and_iterate_chunks(response)
else:
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=text, input=None,
api_key=api_key, api_key=api_key,
original_response=response.text, original_response=raw_response.text,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": request_data},
) )
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT ## RESPONSE OBJECT
try: try:
completion_response = response.json() completion_response = raw_response.json()
except Exception: except Exception:
raise NLPCloudError(message=response.text, status_code=response.status_code) raise NLPCloudError(
message=raw_response.text, status_code=raw_response.status_code
)
if "error" in completion_response: if "error" in completion_response:
raise NLPCloudError( raise NLPCloudError(
message=completion_response["error"], message=completion_response["error"],
status_code=response.status_code, status_code=raw_response.status_code,
) )
else: else:
try: try:
@ -194,7 +206,7 @@ def completion(
except Exception: except Exception:
raise NLPCloudError( raise NLPCloudError(
message=json.dumps(completion_response), message=json.dumps(completion_response),
status_code=response.status_code, status_code=raw_response.status_code,
) )
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
@ -210,37 +222,3 @@ def completion(
) )
setattr(model_response, "usage", usage) setattr(model_response, "usage", usage)
return model_response return model_response
# def clean_and_iterate_chunks(response):
# def process_chunk(chunk):
# print(f"received chunk: {chunk}")
# cleaned_chunk = chunk.decode("utf-8")
# # Perform further processing based on your needs
# return cleaned_chunk
# for line in response.iter_lines():
# if line:
# yield process_chunk(line)
def clean_and_iterate_chunks(response):
buffer = b""
for chunk in response.iter_content(chunk_size=1024):
if not chunk:
break
buffer += chunk
while b"\x00" in buffer:
buffer = buffer.replace(b"\x00", b"")
yield buffer.decode("utf-8")
buffer = b""
# No more data expected, yield any remaining data in the buffer
if buffer:
yield buffer.decode("utf-8")
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -0,0 +1,15 @@
from typing import Optional, Union
import httpx
from litellm.llms.base_llm.transformation import BaseLLMException
class NLPCloudError(BaseLLMException):
def __init__(
self,
status_code: int,
message: str,
headers: Optional[Union[dict, httpx.Headers]] = None,
):
super().__init__(status_code=status_code, message=message, headers=headers)

View file

@ -11,8 +11,10 @@ API calling is done using the OpenAI SDK with an api_base
import types import types
from typing import Optional, Union from typing import Optional, Union
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
class NvidiaNimConfig:
class NvidiaNimConfig(OpenAIGPTConfig):
""" """
Reference: https://docs.api.nvidia.com/nim/reference/databricks-dbrx-instruct-infer Reference: https://docs.api.nvidia.com/nim/reference/databricks-dbrx-instruct-infer
@ -42,21 +44,7 @@ class NvidiaNimConfig:
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return { return super().get_config()
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self, model: str) -> list: def get_supported_openai_params(self, model: str) -> list:
""" """
@ -132,7 +120,11 @@ class NvidiaNimConfig:
] ]
def map_openai_params( def map_openai_params(
self, model: str, non_default_params: dict, optional_params: dict self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict: ) -> dict:
supported_openai_params = self.get_supported_openai_params(model=model) supported_openai_params = self.get_supported_openai_params(model=model)
for param, value in non_default_params.items(): for param, value in non_default_params.items():

View file

@ -242,6 +242,7 @@ class OllamaConfig(BaseConfig):
request_data: dict, request_data: dict,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
litellm_params: dict,
encoding: str, encoding: str,
api_key: Optional[str] = None, api_key: Optional[str] = None,
json_mode: Optional[bool] = None, json_mode: Optional[bool] = None,

View file

@ -14,6 +14,7 @@ from pydantic import BaseModel
import litellm import litellm
from litellm import verbose_logger from litellm import verbose_logger
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
from litellm.types.llms.ollama import OllamaToolCall, OllamaToolCallFunction from litellm.types.llms.ollama import OllamaToolCall, OllamaToolCallFunction
from litellm.types.llms.openai import ChatCompletionAssistantToolCall from litellm.types.llms.openai import ChatCompletionAssistantToolCall
from litellm.types.utils import StreamingChoices from litellm.types.utils import StreamingChoices
@ -30,7 +31,7 @@ class OllamaError(Exception):
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class OllamaChatConfig: class OllamaChatConfig(OpenAIGPTConfig):
""" """
Reference: https://github.com/ollama/ollama/blob/main/docs/api.md#parameters Reference: https://github.com/ollama/ollama/blob/main/docs/api.md#parameters
@ -81,15 +82,10 @@ class OllamaChatConfig:
num_thread: Optional[int] = None num_thread: Optional[int] = None
repeat_last_n: Optional[int] = None repeat_last_n: Optional[int] = None
repeat_penalty: Optional[float] = None repeat_penalty: Optional[float] = None
temperature: Optional[float] = None
seed: Optional[int] = None seed: Optional[int] = None
stop: Optional[list] = (
None # stop is a list based on this - https://github.com/ollama/ollama/pull/442
)
tfs_z: Optional[float] = None tfs_z: Optional[float] = None
num_predict: Optional[int] = None num_predict: Optional[int] = None
top_k: Optional[int] = None top_k: Optional[int] = None
top_p: Optional[float] = None
system: Optional[str] = None system: Optional[str] = None
template: Optional[str] = None template: Optional[str] = None
@ -120,26 +116,9 @@ class OllamaChatConfig:
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return { return super().get_config()
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and k != "function_name" # special param for function calling
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params( def get_supported_openai_params(self, model: str):
self,
):
return [ return [
"max_tokens", "max_tokens",
"max_completion_tokens", "max_completion_tokens",
@ -156,8 +135,12 @@ class OllamaChatConfig:
] ]
def map_openai_params( def map_openai_params(
self, model: str, non_default_params: dict, optional_params: dict self,
): non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
for param, value in non_default_params.items(): for param, value in non_default_params.items():
if param == "max_tokens" or param == "max_completion_tokens": if param == "max_tokens" or param == "max_completion_tokens":
optional_params["num_predict"] = value optional_params["num_predict"] = value

View file

@ -6,28 +6,14 @@ from typing import Any, Callable, Optional
import requests # type: ignore import requests # type: ignore
from litellm.llms.custom_httpx.http_handler import HTTPHandler, _get_httpx_client
from litellm.utils import EmbeddingResponse, ModelResponse, Usage from litellm.utils import EmbeddingResponse, ModelResponse, Usage
from .prompt_templates.factory import custom_prompt, prompt_factory from ...prompt_templates.factory import custom_prompt, prompt_factory
from ..common_utils import OobaboogaError
from .transformation import OobaboogaConfig
oobabooga_config = OobaboogaConfig()
class OobaboogaError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
def validate_environment(api_key):
headers = {
"accept": "application/json",
"content-type": "application/json",
}
if api_key:
headers["Authorization"] = f"Token {api_key}"
return headers
def completion( def completion(
@ -40,12 +26,18 @@ def completion(
api_key, api_key,
logging_obj, logging_obj,
optional_params: dict, optional_params: dict,
litellm_params: dict,
custom_prompt_dict={}, custom_prompt_dict={},
litellm_params=None,
logger_fn=None, logger_fn=None,
default_max_tokens_to_sample=None, default_max_tokens_to_sample=None,
): ):
headers = validate_environment(api_key) headers = oobabooga_config.validate_environment(
api_key=api_key,
headers={},
model=model,
messages=messages,
optional_params=optional_params,
)
if "https" in model: if "https" in model:
completion_url = model completion_url = model
elif api_base: elif api_base:
@ -58,10 +50,13 @@ def completion(
model = model model = model
completion_url = completion_url + "/v1/chat/completions" completion_url = completion_url + "/v1/chat/completions"
data = { data = oobabooga_config.transform_request(
"messages": messages, model=model,
**optional_params, messages=messages,
} optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
@ -70,8 +65,8 @@ def completion(
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
## COMPLETION CALL ## COMPLETION CALL
client = _get_httpx_client()
response = requests.post( response = client.post(
completion_url, completion_url,
headers=headers, headers=headers,
data=json.dumps(data), data=json.dumps(data),
@ -80,44 +75,18 @@ def completion(
if "stream" in optional_params and optional_params["stream"] is True: if "stream" in optional_params and optional_params["stream"] is True:
return response.iter_lines() return response.iter_lines()
else: else:
## LOGGING return oobabooga_config.transform_response(
logging_obj.post_call( model=model,
input=messages, raw_response=response,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key, api_key=api_key,
original_response=response.text, request_data=data,
additional_args={"complete_input_dict": data}, messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
) )
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
try:
completion_response = response.json()
except Exception:
raise OobaboogaError(
message=response.text, status_code=response.status_code
)
if "error" in completion_response:
raise OobaboogaError(
message=completion_response["error"],
status_code=response.status_code,
)
else:
try:
model_response.choices[0].message.content = completion_response["choices"][0]["message"]["content"] # type: ignore
except Exception:
raise OobaboogaError(
message=json.dumps(completion_response),
status_code=response.status_code,
)
model_response.created = int(time.time())
model_response.model = model
usage = Usage(
prompt_tokens=completion_response["usage"]["prompt_tokens"],
completion_tokens=completion_response["usage"]["completion_tokens"],
total_tokens=completion_response["usage"]["total_tokens"],
)
setattr(model_response, "usage", usage)
return model_response
def embedding( def embedding(
@ -127,7 +96,7 @@ def embedding(
api_key: Optional[str], api_key: Optional[str],
api_base: Optional[str], api_base: Optional[str],
logging_obj: Any, logging_obj: Any,
optional_params=None, optional_params: dict,
encoding=None, encoding=None,
): ):
# Create completion URL # Create completion URL
@ -153,7 +122,13 @@ def embedding(
) )
# Send POST request # Send POST request
headers = validate_environment(api_key) headers = oobabooga_config.validate_environment(
api_key=api_key,
headers={},
model=model,
messages=[],
optional_params=optional_params,
)
response = requests.post(embeddings_url, headers=headers, json=data) response = requests.post(embeddings_url, headers=headers, json=data)
if not response.ok: if not response.ok:
raise OobaboogaError(message=response.text, status_code=response.status_code) raise OobaboogaError(message=response.text, status_code=response.status_code)

View file

@ -0,0 +1,110 @@
import json
import time
import types
from typing import TYPE_CHECKING, Any, List, Optional, Union
import httpx
import litellm
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
from litellm.llms.prompt_templates.common_utils import convert_content_list_to_str
from litellm.llms.prompt_templates.factory import custom_prompt, prompt_factory
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import Choices, Message, ModelResponse, Usage
from litellm.utils import token_counter
from ..common_utils import OobaboogaError
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
LoggingClass = LiteLLMLoggingObj
else:
LoggingClass = Any
class OobaboogaConfig(OpenAIGPTConfig):
def _transform_messages(
self, messages: List[AllMessageValues]
) -> List[AllMessageValues]:
return messages
def get_error_class(
self,
error_message: str,
status_code: int,
headers: Optional[Union[dict, httpx.Headers]] = None,
) -> BaseLLMException:
return OobaboogaError(
status_code=status_code, message=error_message, headers=headers
)
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LoggingClass,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=raw_response.text,
additional_args={"complete_input_dict": request_data},
)
## RESPONSE OBJECT
try:
completion_response = raw_response.json()
except Exception:
raise OobaboogaError(
message=raw_response.text, status_code=raw_response.status_code
)
if "error" in completion_response:
raise OobaboogaError(
message=completion_response["error"],
status_code=raw_response.status_code,
)
else:
try:
model_response.choices[0].message.content = completion_response["choices"][0]["message"]["content"] # type: ignore
except Exception as e:
raise OobaboogaError(
message=str(e),
status_code=raw_response.status_code,
)
model_response.created = int(time.time())
model_response.model = model
usage = Usage(
prompt_tokens=completion_response["usage"]["prompt_tokens"],
completion_tokens=completion_response["usage"]["completion_tokens"],
total_tokens=completion_response["usage"]["total_tokens"],
)
setattr(model_response, "usage", usage)
return model_response
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
) -> dict:
headers = {
"accept": "application/json",
"content-type": "application/json",
}
if api_key is not None:
headers["Authorization"] = f"Token {api_key}"
return headers

View file

@ -0,0 +1,15 @@
from typing import Optional, Union
import httpx
from litellm.llms.base_llm.transformation import BaseLLMException
class OobaboogaError(BaseLLMException):
def __init__(
self,
status_code: int,
message: str,
headers: Optional[Union[dict, httpx.Headers]] = None,
):
super().__init__(status_code=status_code, message=message, headers=headers)

View file

@ -197,7 +197,8 @@ class OpenAIGPTConfig(BaseConfig):
request_data: dict, request_data: dict,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
encoding: str, litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None, api_key: Optional[str] = None,
json_mode: Optional[bool] = None, json_mode: Optional[bool] = None,
) -> ModelResponse: ) -> ModelResponse:

View file

@ -1,63 +1,3 @@
""" """
Handler file for calls to OpenAI's o1 family of models LLM Calling done in `openai/openai.py`
Written separately to handle faking streaming for o1 models.
""" """
import asyncio
from typing import Any, Callable, List, Optional, Union
from httpx._config import Timeout
from litellm.llms.bedrock.chat.invoke_handler import MockResponseIterator
from litellm.llms.openai.openai import OpenAIChatCompletion
from litellm.types.utils import ModelResponse
from litellm.utils import CustomStreamWrapper
class OpenAIO1ChatCompletion(OpenAIChatCompletion):
def completion(
self,
model_response: ModelResponse,
timeout: Union[float, Timeout],
optional_params: dict,
logging_obj: Any,
model: Optional[str] = None,
messages: Optional[list] = None,
print_verbose: Optional[Callable[..., Any]] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
acompletion: bool = False,
litellm_params=None,
logger_fn=None,
headers: Optional[dict] = None,
custom_prompt_dict: dict = {},
client=None,
organization: Optional[str] = None,
custom_llm_provider: Optional[str] = None,
drop_params: Optional[bool] = None,
):
# stream: Optional[bool] = optional_params.pop("stream", False)
response = super().completion(
model_response,
timeout,
optional_params,
logging_obj,
model,
messages,
print_verbose,
api_key,
api_base,
acompletion,
litellm_params,
logger_fn,
headers,
custom_prompt_dict,
client,
organization,
custom_llm_provider,
drop_params,
)
return response

View file

@ -3,7 +3,7 @@ Common helpers / utils across al OpenAI endpoints
""" """
import json import json
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, Union
import httpx import httpx
import openai import openai
@ -18,7 +18,7 @@ class OpenAIError(BaseLLMException):
message: str, message: str,
request: Optional[httpx.Request] = None, request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None, response: Optional[httpx.Response] = None,
headers: Optional[httpx.Headers] = None, headers: Optional[Union[dict, httpx.Headers]] = None,
): ):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message

View file

@ -4,7 +4,17 @@ import os
import time import time
import traceback import traceback
import types import types
from typing import Any, Callable, Coroutine, Iterable, Literal, Optional, Union, cast from typing import (
Any,
Callable,
Coroutine,
Iterable,
List,
Literal,
Optional,
Union,
cast,
)
import httpx import httpx
import openai import openai
@ -18,6 +28,7 @@ import litellm
from litellm import LlmProviders from litellm import LlmProviders
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS
from litellm.secret_managers.main import get_secret_str from litellm.secret_managers.main import get_secret_str
from litellm.types.utils import ProviderField from litellm.types.utils import ProviderField
@ -35,6 +46,7 @@ from litellm.utils import (
from ...types.llms.openai import * from ...types.llms.openai import *
from ..base import BaseLLM from ..base import BaseLLM
from ..prompt_templates.factory import custom_prompt, prompt_factory from ..prompt_templates.factory import custom_prompt, prompt_factory
from .chat.gpt_transformation import OpenAIGPTConfig
from .common_utils import OpenAIError, drop_params_from_unprocessable_entity_error from .common_utils import OpenAIError, drop_params_from_unprocessable_entity_error
@ -81,135 +93,7 @@ class MistralEmbeddingConfig:
return optional_params return optional_params
class DeepInfraConfig: class OpenAIConfig(BaseConfig):
"""
Reference: https://deepinfra.com/docs/advanced/openai_api
The class `DeepInfra` provides configuration for the DeepInfra's Chat Completions API interface. Below are the parameters:
"""
frequency_penalty: Optional[int] = None
function_call: Optional[Union[str, dict]] = None
functions: Optional[list] = None
logit_bias: Optional[dict] = None
max_tokens: Optional[int] = None
n: Optional[int] = None
presence_penalty: Optional[int] = None
stop: Optional[Union[str, list]] = None
temperature: Optional[int] = None
top_p: Optional[int] = None
response_format: Optional[dict] = None
tools: Optional[list] = None
tool_choice: Optional[Union[str, dict]] = None
def __init__(
self,
frequency_penalty: Optional[int] = None,
function_call: Optional[Union[str, dict]] = None,
functions: Optional[list] = None,
logit_bias: Optional[dict] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[int] = None,
stop: Optional[Union[str, list]] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
response_format: Optional[dict] = None,
tools: Optional[list] = None,
tool_choice: Optional[Union[str, dict]] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return [
"stream",
"frequency_penalty",
"function_call",
"functions",
"logit_bias",
"max_tokens",
"max_completion_tokens",
"n",
"presence_penalty",
"stop",
"temperature",
"top_p",
"response_format",
"tools",
"tool_choice",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
supported_openai_params = self.get_supported_openai_params()
for param, value in non_default_params.items():
if (
param == "temperature"
and value == 0
and model == "mistralai/Mistral-7B-Instruct-v0.1"
): # this model does no support temperature == 0
value = 0.0001 # close to 0
if param == "tool_choice":
if (
value != "auto" and value != "none"
): # https://deepinfra.com/docs/advanced/function_calling
## UNSUPPORTED TOOL CHOICE VALUE
if litellm.drop_params is True or drop_params is True:
value = None
else:
raise litellm.utils.UnsupportedParamsError(
message="Deepinfra doesn't support tool_choice={}. To drop unsupported openai params from the call, set `litellm.drop_params = True`".format(
value
),
status_code=400,
)
elif param == "max_completion_tokens":
optional_params["max_tokens"] = value
elif param in supported_openai_params:
if value is not None:
optional_params[param] = value
return optional_params
def _get_openai_compatible_provider_info(
self, api_base: Optional[str], api_key: Optional[str]
) -> Tuple[Optional[str], Optional[str]]:
# deepinfra is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1
api_base = (
api_base
or get_secret_str("DEEPINFRA_API_BASE")
or "https://api.deepinfra.com/v1/openai"
)
dynamic_api_key = api_key or get_secret_str("DEEPINFRA_API_KEY")
return api_base, dynamic_api_key
class OpenAIConfig:
""" """
Reference: https://platform.openai.com/docs/api-reference/chat/create Reference: https://platform.openai.com/docs/api-reference/chat/create
@ -273,25 +157,12 @@ class OpenAIConfig:
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return { return super().get_config()
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self, model: str) -> list: def get_supported_openai_params(self, model: str) -> list:
""" """
This function returns the list of supported openai parameters for a given OpenAI Model This function returns the list
of supported openai parameters for a given OpenAI Model
- If O1 model, returns O1 supported params - If O1 model, returns O1 supported params
- If gpt-audio model, returns gpt-audio supported params - If gpt-audio model, returns gpt-audio supported params
@ -319,6 +190,11 @@ class OpenAIConfig:
optional_params[param] = value optional_params[param] = value
return optional_params return optional_params
def _transform_messages(
self, messages: List[AllMessageValues]
) -> List[AllMessageValues]:
return messages
def map_openai_params( def map_openai_params(
self, self,
non_default_params: dict, non_default_params: dict,
@ -349,6 +225,55 @@ class OpenAIConfig:
drop_params=drop_params, drop_params=drop_params,
) )
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return OpenAIError(
status_code=status_code,
message=error_message,
headers=headers,
)
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
return {"model": model, "messages": messages, **optional_params}
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
raise NotImplementedError(
"OpenAI handler does this transformation as it uses the OpenAI SDK."
)
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
) -> dict:
raise NotImplementedError(
"OpenAI handler does this validation as it uses the OpenAI SDK."
)
class OpenAIChatCompletion(BaseLLM): class OpenAIChatCompletion(BaseLLM):
@ -483,6 +408,7 @@ class OpenAIChatCompletion(BaseLLM):
model_response: ModelResponse, model_response: ModelResponse,
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
optional_params: dict, optional_params: dict,
litellm_params: dict,
logging_obj: Any, logging_obj: Any,
model: Optional[str] = None, model: Optional[str] = None,
messages: Optional[list] = None, messages: Optional[list] = None,
@ -490,7 +416,6 @@ class OpenAIChatCompletion(BaseLLM):
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
acompletion: bool = False, acompletion: bool = False,
litellm_params=None,
logger_fn=None, logger_fn=None,
headers: Optional[dict] = None, headers: Optional[dict] = None,
custom_prompt_dict: dict = {}, custom_prompt_dict: dict = {},
@ -516,31 +441,26 @@ class OpenAIChatCompletion(BaseLLM):
if custom_llm_provider is not None and custom_llm_provider != "openai": if custom_llm_provider is not None and custom_llm_provider != "openai":
model_response.model = f"{custom_llm_provider}/{model}" model_response.model = f"{custom_llm_provider}/{model}"
# process all OpenAI compatible provider logic here
if custom_llm_provider == "mistral":
# check if message content passed in as list, and not string
messages = prompt_factory( # type: ignore
model=model,
messages=messages,
custom_llm_provider=custom_llm_provider,
)
if custom_llm_provider == "perplexity" and messages is not None:
# check if messages.name is passed + supported, if not supported remove
messages = prompt_factory( # type: ignore
model=model,
messages=messages,
custom_llm_provider=custom_llm_provider,
)
if messages is not None and custom_llm_provider is not None: if messages is not None and custom_llm_provider is not None:
provider_config = ProviderConfigManager.get_provider_chat_config( provider_config = ProviderConfigManager.get_provider_chat_config(
model=model, provider=LlmProviders(custom_llm_provider) model=model, provider=LlmProviders(custom_llm_provider)
) )
messages = provider_config._transform_messages(messages) if isinstance(provider_config, OpenAIGPTConfig) or isinstance(
provider_config, OpenAIConfig
):
messages = provider_config._transform_messages(messages)
for _ in range( for _ in range(
2 2
): # if call fails due to alternating messages, retry with reformatted message ): # if call fails due to alternating messages, retry with reformatted message
data = {"model": model, "messages": messages, **optional_params} data = OpenAIConfig().transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers or {},
)
try: try:
max_retries = data.pop("max_retries", 2) max_retries = data.pop("max_retries", 2)
@ -2430,7 +2350,7 @@ class OpenAIAssistantsAPI(BaseLLM):
""" """
Here's an example: Here's an example:
``` ```
from litellm.llms.OpenAI.openai import OpenAIAssistantsAPI, MessageData from litellm.llms.openai.openai import OpenAIAssistantsAPI, MessageData
# create thread # create thread
message: MessageData = {"role": "user", "content": "Hey, how's it going?"} message: MessageData = {"role": "user", "content": "Hey, how's it going?"}

View file

@ -26,6 +26,8 @@ from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client, get_async_httpx_client,
) )
from litellm.llms.databricks.streaming_utils import ModelResponseIterator from litellm.llms.databricks.streaming_utils import ModelResponseIterator
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
from litellm.llms.openai.openai import OpenAIConfig
from litellm.types.utils import CustomStreamingDecoder, ModelResponse from litellm.types.utils import CustomStreamingDecoder, ModelResponse
from litellm.utils import ( from litellm.utils import (
Choices, Choices,
@ -205,6 +207,7 @@ class OpenAILikeChatHandler(OpenAILikeBase):
) )
response.raise_for_status() response.raise_for_status()
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
print(f"e.response.text: {e.response.text}")
raise OpenAILikeError( raise OpenAILikeError(
status_code=e.response.status_code, status_code=e.response.status_code,
message=e.response.text, message=e.response.text,
@ -212,6 +215,7 @@ class OpenAILikeChatHandler(OpenAILikeBase):
except httpx.TimeoutException: except httpx.TimeoutException:
raise OpenAILikeError(status_code=408, message="Timeout error occurred.") raise OpenAILikeError(status_code=408, message="Timeout error occurred.")
except Exception as e: except Exception as e:
print(f"e: {e}")
raise OpenAILikeError(status_code=500, message=str(e)) raise OpenAILikeError(status_code=500, message=str(e))
return OpenAILikeChatConfig._transform_response( return OpenAILikeChatConfig._transform_response(
@ -280,7 +284,10 @@ class OpenAILikeChatHandler(OpenAILikeBase):
provider_config = ProviderConfigManager.get_provider_chat_config( provider_config = ProviderConfigManager.get_provider_chat_config(
model=model, provider=LlmProviders(custom_llm_provider) model=model, provider=LlmProviders(custom_llm_provider)
) )
messages = provider_config._transform_messages(messages) if isinstance(provider_config, OpenAIGPTConfig) or isinstance(
provider_config, OpenAIConfig
):
messages = provider_config._transform_messages(messages)
data = { data = {
"model": model, "model": model,

View file

@ -75,6 +75,7 @@ class OpenAILikeChatConfig(OpenAIGPTConfig):
custom_llm_provider: str, custom_llm_provider: str,
base_model: Optional[str], base_model: Optional[str],
) -> ModelResponse: ) -> ModelResponse:
print(f"response: {response}")
response_json = response.json() response_json = response.json()
logging_obj.post_call( logging_obj.post_call(
input=messages, input=messages,
@ -99,3 +100,25 @@ class OpenAILikeChatConfig(OpenAIGPTConfig):
if base_model is not None: if base_model is not None:
returned_response._hidden_params["model"] = base_model returned_response._hidden_params["model"] = base_model
return returned_response return returned_response
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
replace_max_completion_tokens_with_max_tokens: bool = True,
) -> dict:
mapped_params = super().map_openai_params(
non_default_params, optional_params, model, drop_params
)
if (
"max_completion_tokens" in non_default_params
and replace_max_completion_tokens_with_max_tokens
):
mapped_params["max_tokens"] = non_default_params[
"max_completion_tokens"
] # most openai-compatible providers support 'max_tokens' not 'max_completion_tokens'
mapped_params.pop("max_completion_tokens", None)
return mapped_params

View file

@ -1,41 +0,0 @@
from typing import List, Dict
import types
class OpenrouterConfig:
"""
Reference: https://openrouter.ai/docs#format
"""
# OpenRouter-only parameters
extra_body: Dict[str, List[str]] = {"transforms": []} # default transforms to []
def __init__(
self,
transforms: List[str] = [],
models: List[str] = [],
route: str = "",
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}

View file

@ -0,0 +1,43 @@
"""
Support for OpenAI's `/v1/chat/completions` endpoint.
Calls done in OpenAI/openai.py as OpenRouter is openai-compatible.
Docs: https://openrouter.ai/docs/parameters
"""
from typing import Optional
from litellm import get_model_info, verbose_logger
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
class OpenrouterConfig(OpenAIGPTConfig):
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
mapped_openai_params = super().map_openai_params(
non_default_params, optional_params, model, drop_params
)
# OpenRouter-only parameters
extra_body = {}
transforms = non_default_params.pop("transforms", None)
models = non_default_params.pop("models", None)
route = non_default_params.pop("route", None)
if transforms is not None:
extra_body["transforms"] = transforms
if models is not None:
extra_body["models"] = models
if route is not None:
extra_body["route"] = route
mapped_openai_params["extra_body"] = (
extra_body # openai client supports `extra_body` param
)
return mapped_openai_params

View file

@ -4,7 +4,7 @@ Common utility functions used for translating messages across providers
import json import json
from copy import deepcopy from copy import deepcopy
from typing import Dict, List, Literal, Optional, Union from typing import Dict, List, Literal, Optional, Union, cast
import litellm import litellm
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
@ -53,6 +53,13 @@ def strip_name_from_messages(
return new_messages return new_messages
def strip_none_values_from_message(message: AllMessageValues) -> AllMessageValues:
"""
Strips None values from message
"""
return cast(AllMessageValues, {k: v for k, v in message.items() if v is not None})
def convert_content_list_to_str(message: AllMessageValues) -> str: def convert_content_list_to_str(message: AllMessageValues) -> str:
""" """
- handles scenario where content is list and not string - handles scenario where content is list and not string

View file

@ -2856,7 +2856,7 @@ def prompt_factory(
else: else:
return gemini_text_image_pt(messages=messages) return gemini_text_image_pt(messages=messages)
elif custom_llm_provider == "mistral": elif custom_llm_provider == "mistral":
return litellm.MistralConfig._transform_messages(messages=messages) return litellm.MistralConfig()._transform_messages(messages=messages)
elif custom_llm_provider == "bedrock": elif custom_llm_provider == "bedrock":
if "amazon.titan-text" in model: if "amazon.titan-text" in model:
return amazon_titan_pt(messages=messages) return amazon_titan_pt(messages=messages)

View file

@ -1,609 +0,0 @@
import asyncio
import json
import os
import time
import types
from typing import Any, Callable, Optional, Tuple, Union
import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
get_async_httpx_client,
)
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from .prompt_templates.factory import custom_prompt, prompt_factory
class ReplicateError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url="https://api.replicate.com/v1/deployments"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class ReplicateConfig:
"""
Reference: https://replicate.com/meta/llama-2-70b-chat/api
- `prompt` (string): The prompt to send to the model.
- `system_prompt` (string): The system prompt to send to the model. This is prepended to the prompt and helps guide system behavior. Default value: `You are a helpful assistant`.
- `max_new_tokens` (integer): Maximum number of tokens to generate. Typically, a word is made up of 2-3 tokens. Default value: `128`.
- `min_new_tokens` (integer): Minimum number of tokens to generate. To disable, set to `-1`. A word is usually 2-3 tokens. Default value: `-1`.
- `temperature` (number): Adjusts the randomness of outputs. Values greater than 1 increase randomness, 0 is deterministic, and 0.75 is a reasonable starting value. Default value: `0.75`.
- `top_p` (number): During text decoding, it samples from the top `p` percentage of most likely tokens. Reduce this to ignore less probable tokens. Default value: `0.9`.
- `top_k` (integer): During text decoding, samples from the top `k` most likely tokens. Reduce this to ignore less probable tokens. Default value: `50`.
- `stop_sequences` (string): A comma-separated list of sequences to stop generation at. For example, inputting '<end>,<stop>' will cease generation at the first occurrence of either 'end' or '<stop>'.
- `seed` (integer): This is the seed for the random generator. Leave it blank to randomize the seed.
- `debug` (boolean): If set to `True`, it provides debugging output in logs.
Please note that Replicate's mapping of these parameters can be inconsistent across different models, indicating that not all of these parameters may be available for use with all models.
"""
system_prompt: Optional[str] = None
max_new_tokens: Optional[int] = None
min_new_tokens: Optional[int] = None
temperature: Optional[int] = None
top_p: Optional[int] = None
top_k: Optional[int] = None
stop_sequences: Optional[str] = None
seed: Optional[int] = None
debug: Optional[bool] = None
def __init__(
self,
system_prompt: Optional[str] = None,
max_new_tokens: Optional[int] = None,
min_new_tokens: Optional[int] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
top_k: Optional[int] = None,
stop_sequences: Optional[str] = None,
seed: Optional[int] = None,
debug: Optional[bool] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
# Function to start a prediction and get the prediction URL
def start_prediction(
version_id, input_data, api_token, api_base, logging_obj, print_verbose
):
base_url = api_base
if "deployments" in version_id:
print_verbose("\nLiteLLM: Request to custom replicate deployment")
version_id = version_id.replace("deployments/", "")
base_url = f"https://api.replicate.com/v1/deployments/{version_id}"
print_verbose(f"Deployment base URL: {base_url}\n")
else: # assume it's a model
base_url = f"https://api.replicate.com/v1/models/{version_id}"
headers = {
"Authorization": f"Token {api_token}",
"Content-Type": "application/json",
}
initial_prediction_data = {
"input": input_data,
}
if ":" in version_id and len(version_id) > 64:
model_parts = version_id.split(":")
if (
len(model_parts) > 1 and len(model_parts[1]) == 64
): ## checks if model name has a 64 digit code - e.g. "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3"
initial_prediction_data["version"] = model_parts[1]
## LOGGING
logging_obj.pre_call(
input=input_data["prompt"],
api_key="",
additional_args={
"complete_input_dict": initial_prediction_data,
"headers": headers,
"api_base": base_url,
},
)
response = requests.post(
f"{base_url}/predictions", json=initial_prediction_data, headers=headers
)
if response.status_code == 201:
response_data = response.json()
return response_data.get("urls", {}).get("get")
else:
raise ReplicateError(
response.status_code, f"Failed to start prediction {response.text}"
)
async def async_start_prediction(
version_id,
input_data,
api_token,
api_base,
logging_obj,
print_verbose,
http_handler: AsyncHTTPHandler,
) -> str:
base_url = api_base
if "deployments" in version_id:
print_verbose("\nLiteLLM: Request to custom replicate deployment")
version_id = version_id.replace("deployments/", "")
base_url = f"https://api.replicate.com/v1/deployments/{version_id}"
print_verbose(f"Deployment base URL: {base_url}\n")
else: # assume it's a model
base_url = f"https://api.replicate.com/v1/models/{version_id}"
headers = {
"Authorization": f"Token {api_token}",
"Content-Type": "application/json",
}
initial_prediction_data = {
"input": input_data,
}
if ":" in version_id and len(version_id) > 64:
model_parts = version_id.split(":")
if (
len(model_parts) > 1 and len(model_parts[1]) == 64
): ## checks if model name has a 64 digit code - e.g. "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3"
initial_prediction_data["version"] = model_parts[1]
## LOGGING
logging_obj.pre_call(
input=input_data["prompt"],
api_key="",
additional_args={
"complete_input_dict": initial_prediction_data,
"headers": headers,
"api_base": base_url,
},
)
response = await http_handler.post(
url="{}/predictions".format(base_url),
data=json.dumps(initial_prediction_data),
headers=headers,
)
if response.status_code == 201:
response_data = response.json()
return response_data.get("urls", {}).get("get")
else:
raise ReplicateError(
response.status_code, f"Failed to start prediction {response.text}"
)
# Function to handle prediction response (non-streaming)
def handle_prediction_response(prediction_url, api_token, print_verbose):
output_string = ""
headers = {
"Authorization": f"Token {api_token}",
"Content-Type": "application/json",
}
status = ""
logs = ""
while True and (status not in ["succeeded", "failed", "canceled"]):
print_verbose(f"replicate: polling endpoint: {prediction_url}")
time.sleep(0.5)
response = requests.get(prediction_url, headers=headers)
if response.status_code == 200:
response_data = response.json()
if "output" in response_data:
output_string = "".join(response_data["output"])
print_verbose(f"Non-streamed output:{output_string}")
status = response_data.get("status", None)
logs = response_data.get("logs", "")
if status == "failed":
replicate_error = response_data.get("error", "")
raise ReplicateError(
status_code=400,
message=f"Error: {replicate_error}, \nReplicate logs:{logs}",
)
else:
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
print_verbose("Replicate: Failed to fetch prediction status and output.")
return output_string, logs
async def async_handle_prediction_response(
prediction_url, api_token, print_verbose, http_handler: AsyncHTTPHandler
) -> Tuple[str, Any]:
output_string = ""
headers = {
"Authorization": f"Token {api_token}",
"Content-Type": "application/json",
}
status = ""
logs = ""
while True and (status not in ["succeeded", "failed", "canceled"]):
print_verbose(f"replicate: polling endpoint: {prediction_url}")
await asyncio.sleep(0.5) # prevent replicate rate limit errors
response = await http_handler.get(prediction_url, headers=headers)
if response.status_code == 200:
response_data = response.json()
if "output" in response_data:
output_string = "".join(response_data["output"])
print_verbose(f"Non-streamed output:{output_string}")
status = response_data.get("status", None)
logs = response_data.get("logs", "")
if status == "failed":
replicate_error = response_data.get("error", "")
raise ReplicateError(
status_code=400,
message=f"Error: {replicate_error}, \nReplicate logs:{logs}",
)
else:
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
print_verbose("Replicate: Failed to fetch prediction status and output.")
return output_string, logs
# Function to handle prediction response (streaming)
def handle_prediction_response_streaming(prediction_url, api_token, print_verbose):
previous_output = ""
output_string = ""
headers = {
"Authorization": f"Token {api_token}",
"Content-Type": "application/json",
}
status = ""
while True and (status not in ["succeeded", "failed", "canceled"]):
time.sleep(0.5) # prevent being rate limited by replicate
print_verbose(f"replicate: polling endpoint: {prediction_url}")
response = requests.get(prediction_url, headers=headers)
if response.status_code == 200:
response_data = response.json()
status = response_data["status"]
if "output" in response_data:
try:
output_string = "".join(response_data["output"])
except Exception:
raise ReplicateError(
status_code=422,
message="Unable to parse response. Got={}".format(
response_data["output"]
),
)
new_output = output_string[len(previous_output) :]
print_verbose(f"New chunk: {new_output}")
yield {"output": new_output, "status": status}
previous_output = output_string
status = response_data["status"]
if status == "failed":
replicate_error = response_data.get("error", "")
raise ReplicateError(
status_code=400, message=f"Error: {replicate_error}"
)
else:
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
print_verbose(
f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}"
)
# Function to handle prediction response (streaming)
async def async_handle_prediction_response_streaming(
prediction_url, api_token, print_verbose
):
http_handler = get_async_httpx_client(llm_provider=litellm.LlmProviders.REPLICATE)
previous_output = ""
output_string = ""
headers = {
"Authorization": f"Token {api_token}",
"Content-Type": "application/json",
}
status = ""
while True and (status not in ["succeeded", "failed", "canceled"]):
await asyncio.sleep(0.5) # prevent being rate limited by replicate
print_verbose(f"replicate: polling endpoint: {prediction_url}")
response = await http_handler.get(prediction_url, headers=headers)
if response.status_code == 200:
response_data = response.json()
status = response_data["status"]
if "output" in response_data:
try:
output_string = "".join(response_data["output"])
except Exception:
raise ReplicateError(
status_code=422,
message="Unable to parse response. Got={}".format(
response_data["output"]
),
)
new_output = output_string[len(previous_output) :]
print_verbose(f"New chunk: {new_output}")
yield {"output": new_output, "status": status}
previous_output = output_string
status = response_data["status"]
if status == "failed":
replicate_error = response_data.get("error", "")
raise ReplicateError(
status_code=400, message=f"Error: {replicate_error}"
)
else:
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
print_verbose(
f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}"
)
# Function to extract version ID from model string
def model_to_version_id(model):
if ":" in model:
split_model = model.split(":")
return split_model[1]
return model
def process_response(
model_response: ModelResponse,
result: str,
model: str,
encoding: Any,
prompt: str,
) -> ModelResponse:
if len(result) == 0: # edge case, where result from replicate is empty
result = " "
## Building RESPONSE OBJECT
if len(result) >= 1:
model_response.choices[0].message.content = result # type: ignore
# Calculate usage
prompt_tokens = len(encoding.encode(prompt, disallowed_special=()))
completion_tokens = len(
encoding.encode(
model_response["choices"][0]["message"].get("content", ""),
disallowed_special=(),
)
)
model_response.model = "replicate/" + model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
# Main function for prediction completion
def completion(
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
optional_params: dict,
logging_obj,
api_key,
encoding,
custom_prompt_dict={},
litellm_params=None,
logger_fn=None,
acompletion=None,
) -> Union[ModelResponse, CustomStreamWrapper]:
# Start a prediction and get the prediction URL
version_id = model_to_version_id(model)
## Load Config
config = litellm.ReplicateConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > replicate_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
system_prompt = None
if optional_params is not None and "supports_system_prompt" in optional_params:
supports_sys_prompt = optional_params.pop("supports_system_prompt")
else:
supports_sys_prompt = False
if supports_sys_prompt:
for i in range(len(messages)):
if messages[i]["role"] == "system":
first_sys_message = messages.pop(i)
system_prompt = first_sys_message["content"]
break
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", {}),
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
bos_token=model_prompt_details.get("bos_token", ""),
eos_token=model_prompt_details.get("eos_token", ""),
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages)
if prompt is None or not isinstance(prompt, str):
raise ReplicateError(
status_code=400,
message="LiteLLM Error - prompt is not a string - {}".format(prompt),
)
# If system prompt is supported, and a system prompt is provided, use it
if system_prompt is not None:
input_data = {
"prompt": prompt,
"system_prompt": system_prompt,
**optional_params,
}
# Otherwise, use the prompt as is
else:
input_data = {"prompt": prompt, **optional_params}
if acompletion is not None and acompletion is True:
return async_completion(
model_response=model_response,
model=model,
prompt=prompt,
encoding=encoding,
optional_params=optional_params,
version_id=version_id,
input_data=input_data,
api_key=api_key,
api_base=api_base,
logging_obj=logging_obj,
print_verbose=print_verbose,
) # type: ignore
## COMPLETION CALL
## Replicate Compeltion calls have 2 steps
## Step1: Start Prediction: gets a prediction url
## Step2: Poll prediction url for response
## Step2: is handled with and without streaming
model_response.created = int(
time.time()
) # for pricing this must remain right before calling api
prediction_url = start_prediction(
version_id,
input_data,
api_key,
api_base,
logging_obj=logging_obj,
print_verbose=print_verbose,
)
print_verbose(prediction_url)
# Handle the prediction response (streaming or non-streaming)
if "stream" in optional_params and optional_params["stream"] is True:
print_verbose("streaming request")
_response = handle_prediction_response_streaming(
prediction_url, api_key, print_verbose
)
return CustomStreamWrapper(_response, model, logging_obj=logging_obj, custom_llm_provider="replicate") # type: ignore
else:
result, logs = handle_prediction_response(
prediction_url, api_key, print_verbose
)
## LOGGING
logging_obj.post_call(
input=prompt,
api_key="",
original_response=result,
additional_args={
"complete_input_dict": input_data,
"logs": logs,
"api_base": prediction_url,
},
)
print_verbose(f"raw model_response: {result}")
return process_response(
model_response=model_response,
result=result,
model=model,
encoding=encoding,
prompt=prompt,
)
async def async_completion(
model_response: ModelResponse,
model: str,
prompt: str,
encoding,
optional_params: dict,
version_id,
input_data,
api_key,
api_base,
logging_obj,
print_verbose,
) -> Union[ModelResponse, CustomStreamWrapper]:
http_handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.REPLICATE,
)
prediction_url = await async_start_prediction(
version_id,
input_data,
api_key,
api_base,
logging_obj=logging_obj,
print_verbose=print_verbose,
http_handler=http_handler,
)
if "stream" in optional_params and optional_params["stream"] is True:
_response = async_handle_prediction_response_streaming(
prediction_url, api_key, print_verbose
)
return CustomStreamWrapper(_response, model, logging_obj=logging_obj, custom_llm_provider="replicate") # type: ignore
result, logs = await async_handle_prediction_response(
prediction_url, api_key, print_verbose, http_handler=http_handler
)
return process_response(
model_response=model_response,
result=result,
model=model,
encoding=encoding,
prompt=prompt,
)
# # Example usage:
# response = completion(
# api_key="",
# messages=[{"content": "good morning"}],
# model="replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf",
# model_response=ModelResponse(),
# print_verbose=print,
# logging_obj=print, # stub logging_obj
# optional_params={"stream": False}
# )
# print(response)

View file

@ -0,0 +1,285 @@
import asyncio
import json
import os
import time
import types
from typing import Any, Callable, List, Optional, Tuple, Union
import httpx # type: ignore
import litellm
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client,
get_async_httpx_client,
)
from litellm.types.llms.openai import AllMessageValues
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from ...prompt_templates.factory import custom_prompt, prompt_factory
from ..common_utils import ReplicateError
from .transformation import ReplicateConfig
replicate_config = ReplicateConfig()
# Function to handle prediction response (streaming)
def handle_prediction_response_streaming(
prediction_url, api_token, print_verbose, headers: dict, http_client: HTTPHandler
):
previous_output = ""
output_string = ""
status = ""
while True and (status not in ["succeeded", "failed", "canceled"]):
time.sleep(0.5) # prevent being rate limited by replicate
print_verbose(f"replicate: polling endpoint: {prediction_url}")
response = http_client.get(prediction_url, headers=headers)
if response.status_code == 200:
response_data = response.json()
status = response_data["status"]
if "output" in response_data:
try:
output_string = "".join(response_data["output"])
except Exception:
raise ReplicateError(
status_code=422,
message="Unable to parse response. Got={}".format(
response_data["output"]
),
headers=response.headers,
)
new_output = output_string[len(previous_output) :]
print_verbose(f"New chunk: {new_output}")
yield {"output": new_output, "status": status}
previous_output = output_string
status = response_data["status"]
if status == "failed":
replicate_error = response_data.get("error", "")
raise ReplicateError(
status_code=400,
message=f"Error: {replicate_error}",
headers=response.headers,
)
else:
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
print_verbose(
f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}"
)
# Function to handle prediction response (streaming)
async def async_handle_prediction_response_streaming(
prediction_url,
api_token,
print_verbose,
headers: dict,
http_client: AsyncHTTPHandler,
):
previous_output = ""
output_string = ""
status = ""
while True and (status not in ["succeeded", "failed", "canceled"]):
await asyncio.sleep(0.5) # prevent being rate limited by replicate
print_verbose(f"replicate: polling endpoint: {prediction_url}")
response = await http_client.get(prediction_url, headers=headers)
if response.status_code == 200:
response_data = response.json()
status = response_data["status"]
if "output" in response_data:
try:
output_string = "".join(response_data["output"])
except Exception:
raise ReplicateError(
status_code=422,
message="Unable to parse response. Got={}".format(
response_data["output"]
),
headers=response.headers,
)
new_output = output_string[len(previous_output) :]
print_verbose(f"New chunk: {new_output}")
yield {"output": new_output, "status": status}
previous_output = output_string
status = response_data["status"]
if status == "failed":
replicate_error = response_data.get("error", "")
raise ReplicateError(
status_code=400,
message=f"Error: {replicate_error}",
headers=response.headers,
)
else:
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
print_verbose(
f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}"
)
# Main function for prediction completion
def completion(
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
optional_params: dict,
litellm_params: dict,
logging_obj,
api_key,
encoding,
custom_prompt_dict={},
logger_fn=None,
acompletion=None,
headers={},
) -> Union[ModelResponse, CustomStreamWrapper]:
headers = replicate_config.validate_environment(
api_key=api_key,
headers=headers,
model=model,
messages=messages,
optional_params=optional_params,
)
# Start a prediction and get the prediction URL
version_id = replicate_config.model_to_version_id(model)
input_data = replicate_config.transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
if acompletion is not None and acompletion is True:
return async_completion(
model_response=model_response,
model=model,
encoding=encoding,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
version_id=version_id,
input_data=input_data,
api_key=api_key,
api_base=api_base,
logging_obj=logging_obj,
print_verbose=print_verbose,
headers=headers,
) # type: ignore
## COMPLETION CALL
model_response.created = int(
time.time()
) # for pricing this must remain right before calling api
prediction_url = replicate_config.get_complete_url(api_base, model)
## COMPLETION CALL
httpx_client = _get_httpx_client(
params={"timeout": 600.0},
)
response = httpx_client.post(
url=prediction_url,
headers=headers,
data=json.dumps(input_data),
)
prediction_url = replicate_config.get_prediction_url(response)
# Handle the prediction response (streaming or non-streaming)
if "stream" in optional_params and optional_params["stream"] is True:
print_verbose("streaming request")
_response = handle_prediction_response_streaming(
prediction_url,
api_key,
print_verbose,
headers=headers,
http_client=httpx_client,
)
return CustomStreamWrapper(_response, model, logging_obj=logging_obj, custom_llm_provider="replicate") # type: ignore
else:
for _ in range(litellm.DEFAULT_MAX_RETRIES):
time.sleep(
1
) # wait 1s to allow response to be generated by replicate - else partial output is generated with status=="processing"
response = httpx_client.get(url=prediction_url, headers=headers)
return litellm.ReplicateConfig().transform_response(
model=model,
raw_response=response,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
request_data=input_data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
)
raise ReplicateError(
status_code=500,
message="No response received from Replicate API after max retries",
headers=None,
)
async def async_completion(
model_response: ModelResponse,
model: str,
messages: List[AllMessageValues],
encoding,
optional_params: dict,
litellm_params: dict,
version_id,
input_data,
api_key,
api_base,
logging_obj,
print_verbose,
headers: dict,
) -> Union[ModelResponse, CustomStreamWrapper]:
prediction_url = replicate_config.get_complete_url(api_base=api_base, model=model)
async_handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.REPLICATE,
params={"timeout": 600.0},
)
response = await async_handler.post(
url=prediction_url, headers=headers, data=json.dumps(input_data)
)
prediction_url = replicate_config.get_prediction_url(response)
if "stream" in optional_params and optional_params["stream"] is True:
_response = async_handle_prediction_response_streaming(
prediction_url,
api_key,
print_verbose,
headers=headers,
http_client=async_handler,
)
return CustomStreamWrapper(_response, model, logging_obj=logging_obj, custom_llm_provider="replicate") # type: ignore
for _ in range(litellm.DEFAULT_MAX_RETRIES):
await asyncio.sleep(
1
) # wait 1s to allow response to be generated by replicate - else partial output is generated with status=="processing"
response = await async_handler.get(url=prediction_url, headers=headers)
return litellm.ReplicateConfig().transform_response(
model=model,
raw_response=response,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
request_data=input_data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
)
# Add a fallback return if no response is received after max retries
raise ReplicateError(
status_code=500,
message="No response received from Replicate API after max retries",
headers=None,
)

View file

@ -0,0 +1,312 @@
import types
from typing import TYPE_CHECKING, Any, List, Optional, Union
import httpx
import litellm
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.llms.prompt_templates.common_utils import convert_content_list_to_str
from litellm.llms.prompt_templates.factory import custom_prompt, prompt_factory
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import Choices, Message, ModelResponse, Usage
from litellm.utils import token_counter
from ..common_utils import ReplicateError
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
LoggingClass = LiteLLMLoggingObj
else:
LoggingClass = Any
class ReplicateConfig(BaseConfig):
"""
Reference: https://replicate.com/meta/llama-2-70b-chat/api
- `prompt` (string): The prompt to send to the model.
- `system_prompt` (string): The system prompt to send to the model. This is prepended to the prompt and helps guide system behavior. Default value: `You are a helpful assistant`.
- `max_new_tokens` (integer): Maximum number of tokens to generate. Typically, a word is made up of 2-3 tokens. Default value: `128`.
- `min_new_tokens` (integer): Minimum number of tokens to generate. To disable, set to `-1`. A word is usually 2-3 tokens. Default value: `-1`.
- `temperature` (number): Adjusts the randomness of outputs. Values greater than 1 increase randomness, 0 is deterministic, and 0.75 is a reasonable starting value. Default value: `0.75`.
- `top_p` (number): During text decoding, it samples from the top `p` percentage of most likely tokens. Reduce this to ignore less probable tokens. Default value: `0.9`.
- `top_k` (integer): During text decoding, samples from the top `k` most likely tokens. Reduce this to ignore less probable tokens. Default value: `50`.
- `stop_sequences` (string): A comma-separated list of sequences to stop generation at. For example, inputting '<end>,<stop>' will cease generation at the first occurrence of either 'end' or '<stop>'.
- `seed` (integer): This is the seed for the random generator. Leave it blank to randomize the seed.
- `debug` (boolean): If set to `True`, it provides debugging output in logs.
Please note that Replicate's mapping of these parameters can be inconsistent across different models, indicating that not all of these parameters may be available for use with all models.
"""
system_prompt: Optional[str] = None
max_new_tokens: Optional[int] = None
min_new_tokens: Optional[int] = None
temperature: Optional[int] = None
top_p: Optional[int] = None
top_k: Optional[int] = None
stop_sequences: Optional[str] = None
seed: Optional[int] = None
debug: Optional[bool] = None
def __init__(
self,
system_prompt: Optional[str] = None,
max_new_tokens: Optional[int] = None,
min_new_tokens: Optional[int] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
top_k: Optional[int] = None,
stop_sequences: Optional[str] = None,
seed: Optional[int] = None,
debug: Optional[bool] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return super().get_config()
def get_supported_openai_params(self, model: str) -> list:
return [
"stream",
"temperature",
"max_tokens",
"top_p",
"stop",
"seed",
"tools",
"tool_choice",
"functions",
"function_call",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
for param, value in non_default_params.items():
if param == "stream":
optional_params["stream"] = value
if param == "max_tokens":
if "vicuna" in model or "flan" in model:
optional_params["max_length"] = value
elif "meta/codellama-13b" in model:
optional_params["max_tokens"] = value
else:
optional_params["max_new_tokens"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "stop":
optional_params["stop_sequences"] = value
return optional_params
# Function to extract version ID from model string
def model_to_version_id(self, model: str) -> str:
if ":" in model:
split_model = model.split(":")
return split_model[1]
return model
def _transform_messages(
self, messages: List[AllMessageValues]
) -> List[AllMessageValues]:
return messages
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return ReplicateError(
status_code=status_code, message=error_message, headers=headers
)
def get_complete_url(self, api_base: str, model: str) -> str:
version_id = self.model_to_version_id(model)
base_url = api_base
if "deployments" in version_id:
version_id = version_id.replace("deployments/", "")
base_url = f"https://api.replicate.com/v1/deployments/{version_id}"
else: # assume it's a model
base_url = f"https://api.replicate.com/v1/models/{version_id}"
base_url = f"{base_url}/predictions"
return base_url
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
## Load Config
config = litellm.ReplicateConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > replicate_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
system_prompt = None
if optional_params is not None and "supports_system_prompt" in optional_params:
supports_sys_prompt = optional_params.pop("supports_system_prompt")
else:
supports_sys_prompt = False
if supports_sys_prompt:
for i in range(len(messages)):
if messages[i]["role"] == "system":
first_sys_message = messages.pop(i)
system_prompt = convert_content_list_to_str(first_sys_message)
break
if model in litellm.custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = litellm.custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", {}),
initial_prompt_value=model_prompt_details.get(
"initial_prompt_value", ""
),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
bos_token=model_prompt_details.get("bos_token", ""),
eos_token=model_prompt_details.get("eos_token", ""),
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages)
if prompt is None or not isinstance(prompt, str):
raise ReplicateError(
status_code=400,
message="LiteLLM Error - prompt is not a string - {}".format(prompt),
headers={},
)
# If system prompt is supported, and a system prompt is provided, use it
if system_prompt is not None:
input_data = {
"prompt": prompt,
"system_prompt": system_prompt,
**optional_params,
}
# Otherwise, use the prompt as is
else:
input_data = {"prompt": prompt, **optional_params}
version_id = self.model_to_version_id(model)
request_data: dict = {"input": input_data}
if ":" in version_id and len(version_id) > 64:
model_parts = version_id.split(":")
if (
len(model_parts) > 1 and len(model_parts[1]) == 64
): ## checks if model name has a 64 digit code - e.g. "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3"
request_data["version"] = model_parts[1]
return request_data
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LoggingClass,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=raw_response.text,
additional_args={"complete_input_dict": request_data},
)
raw_response_json = raw_response.json()
if raw_response_json.get("status") != "succeeded":
raise ReplicateError(
status_code=422,
message="LiteLLM Error - prediction not succeeded - {}".format(
raw_response_json
),
headers=raw_response.headers,
)
outputs = raw_response_json.get("output", [])
response_str = "".join(outputs)
if len(response_str) == 0: # edge case, where result from replicate is empty
response_str = " "
## Building RESPONSE OBJECT
if len(response_str) >= 1:
model_response.choices[0].message.content = response_str # type: ignore
# Calculate usage
prompt_tokens = token_counter(model=model, messages=messages)
completion_tokens = token_counter(
model=model,
text=response_str,
count_response_tokens=True,
)
model_response.model = "replicate/" + model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
def get_prediction_url(self, response: httpx.Response) -> str:
"""
response json: {
...,
"urls":{"cancel":"https://api.replicate.com/v1/predictions/gqsmqmp1pdrj00cknr08dgmvb4/cancel","get":"https://api.replicate.com/v1/predictions/gqsmqmp1pdrj00cknr08dgmvb4","stream":"https://stream-b.svc.rno2.c.replicate.net/v1/streams/eot4gbydowuin4snhncydwxt57dfwgsc3w3snycx5nid7oef7jga"}
}
"""
response_json = response.json()
prediction_url = response_json.get("urls", {}).get("get")
if prediction_url is None:
raise ReplicateError(
status_code=400,
message="LiteLLM Error - prediction url is None - {}".format(
response_json
),
headers=response.headers,
)
return prediction_url
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
) -> dict:
headers = {
"Authorization": f"Token {api_key}",
"Content-Type": "application/json",
}
return headers

View file

@ -0,0 +1,15 @@
from typing import Optional, Union
import httpx
from litellm.llms.base_llm.transformation import BaseLLMException
class ReplicateError(BaseLLMException):
def __init__(
self,
status_code: int,
message: str,
headers: Optional[Union[dict, httpx.Headers]],
):
super().__init__(status_code=status_code, message=message, headers=headers)

View file

@ -363,6 +363,7 @@ class SagemakerLLM(BaseAWSLLM):
messages=messages, messages=messages,
optional_params=optional_params, optional_params=optional_params,
encoding=encoding, encoding=encoding,
litellm_params=litellm_params,
) )
async def make_async_call( async def make_async_call(
@ -562,6 +563,7 @@ class SagemakerLLM(BaseAWSLLM):
messages=messages, messages=messages,
optional_params=optional_params, optional_params=optional_params,
encoding=encoding, encoding=encoding,
litellm_params=litellm_params,
) )
def embedding( def embedding(

View file

@ -202,6 +202,7 @@ class SagemakerConfig(BaseConfig):
request_data: dict, request_data: dict,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
litellm_params: dict,
encoding: str, encoding: str,
api_key: Optional[str] = None, api_key: Optional[str] = None,
json_mode: Optional[bool] = None, json_mode: Optional[bool] = None,

View file

@ -7,8 +7,10 @@ this is OpenAI compatible - no translation needed / occurs
import types import types
from typing import Optional from typing import Optional
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
class SambanovaConfig:
class SambanovaConfig(OpenAIGPTConfig):
""" """
Reference: https://community.sambanova.ai/t/create-chat-completion-api/ Reference: https://community.sambanova.ai/t/create-chat-completion-api/
@ -18,9 +20,7 @@ class SambanovaConfig:
max_tokens: Optional[int] = None max_tokens: Optional[int] = None
response_format: Optional[dict] = None response_format: Optional[dict] = None
seed: Optional[int] = None seed: Optional[int] = None
stop: Optional[str] = None
stream: Optional[bool] = None stream: Optional[bool] = None
temperature: Optional[float] = None
top_p: Optional[int] = None top_p: Optional[int] = None
tool_choice: Optional[str] = None tool_choice: Optional[str] = None
tools: Optional[list] = None tools: Optional[list] = None
@ -46,21 +46,7 @@ class SambanovaConfig:
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return { return super().get_config()
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self, model: str) -> list: def get_supported_openai_params(self, model: str) -> list:
""" """
@ -80,12 +66,3 @@ class SambanovaConfig:
"tools", "tools",
"user", "user",
] ]
def map_openai_params(
self, model: str, non_default_params: dict, optional_params: dict
) -> dict:
supported_openai_params = self.get_supported_openai_params(model=model)
for param, value in non_default_params.items():
if param in supported_openai_params:
optional_params[param] = value
return optional_params

View file

@ -22,6 +22,7 @@ from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler, AsyncHTTPHandler,
get_async_httpx_client, get_async_httpx_client,
) )
from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig
from litellm.types.llms.databricks import GenericStreamingChunk from litellm.types.llms.databricks import GenericStreamingChunk
from litellm.utils import ( from litellm.utils import (
Choices, Choices,
@ -91,19 +92,17 @@ async def make_call(
return completion_stream return completion_stream
class MistralTextCompletionConfig: class MistralTextCompletionConfig(OpenAITextCompletionConfig):
""" """
Reference: https://docs.mistral.ai/api/#operation/createFIMCompletion Reference: https://docs.mistral.ai/api/#operation/createFIMCompletion
""" """
suffix: Optional[str] = None suffix: Optional[str] = None
temperature: Optional[int] = None temperature: Optional[int] = None
top_p: Optional[float] = None
max_tokens: Optional[int] = None max_tokens: Optional[int] = None
min_tokens: Optional[int] = None min_tokens: Optional[int] = None
stream: Optional[bool] = None stream: Optional[bool] = None
random_seed: Optional[int] = None random_seed: Optional[int] = None
stop: Optional[str] = None
def __init__( def __init__(
self, self,
@ -123,23 +122,9 @@ class MistralTextCompletionConfig:
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return { return super().get_config()
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self): def get_supported_openai_params(self, model: str):
return [ return [
"suffix", "suffix",
"temperature", "temperature",
@ -151,7 +136,13 @@ class MistralTextCompletionConfig:
"stop", "stop",
] ]
def map_openai_params(self, non_default_params: dict, optional_params: dict): def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
for param, value in non_default_params.items(): for param, value in non_default_params.items():
if param == "suffix": if param == "suffix":
optional_params["suffix"] = value optional_params["suffix"] = value

View file

@ -1,22 +1,20 @@
from typing import List, Literal, Tuple from typing import Dict, List, Literal, Optional, Tuple, Union
import httpx import httpx
from litellm import supports_response_schema, supports_system_messages, verbose_logger from litellm import supports_response_schema, supports_system_messages, verbose_logger
from litellm.llms.base_llm.transformation import BaseLLMException
from litellm.types.llms.vertex_ai import PartType from litellm.types.llms.vertex_ai import PartType
class VertexAIError(Exception): class VertexAIError(BaseLLMException):
def __init__(self, status_code, message): def __init__(
self.status_code = status_code self,
self.message = message status_code: int,
self.request = httpx.Request( message: str,
method="POST", url=" https://cloud.google.com/vertex-ai/" headers: Optional[Union[Dict, httpx.Headers]] = None,
) ):
self.response = httpx.Response(status_code=status_code, request=self.request) super().__init__(message=message, status_code=status_code, headers=headers)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
def get_supports_system_message( def get_supports_system_message(

View file

@ -299,11 +299,13 @@ def _transform_request_body(
try: try:
if custom_llm_provider == "gemini": if custom_llm_provider == "gemini":
content = litellm.GoogleAIStudioGeminiConfig._transform_messages( content = litellm.GoogleAIStudioGeminiConfig()._transform_messages(
messages=messages messages=messages
) )
else: else:
content = litellm.VertexGeminiConfig._transform_messages(messages=messages) content = litellm.VertexGeminiConfig()._transform_messages(
messages=messages
)
tools: Optional[Tools] = optional_params.pop("tools", None) tools: Optional[Tools] = optional_params.pop("tools", None)
tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None) tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None)
safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop( safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop(
@ -460,15 +462,3 @@ def _transform_system_message(
return SystemInstructions(parts=system_content_blocks), messages return SystemInstructions(parts=system_content_blocks), messages
return None, messages return None, messages
def set_headers(auth_header: Optional[str], extra_headers: Optional[dict]) -> dict:
headers = {
"Content-Type": "application/json",
}
if auth_header is not None:
headers["Authorization"] = f"Bearer {auth_header}"
if extra_headers is not None:
headers.update(extra_headers)
return headers

View file

@ -20,6 +20,7 @@ from typing import (
Optional, Optional,
Tuple, Tuple,
Union, Union,
cast,
) )
import httpx # type: ignore import httpx # type: ignore
@ -30,6 +31,7 @@ import litellm.litellm_core_utils
import litellm.litellm_core_utils.litellm_logging import litellm.litellm_core_utils.litellm_logging
from litellm import verbose_logger from litellm import verbose_logger
from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler, AsyncHTTPHandler,
HTTPHandler, HTTPHandler,
@ -86,10 +88,16 @@ from .transformation import (
_gemini_convert_messages_with_history, _gemini_convert_messages_with_history,
_process_gemini_image, _process_gemini_image,
async_transform_request_body, async_transform_request_body,
set_headers,
sync_transform_request_body, sync_transform_request_body,
) )
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
LoggingClass = LiteLLMLoggingObj
else:
LoggingClass = Any
class VertexAIConfig: class VertexAIConfig:
""" """
@ -277,7 +285,7 @@ class VertexAIConfig:
] ]
class VertexGeminiConfig: class VertexGeminiConfig(BaseConfig):
""" """
Reference: https://cloud.google.com/vertex-ai/docs/generative-ai/chat/test-chat-prompts Reference: https://cloud.google.com/vertex-ai/docs/generative-ai/chat/test-chat-prompts
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
@ -338,23 +346,9 @@ class VertexGeminiConfig:
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return { return super().get_config()
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self): def get_supported_openai_params(self, model: str) -> List[str]:
return [ return [
"temperature", "temperature",
"top_p", "top_p",
@ -473,12 +467,11 @@ class VertexGeminiConfig:
def map_openai_params( def map_openai_params(
self, self,
non_default_params: Dict,
optional_params: Dict,
model: str, model: str,
non_default_params: dict,
optional_params: dict,
drop_params: bool, drop_params: bool,
): ) -> Dict:
for param, value in non_default_params.items(): for param, value in non_default_params.items():
if param == "temperature": if param == "temperature":
optional_params["temperature"] = value optional_params["temperature"] = value
@ -751,38 +744,38 @@ class VertexGeminiConfig:
return model_response return model_response
def _transform_response( def transform_response(
self, self,
model: str, model: str,
response: httpx.Response, raw_response: httpx.Response,
model_response: ModelResponse, model_response: ModelResponse,
logging_obj: litellm.litellm_core_utils.litellm_logging.Logging, logging_obj: LoggingClass,
optional_params: dict, request_data: Dict,
litellm_params: dict, messages: List[AllMessageValues],
api_key: str, optional_params: Dict,
data: Union[dict, str, RequestBody], litellm_params: Dict,
messages: List, encoding: Any,
print_verbose, api_key: Optional[str] = None,
encoding, json_mode: Optional[bool] = None,
) -> ModelResponse: ) -> ModelResponse:
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=messages, input=messages,
api_key="", api_key="",
original_response=response.text, original_response=raw_response.text,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": request_data},
) )
## RESPONSE OBJECT ## RESPONSE OBJECT
try: try:
completion_response = GenerateContentResponseBody(**response.json()) # type: ignore completion_response = GenerateContentResponseBody(**raw_response.json()) # type: ignore
except Exception as e: except Exception as e:
raise VertexAIError( raise VertexAIError(
message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format( message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format(
response.text, str(e) raw_response.text, str(e)
), ),
status_code=422, status_code=422,
headers=raw_response.headers,
) )
## GET MODEL ## ## GET MODEL ##
@ -915,14 +908,53 @@ class VertexGeminiConfig:
completion_response, str(e) completion_response, str(e)
), ),
status_code=422, status_code=422,
headers=raw_response.headers,
) )
return model_response return model_response
@staticmethod def _transform_messages(
def _transform_messages(messages: List[AllMessageValues]) -> List[ContentType]: self, messages: List[AllMessageValues]
) -> List[ContentType]:
return _gemini_convert_messages_with_history(messages=messages) return _gemini_convert_messages_with_history(messages=messages)
def get_error_class(
self, error_message: str, status_code: int, headers: Union[Dict, httpx.Headers]
) -> BaseLLMException:
return VertexAIError(
message=error_message, status_code=status_code, headers=headers
)
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: Dict,
litellm_params: Dict,
headers: Dict,
) -> Dict:
raise NotImplementedError(
"Vertex AI has a custom implementation of transform_request. Needs sync + async."
)
def validate_environment(
self,
headers: Optional[Dict],
model: str,
messages: List[AllMessageValues],
optional_params: Dict,
api_key: Optional[str] = None,
) -> Dict:
default_headers = {
"Content-Type": "application/json",
}
if api_key is not None:
default_headers["Authorization"] = f"Bearer {api_key}"
if headers is not None:
default_headers.update(headers)
return default_headers
class GoogleAIStudioGeminiConfig( class GoogleAIStudioGeminiConfig(
VertexGeminiConfig VertexGeminiConfig
@ -978,23 +1010,9 @@ class GoogleAIStudioGeminiConfig(
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return { return super().get_config()
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self): def get_supported_openai_params(self, model: str) -> List[str]:
return [ return [
"temperature", "temperature",
"top_p", "top_p",
@ -1012,22 +1030,27 @@ class GoogleAIStudioGeminiConfig(
def map_openai_params( def map_openai_params(
self, self,
model: str,
non_default_params: Dict, non_default_params: Dict,
optional_params: Dict, optional_params: Dict,
model: str,
drop_params: bool, drop_params: bool,
): ) -> Dict:
# drop frequency_penalty and presence_penalty # drop frequency_penalty and presence_penalty
if "frequency_penalty" in non_default_params: if "frequency_penalty" in non_default_params:
del non_default_params["frequency_penalty"] del non_default_params["frequency_penalty"]
if "presence_penalty" in non_default_params: if "presence_penalty" in non_default_params:
del non_default_params["presence_penalty"] del non_default_params["presence_penalty"]
return super().map_openai_params( return super().map_openai_params(
model, non_default_params, optional_params, drop_params model=model,
non_default_params=non_default_params,
optional_params=optional_params,
drop_params=drop_params,
) )
@staticmethod def _transform_messages(
def _transform_messages(messages: List[AllMessageValues]) -> List[ContentType]: self, messages: List[AllMessageValues]
) -> List[ContentType]:
""" """
Google AI Studio Gemini does not support image urls in messages. Google AI Studio Gemini does not support image urls in messages.
""" """
@ -1075,9 +1098,14 @@ async def make_call(
raise VertexAIError( raise VertexAIError(
status_code=e.response.status_code, status_code=e.response.status_code,
message=VertexGeminiConfig().translate_exception_str(exception_string), message=VertexGeminiConfig().translate_exception_str(exception_string),
headers=e.response.headers,
) )
if response.status_code != 200: if response.status_code != 200:
raise VertexAIError(status_code=response.status_code, message=response.text) raise VertexAIError(
status_code=response.status_code,
message=response.text,
headers=response.headers,
)
completion_stream = ModelResponseIterator( completion_stream = ModelResponseIterator(
streaming_response=response.aiter_lines(), sync_stream=False streaming_response=response.aiter_lines(), sync_stream=False
@ -1111,7 +1139,11 @@ def make_sync_call(
response = client.post(api_base, headers=headers, data=data, stream=True) response = client.post(api_base, headers=headers, data=data, stream=True)
if response.status_code != 200: if response.status_code != 200:
raise VertexAIError(status_code=response.status_code, message=response.read()) raise VertexAIError(
status_code=response.status_code,
message=str(response.read()),
headers=response.headers,
)
completion_stream = ModelResponseIterator( completion_stream = ModelResponseIterator(
streaming_response=response.iter_lines(), sync_stream=True streaming_response=response.iter_lines(), sync_stream=True
@ -1182,7 +1214,13 @@ class VertexLLM(VertexBase):
should_use_v1beta1_features=should_use_v1beta1_features, should_use_v1beta1_features=should_use_v1beta1_features,
) )
headers = set_headers(auth_header=auth_header, extra_headers=extra_headers) headers = VertexGeminiConfig().validate_environment(
api_key=auth_header,
headers=extra_headers,
model=model,
messages=messages,
optional_params=optional_params,
)
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
@ -1263,7 +1301,13 @@ class VertexLLM(VertexBase):
should_use_v1beta1_features=should_use_v1beta1_features, should_use_v1beta1_features=should_use_v1beta1_features,
) )
headers = set_headers(auth_header=auth_header, extra_headers=extra_headers) headers = VertexGeminiConfig().validate_environment(
api_key=auth_header,
headers=extra_headers,
model=model,
messages=messages,
optional_params=optional_params,
)
request_body = await async_transform_request_body(**data) # type: ignore request_body = await async_transform_request_body(**data) # type: ignore
_async_client_params = {} _async_client_params = {}
@ -1287,23 +1331,32 @@ class VertexLLM(VertexBase):
) )
try: try:
response = await client.post(api_base, headers=headers, json=request_body) # type: ignore response = await client.post(
api_base, headers=headers, json=cast(dict, request_body)
) # type: ignore
response.raise_for_status() response.raise_for_status()
except httpx.HTTPStatusError as err: except httpx.HTTPStatusError as err:
error_code = err.response.status_code error_code = err.response.status_code
raise VertexAIError(status_code=error_code, message=err.response.text) raise VertexAIError(
status_code=error_code,
message=err.response.text,
headers=err.response.headers,
)
except httpx.TimeoutException: except httpx.TimeoutException:
raise VertexAIError(status_code=408, message="Timeout error occurred.") raise VertexAIError(
status_code=408,
message="Timeout error occurred.",
headers=None,
)
return VertexGeminiConfig()._transform_response( return VertexGeminiConfig().transform_response(
model=model, model=model,
response=response, raw_response=response,
model_response=model_response, model_response=model_response,
logging_obj=logging_obj, logging_obj=logging_obj,
api_key="", api_key="",
data=request_body, request_data=cast(dict, request_body),
messages=messages, messages=messages,
print_verbose=print_verbose,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params, litellm_params=litellm_params,
encoding=encoding, encoding=encoding,
@ -1421,7 +1474,13 @@ class VertexLLM(VertexBase):
api_base=api_base, api_base=api_base,
should_use_v1beta1_features=should_use_v1beta1_features, should_use_v1beta1_features=should_use_v1beta1_features,
) )
headers = set_headers(auth_header=auth_header, extra_headers=extra_headers) headers = VertexGeminiConfig().validate_environment(
api_key=auth_header,
headers=extra_headers,
model=model,
messages=messages,
optional_params=optional_params,
)
## TRANSFORMATION ## ## TRANSFORMATION ##
data = sync_transform_request_body(**transform_request_params) data = sync_transform_request_body(**transform_request_params)
@ -1479,21 +1538,28 @@ class VertexLLM(VertexBase):
response.raise_for_status() response.raise_for_status()
except httpx.HTTPStatusError as err: except httpx.HTTPStatusError as err:
error_code = err.response.status_code error_code = err.response.status_code
raise VertexAIError(status_code=error_code, message=err.response.text) raise VertexAIError(
status_code=error_code,
message=err.response.text,
headers=err.response.headers,
)
except httpx.TimeoutException: except httpx.TimeoutException:
raise VertexAIError(status_code=408, message="Timeout error occurred.") raise VertexAIError(
status_code=408,
message="Timeout error occurred.",
headers=None,
)
return VertexGeminiConfig()._transform_response( return VertexGeminiConfig().transform_response(
model=model, model=model,
response=response, raw_response=response,
model_response=model_response, model_response=model_response,
logging_obj=logging_obj, logging_obj=logging_obj,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params, litellm_params=litellm_params,
api_key="", api_key="",
data=data, # type: ignore request_data=data, # type: ignore
messages=messages, messages=messages,
print_verbose=print_verbose,
encoding=encoding, encoding=encoding,
) )

View file

@ -2,9 +2,10 @@ import types
from typing import Literal, Optional, Union from typing import Literal, Optional, Union
import litellm import litellm
from litellm.llms.openai_like.chat.transformation import OpenAILikeChatConfig
class VolcEngineConfig: class VolcEngineConfig(OpenAILikeChatConfig):
frequency_penalty: Optional[int] = None frequency_penalty: Optional[int] = None
function_call: Optional[Union[str, dict]] = None function_call: Optional[Union[str, dict]] = None
functions: Optional[list] = None functions: Optional[list] = None
@ -38,21 +39,7 @@ class VolcEngineConfig:
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return { return super().get_config()
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self, model: str) -> list: def get_supported_openai_params(self, model: str) -> list:
return [ return [
@ -77,14 +64,3 @@ class VolcEngineConfig:
"max_retries", "max_retries",
"extra_headers", "extra_headers",
] # works across all models ] # works across all models
def map_openai_params(
self, non_default_params: dict, optional_params: dict, model: str
) -> dict:
supported_openai_params = self.get_supported_openai_params(model)
for param, value in non_default_params.items():
if param == "max_completion_tokens":
optional_params["max_tokens"] = value
elif param in supported_openai_params:
optional_params[param] = value
return optional_params

View file

@ -274,6 +274,7 @@ class IBMWatsonXAIConfig(BaseConfig):
request_data: Dict, request_data: Dict,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: Dict, optional_params: Dict,
litellm_params: Dict,
encoding: str, encoding: str,
api_key: Optional[str] = None, api_key: Optional[str] = None,
json_mode: Optional[bool] = None, json_mode: Optional[bool] = None,

View file

@ -83,26 +83,13 @@ from litellm.utils import (
from ._logging import verbose_logger from ._logging import verbose_logger
from .caching.caching import disable_cache, enable_cache, update_cache from .caching.caching import disable_cache, enable_cache, update_cache
from .litellm_core_utils.streaming_chunk_builder_utils import ChunkProcessor from .litellm_core_utils.streaming_chunk_builder_utils import ChunkProcessor
from .llms import ( from .llms import aleph_alpha, baseten, maritalk, ollama_chat, petals
aleph_alpha,
baseten,
maritalk,
nlp_cloud,
ollama_chat,
oobabooga,
openrouter,
palm,
petals,
replicate,
)
from .llms.ai21 import completion as ai21
from .llms.anthropic.chat import AnthropicChatCompletion from .llms.anthropic.chat import AnthropicChatCompletion
from .llms.azure.audio_transcriptions import AzureAudioTranscription from .llms.azure.audio_transcriptions import AzureAudioTranscription
from .llms.azure.azure import AzureChatCompletion, _check_dynamic_azure_params from .llms.azure.azure import AzureChatCompletion, _check_dynamic_azure_params
from .llms.azure.chat.o1_handler import AzureOpenAIO1ChatCompletion from .llms.azure.chat.o1_handler import AzureOpenAIO1ChatCompletion
from .llms.azure_ai.chat import AzureAIChatCompletion from .llms.azure.completion.handler import AzureTextCompletion
from .llms.azure_ai.embed import AzureAIEmbedding from .llms.azure_ai.embed import AzureAIEmbedding
from .llms.azure_text import AzureTextCompletion
from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
from .llms.bedrock.embed.embedding import BedrockEmbedding from .llms.bedrock.embed.embedding import BedrockEmbedding
from .llms.bedrock.image.image_handler import BedrockImageGeneration from .llms.bedrock.image.image_handler import BedrockImageGeneration
@ -111,13 +98,16 @@ from .llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
from .llms.custom_llm import CustomLLM, custom_chat_llm_router from .llms.custom_llm import CustomLLM, custom_chat_llm_router
from .llms.databricks.chat.handler import DatabricksChatCompletion from .llms.databricks.chat.handler import DatabricksChatCompletion
from .llms.databricks.embed.handler import DatabricksEmbeddingHandler from .llms.databricks.embed.handler import DatabricksEmbeddingHandler
from .llms.deprecated_providers import palm
from .llms.groq.chat.handler import GroqChatCompletion from .llms.groq.chat.handler import GroqChatCompletion
from .llms.huggingface_restapi import Huggingface from .llms.huggingface.chat.handler import Huggingface
from .llms.nlp_cloud.chat.handler import completion as nlp_cloud_chat_completion
from .llms.oobabooga.chat import oobabooga
from .llms.ollama.completion import handler as ollama from .llms.ollama.completion import handler as ollama
from .llms.openai.transcriptions.handler import OpenAIAudioTranscription from .llms.openai.transcriptions.handler import OpenAIAudioTranscription
from .llms.openai.chat.o1_handler import OpenAIO1ChatCompletion
from .llms.openai.completion.handler import OpenAITextCompletion from .llms.openai.completion.handler import OpenAITextCompletion
from .llms.openai.openai import OpenAIChatCompletion from .llms.openai.openai import OpenAIChatCompletion
from .llms.openai_like.chat.handler import OpenAILikeChatHandler
from .llms.openai_like.embedding.handler import OpenAILikeEmbeddingHandler from .llms.openai_like.embedding.handler import OpenAILikeEmbeddingHandler
from .llms.predibase import PredibaseChatCompletion from .llms.predibase import PredibaseChatCompletion
from .llms.prompt_templates.common_utils import get_completion_messages from .llms.prompt_templates.common_utils import get_completion_messages
@ -131,6 +121,7 @@ from .llms.prompt_templates.factory import (
) )
from .llms.sagemaker.chat.handler import SagemakerChatHandler from .llms.sagemaker.chat.handler import SagemakerChatHandler
from .llms.sagemaker.completion.handler import SagemakerLLM from .llms.sagemaker.completion.handler import SagemakerLLM
from .llms.replicate.chat.handler import completion as replicate_chat_completion
from .llms.text_completion_codestral import CodestralTextCompletion from .llms.text_completion_codestral import CodestralTextCompletion
from .llms.together_ai.completion.handler import TogetherAITextCompletion from .llms.together_ai.completion.handler import TogetherAITextCompletion
from .llms.triton import TritonChatCompletion from .llms.triton import TritonChatCompletion
@ -159,7 +150,7 @@ from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings.embedding_handler im
from .llms.vertex_ai_and_google_ai_studio.vertex_model_garden.main import ( from .llms.vertex_ai_and_google_ai_studio.vertex_model_garden.main import (
VertexAIModelGardenModels, VertexAIModelGardenModels,
) )
from .llms.vllm.completion import handler from .llms.vllm.completion import handler as vllm_handler
from .llms.watsonx.chat.handler import WatsonXChatHandler from .llms.watsonx.chat.handler import WatsonXChatHandler
from .llms.watsonx.completion.handler import IBMWatsonXAI from .llms.watsonx.completion.handler import IBMWatsonXAI
from .types.llms.openai import ( from .types.llms.openai import (
@ -196,12 +187,10 @@ from litellm.utils import (
####### ENVIRONMENT VARIABLES ################### ####### ENVIRONMENT VARIABLES ###################
openai_chat_completions = OpenAIChatCompletion() openai_chat_completions = OpenAIChatCompletion()
openai_text_completions = OpenAITextCompletion() openai_text_completions = OpenAITextCompletion()
openai_o1_chat_completions = OpenAIO1ChatCompletion()
openai_audio_transcriptions = OpenAIAudioTranscription() openai_audio_transcriptions = OpenAIAudioTranscription()
databricks_chat_completions = DatabricksChatCompletion() databricks_chat_completions = DatabricksChatCompletion()
groq_chat_completions = GroqChatCompletion() groq_chat_completions = GroqChatCompletion()
together_ai_text_completions = TogetherAITextCompletion() together_ai_text_completions = TogetherAITextCompletion()
azure_ai_chat_completions = AzureAIChatCompletion()
azure_ai_embedding = AzureAIEmbedding() azure_ai_embedding = AzureAIEmbedding()
anthropic_chat_completions = AnthropicChatCompletion() anthropic_chat_completions = AnthropicChatCompletion()
azure_chat_completions = AzureChatCompletion() azure_chat_completions = AzureChatCompletion()
@ -228,6 +217,7 @@ watsonxai = IBMWatsonXAI()
sagemaker_llm = SagemakerLLM() sagemaker_llm = SagemakerLLM()
watsonx_chat_completion = WatsonXChatHandler() watsonx_chat_completion = WatsonXChatHandler()
openai_like_embedding = OpenAILikeEmbeddingHandler() openai_like_embedding = OpenAILikeEmbeddingHandler()
openai_like_chat_completion = OpenAILikeChatHandler()
databricks_embedding = DatabricksEmbeddingHandler() databricks_embedding = DatabricksEmbeddingHandler()
base_llm_http_handler = BaseLLMHTTPHandler() base_llm_http_handler = BaseLLMHTTPHandler()
sagemaker_chat_completion = SagemakerChatHandler() sagemaker_chat_completion = SagemakerChatHandler()
@ -449,6 +439,7 @@ async def acompletion(
or custom_llm_provider == "cerebras" or custom_llm_provider == "cerebras"
or custom_llm_provider == "sambanova" or custom_llm_provider == "sambanova"
or custom_llm_provider == "ai21_chat" or custom_llm_provider == "ai21_chat"
or custom_llm_provider == "ai21"
or custom_llm_provider == "volcengine" or custom_llm_provider == "volcengine"
or custom_llm_provider == "codestral" or custom_llm_provider == "codestral"
or custom_llm_provider == "text-completion-codestral" or custom_llm_provider == "text-completion-codestral"
@ -1316,7 +1307,7 @@ def completion( # type: ignore # noqa: PLR0915
## COMPLETION CALL ## COMPLETION CALL
try: try:
response = azure_ai_chat_completions.completion( response = openai_chat_completions.completion(
model=model, model=model,
messages=messages, messages=messages,
headers=headers, headers=headers,
@ -1513,9 +1504,7 @@ def completion( # type: ignore # noqa: PLR0915
or custom_llm_provider == "nvidia_nim" or custom_llm_provider == "nvidia_nim"
or custom_llm_provider == "cerebras" or custom_llm_provider == "cerebras"
or custom_llm_provider == "sambanova" or custom_llm_provider == "sambanova"
or custom_llm_provider == "ai21_chat"
or custom_llm_provider == "volcengine" or custom_llm_provider == "volcengine"
or custom_llm_provider == "codestral"
or custom_llm_provider == "deepseek" or custom_llm_provider == "deepseek"
or custom_llm_provider == "anyscale" or custom_llm_provider == "anyscale"
or custom_llm_provider == "mistral" or custom_llm_provider == "mistral"
@ -1562,46 +1551,25 @@ def completion( # type: ignore # noqa: PLR0915
## COMPLETION CALL ## COMPLETION CALL
try: try:
if litellm.openAIO1Config.is_model_o1_reasoning_model(model=model): response = openai_chat_completions.completion(
response = openai_o1_chat_completions.completion( model=model,
model=model, messages=messages,
messages=messages, headers=headers,
headers=headers, model_response=model_response,
model_response=model_response, print_verbose=print_verbose,
print_verbose=print_verbose, api_key=api_key,
api_key=api_key, api_base=api_base,
api_base=api_base, acompletion=acompletion,
acompletion=acompletion, logging_obj=logging,
logging_obj=logging, optional_params=optional_params,
optional_params=optional_params, litellm_params=litellm_params,
litellm_params=litellm_params, logger_fn=logger_fn,
logger_fn=logger_fn, timeout=timeout, # type: ignore
timeout=timeout, # type: ignore custom_prompt_dict=custom_prompt_dict,
custom_prompt_dict=custom_prompt_dict, client=client, # pass AsyncOpenAI, OpenAI client
client=client, # pass AsyncOpenAI, OpenAI client organization=organization,
organization=organization, custom_llm_provider=custom_llm_provider,
custom_llm_provider=custom_llm_provider, )
)
else:
response = openai_chat_completions.completion(
model=model,
messages=messages,
headers=headers,
model_response=model_response,
print_verbose=print_verbose,
api_key=api_key,
api_base=api_base,
acompletion=acompletion,
logging_obj=logging,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
timeout=timeout, # type: ignore
custom_prompt_dict=custom_prompt_dict,
client=client, # pass AsyncOpenAI, OpenAI client
organization=organization,
custom_llm_provider=custom_llm_provider,
)
except Exception as e: except Exception as e:
## LOGGING - log the original exception returned ## LOGGING - log the original exception returned
logging.post_call( logging.post_call(
@ -1627,7 +1595,6 @@ def completion( # type: ignore # noqa: PLR0915
or model in litellm.replicate_models or model in litellm.replicate_models
): ):
# Setting the relevant API KEY for replicate, replicate defaults to using os.environ.get("REPLICATE_API_TOKEN") # Setting the relevant API KEY for replicate, replicate defaults to using os.environ.get("REPLICATE_API_TOKEN")
replicate_key = None
replicate_key = ( replicate_key = (
api_key api_key
or litellm.replicate_key or litellm.replicate_key
@ -1645,7 +1612,7 @@ def completion( # type: ignore # noqa: PLR0915
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
model_response = replicate.completion( # type: ignore model_response = replicate_chat_completion( # type: ignore
model=model, model=model,
messages=messages, messages=messages,
api_base=api_base, api_base=api_base,
@ -1659,6 +1626,7 @@ def completion( # type: ignore # noqa: PLR0915
logging_obj=logging, logging_obj=logging,
custom_prompt_dict=custom_prompt_dict, custom_prompt_dict=custom_prompt_dict,
acompletion=acompletion, acompletion=acompletion,
headers=headers,
) )
if optional_params.get("stream", False) is True: if optional_params.get("stream", False) is True:
@ -1806,7 +1774,7 @@ def completion( # type: ignore # noqa: PLR0915
or "https://api.nlpcloud.io/v1/gpu/" or "https://api.nlpcloud.io/v1/gpu/"
) )
response = nlp_cloud.completion( response = nlp_cloud_chat_completion(
model=model, model=model,
messages=messages, messages=messages,
api_base=api_base, api_base=api_base,
@ -1969,10 +1937,10 @@ def completion( # type: ignore # noqa: PLR0915
api_base api_base
or litellm.api_base or litellm.api_base
or get_secret("MARITALK_API_BASE") or get_secret("MARITALK_API_BASE")
or "https://chat.maritaca.ai/api/chat/inference" or "https://chat.maritaca.ai/api"
) )
model_response = maritalk.completion( model_response = openai_like_chat_completion.completion(
model=model, model=model,
messages=messages, messages=messages,
api_base=api_base, api_base=api_base,
@ -1984,17 +1952,10 @@ def completion( # type: ignore # noqa: PLR0915
encoding=encoding, encoding=encoding,
api_key=maritalk_key, api_key=maritalk_key,
logging_obj=logging, logging_obj=logging,
custom_llm_provider="maritalk",
custom_prompt_dict=custom_prompt_dict,
) )
if "stream" in optional_params and optional_params["stream"] is True:
# don't try to access stream object,
response = CustomStreamWrapper(
model_response,
model,
custom_llm_provider="maritalk",
logging_obj=logging,
)
return response
response = model_response response = model_response
elif custom_llm_provider == "huggingface": elif custom_llm_provider == "huggingface":
custom_llm_provider = "huggingface" custom_llm_provider = "huggingface"
@ -2012,7 +1973,7 @@ def completion( # type: ignore # noqa: PLR0915
model=model, model=model,
messages=messages, messages=messages,
api_base=api_base, # type: ignore api_base=api_base, # type: ignore
headers=hf_headers, headers=hf_headers or {},
model_response=model_response, model_response=model_response,
print_verbose=print_verbose, print_verbose=print_verbose,
optional_params=optional_params, optional_params=optional_params,
@ -2024,6 +1985,7 @@ def completion( # type: ignore # noqa: PLR0915
logging_obj=logging, logging_obj=logging,
custom_prompt_dict=custom_prompt_dict, custom_prompt_dict=custom_prompt_dict,
timeout=timeout, # type: ignore timeout=timeout, # type: ignore
client=client,
) )
if ( if (
"stream" in optional_params "stream" in optional_params
@ -2146,7 +2108,7 @@ def completion( # type: ignore # noqa: PLR0915
headers = openrouter_headers headers = openrouter_headers
## Load Config ## Load Config
config = openrouter.OpenrouterConfig.get_config() config = litellm.OpenrouterConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if k == "extra_body": if k == "extra_body":
# we use openai 'extra_body' to pass openrouter specific params - transforms, route, models # we use openai 'extra_body' to pass openrouter specific params - transforms, route, models
@ -2190,30 +2152,9 @@ def completion( # type: ignore # noqa: PLR0915
""" """
pass pass
elif custom_llm_provider == "palm": elif custom_llm_provider == "palm":
palm_api_key = api_key or get_secret("PALM_API_KEY") or litellm.api_key raise ValueError(
"Palm was decommisioned on October 2024. Please use the `gemini/` route for Gemini Google AI Studio Models. Announcement: https://ai.google.dev/palm_docs/palm?hl=en"
# palm does not support streaming as yet :(
model_response = palm.completion(
model=model,
messages=messages,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
api_key=palm_api_key,
logging_obj=logging,
) )
# fake palm streaming
if "stream" in optional_params and optional_params["stream"] is True:
# fake streaming for palm
resp_string = model_response["choices"][0]["message"]["content"]
response = CustomStreamWrapper(
resp_string, model, custom_llm_provider="palm", logging_obj=logging
)
return response
response = model_response
elif custom_llm_provider == "vertex_ai_beta" or custom_llm_provider == "gemini": elif custom_llm_provider == "vertex_ai_beta" or custom_llm_provider == "gemini":
vertex_ai_project = ( vertex_ai_project = (
optional_params.pop("vertex_project", None) optional_params.pop("vertex_project", None)
@ -2475,51 +2416,9 @@ def completion( # type: ignore # noqa: PLR0915
): ):
return _model_response return _model_response
response = _model_response response = _model_response
elif custom_llm_provider == "ai21":
custom_llm_provider = "ai21"
ai21_key = (
api_key
or litellm.ai21_key
or os.environ.get("AI21_API_KEY")
or litellm.api_key
)
api_base = (
api_base
or litellm.api_base
or get_secret("AI21_API_BASE")
or "https://api.ai21.com/studio/v1/"
)
model_response = ai21.completion(
model=model,
messages=messages,
api_base=api_base,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
api_key=ai21_key,
logging_obj=logging,
)
if "stream" in optional_params and optional_params["stream"] is True:
# don't try to access stream object,
response = CustomStreamWrapper(
model_response,
model,
custom_llm_provider="ai21",
logging_obj=logging,
)
return response
## RESPONSE OBJECT
response = model_response
elif custom_llm_provider == "sagemaker_chat": elif custom_llm_provider == "sagemaker_chat":
# boto3 reads keys from .env # boto3 reads keys from .env
response = sagemaker_chat_completion.completion( model_response = sagemaker_chat_completion.completion(
model=model, model=model,
messages=messages, messages=messages,
model_response=model_response, model_response=model_response,
@ -2531,9 +2430,13 @@ def completion( # type: ignore # noqa: PLR0915
encoding=encoding, encoding=encoding,
logging_obj=logging, logging_obj=logging,
acompletion=acompletion, acompletion=acompletion,
headers=headers or {},
) )
elif custom_llm_provider == "sagemaker":
## RESPONSE OBJECT
response = model_response
elif (
custom_llm_provider == "sagemaker"
):
# boto3 reads keys from .env # boto3 reads keys from .env
model_response = sagemaker_llm.completion( model_response = sagemaker_llm.completion(
model=model, model=model,
@ -2691,7 +2594,7 @@ def completion( # type: ignore # noqa: PLR0915
response = response response = response
elif custom_llm_provider == "vllm": elif custom_llm_provider == "vllm":
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
model_response = handler.completion( model_response = vllm_handler.completion(
model=model, model=model,
messages=messages, messages=messages,
custom_prompt_dict=custom_prompt_dict, custom_prompt_dict=custom_prompt_dict,
@ -3872,6 +3775,7 @@ async def atext_completion(
or custom_llm_provider == "cerebras" or custom_llm_provider == "cerebras"
or custom_llm_provider == "sambanova" or custom_llm_provider == "sambanova"
or custom_llm_provider == "ai21_chat" or custom_llm_provider == "ai21_chat"
or custom_llm_provider == "ai21"
or custom_llm_provider == "volcengine" or custom_llm_provider == "volcengine"
or custom_llm_provider == "text-completion-codestral" or custom_llm_provider == "text-completion-codestral"
or custom_llm_provider == "deepseek" or custom_llm_provider == "deepseek"

View file

@ -56,6 +56,7 @@ class AnthropicPassthroughLoggingHandler:
request_data={}, request_data={},
encoding=litellm.encoding, encoding=litellm.encoding,
json_mode=False, json_mode=False,
litellm_params={},
) )
kwargs = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload( kwargs = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload(

View file

@ -41,20 +41,19 @@ class VertexPassthroughLoggingHandler:
instance_of_vertex_llm = litellm.VertexGeminiConfig() instance_of_vertex_llm = litellm.VertexGeminiConfig()
litellm_model_response: litellm.ModelResponse = ( litellm_model_response: litellm.ModelResponse = (
instance_of_vertex_llm._transform_response( instance_of_vertex_llm.transform_response(
model=model, model=model,
messages=[ messages=[
{"role": "user", "content": "no-message-pass-through-endpoint"} {"role": "user", "content": "no-message-pass-through-endpoint"}
], ],
response=httpx_response, raw_response=httpx_response,
model_response=litellm.ModelResponse(), model_response=litellm.ModelResponse(),
logging_obj=logging_obj, logging_obj=logging_obj,
optional_params={}, optional_params={},
litellm_params={}, litellm_params={},
api_key="", api_key="",
data={}, request_data={},
print_verbose=litellm.print_verbose, encoding=litellm.encoding,
encoding=None,
) )
) )
kwargs = VertexPassthroughLoggingHandler._create_vertex_response_logging_payload_for_generate_content( kwargs = VertexPassthroughLoggingHandler._create_vertex_response_logging_payload_for_generate_content(

View file

@ -2923,22 +2923,16 @@ def get_optional_params( # noqa: PLR0915
) )
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
if stream: optional_params = litellm.ReplicateConfig().map_openai_params(
optional_params["stream"] = stream non_default_params=non_default_params,
# return optional_params optional_params=optional_params,
if max_tokens is not None: model=model,
if "vicuna" in model or "flan" in model: drop_params=(
optional_params["max_length"] = max_tokens drop_params
elif "meta/codellama-13b" in model: if drop_params is not None and isinstance(drop_params, bool)
optional_params["max_tokens"] = max_tokens else False
else: ),
optional_params["max_new_tokens"] = max_tokens )
if temperature is not None:
optional_params["temperature"] = temperature
if top_p is not None:
optional_params["top_p"] = top_p
if stop is not None:
optional_params["stop_sequences"] = stop
elif custom_llm_provider == "predibase": elif custom_llm_provider == "predibase":
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider model=model, custom_llm_provider=custom_llm_provider
@ -2954,7 +2948,14 @@ def get_optional_params( # noqa: PLR0915
) )
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
optional_params = litellm.HuggingfaceConfig().map_openai_params( optional_params = litellm.HuggingfaceConfig().map_openai_params(
non_default_params=non_default_params, optional_params=optional_params non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
) )
elif custom_llm_provider == "together_ai": elif custom_llm_provider == "together_ai":
## check if unsupported param passed in ## check if unsupported param passed in
@ -2973,53 +2974,6 @@ def get_optional_params( # noqa: PLR0915
else False else False
), ),
) )
elif custom_llm_provider == "ai21":
## check if unsupported param passed in
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
if stream:
optional_params["stream"] = stream
if n is not None:
optional_params["numResults"] = n
if max_tokens is not None:
optional_params["maxTokens"] = max_tokens
if temperature is not None:
optional_params["temperature"] = temperature
if top_p is not None:
optional_params["topP"] = top_p
if stop is not None:
optional_params["stopSequences"] = stop
if frequency_penalty is not None:
optional_params["frequencyPenalty"] = {"scale": frequency_penalty}
if presence_penalty is not None:
optional_params["presencePenalty"] = {"scale": presence_penalty}
elif (
custom_llm_provider == "palm"
): # https://developers.generativeai.google/tutorials/curl_quickstart
## check if unsupported param passed in
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
if temperature is not None:
optional_params["temperature"] = temperature
if top_p is not None:
optional_params["top_p"] = top_p
if stream:
optional_params["stream"] = stream
if n is not None:
optional_params["candidate_count"] = n
if stop is not None:
if isinstance(stop, str):
optional_params["stop_sequences"] = [stop]
elif isinstance(stop, list):
optional_params["stop_sequences"] = stop
if max_tokens is not None:
optional_params["max_output_tokens"] = max_tokens
elif custom_llm_provider == "vertex_ai" and ( elif custom_llm_provider == "vertex_ai" and (
model in litellm.vertex_chat_models model in litellm.vertex_chat_models
or model in litellm.vertex_code_chat_models or model in litellm.vertex_code_chat_models
@ -3120,12 +3074,25 @@ def get_optional_params( # noqa: PLR0915
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
if "codestral" in model: if "codestral" in model:
optional_params = litellm.MistralTextCompletionConfig().map_openai_params( optional_params = litellm.MistralTextCompletionConfig().map_openai_params(
non_default_params=non_default_params, optional_params=optional_params model=model,
non_default_params=non_default_params,
optional_params=optional_params,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
) )
else: else:
optional_params = litellm.MistralConfig().map_openai_params( optional_params = litellm.MistralConfig().map_openai_params(
model=model,
non_default_params=non_default_params, non_default_params=non_default_params,
optional_params=optional_params, optional_params=optional_params,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
) )
elif custom_llm_provider == "vertex_ai" and model in litellm.vertex_ai_ai21_models: elif custom_llm_provider == "vertex_ai" and model in litellm.vertex_ai_ai21_models:
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
@ -3326,29 +3293,28 @@ def get_optional_params( # noqa: PLR0915
model=model, model=model,
non_default_params=non_default_params, non_default_params=non_default_params,
optional_params=optional_params, optional_params=optional_params,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
) )
elif custom_llm_provider == "nlp_cloud": elif custom_llm_provider == "nlp_cloud":
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider model=model, custom_llm_provider=custom_llm_provider
) )
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
optional_params = litellm.NLPCloudConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
if max_tokens is not None:
optional_params["max_length"] = max_tokens
if stream:
optional_params["stream"] = stream
if temperature is not None:
optional_params["temperature"] = temperature
if top_p is not None:
optional_params["top_p"] = top_p
if presence_penalty is not None:
optional_params["presence_penalty"] = presence_penalty
if frequency_penalty is not None:
optional_params["frequency_penalty"] = frequency_penalty
if n is not None:
optional_params["num_return_sequences"] = n
if stop is not None:
optional_params["stop_sequences"] = stop
elif custom_llm_provider == "petals": elif custom_llm_provider == "petals":
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider model=model, custom_llm_provider=custom_llm_provider
@ -3435,7 +3401,14 @@ def get_optional_params( # noqa: PLR0915
) )
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
optional_params = litellm.MistralConfig().map_openai_params( optional_params = litellm.MistralConfig().map_openai_params(
non_default_params=non_default_params, optional_params=optional_params non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
) )
elif custom_llm_provider == "text-completion-codestral": elif custom_llm_provider == "text-completion-codestral":
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
@ -3443,7 +3416,14 @@ def get_optional_params( # noqa: PLR0915
) )
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
optional_params = litellm.MistralTextCompletionConfig().map_openai_params( optional_params = litellm.MistralTextCompletionConfig().map_openai_params(
non_default_params=non_default_params, optional_params=optional_params non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
) )
elif custom_llm_provider == "databricks": elif custom_llm_provider == "databricks":
@ -3470,6 +3450,11 @@ def get_optional_params( # noqa: PLR0915
model=model, model=model,
non_default_params=non_default_params, non_default_params=non_default_params,
optional_params=optional_params, optional_params=optional_params,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
) )
elif custom_llm_provider == "cerebras": elif custom_llm_provider == "cerebras":
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
@ -3480,6 +3465,11 @@ def get_optional_params( # noqa: PLR0915
non_default_params=non_default_params, non_default_params=non_default_params,
optional_params=optional_params, optional_params=optional_params,
model=model, model=model,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
) )
elif custom_llm_provider == "xai": elif custom_llm_provider == "xai":
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
@ -3491,7 +3481,7 @@ def get_optional_params( # noqa: PLR0915
non_default_params=non_default_params, non_default_params=non_default_params,
optional_params=optional_params, optional_params=optional_params,
) )
elif custom_llm_provider == "ai21_chat": elif custom_llm_provider == "ai21_chat" or custom_llm_provider == "ai21":
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider model=model, custom_llm_provider=custom_llm_provider
) )
@ -3500,6 +3490,11 @@ def get_optional_params( # noqa: PLR0915
non_default_params=non_default_params, non_default_params=non_default_params,
optional_params=optional_params, optional_params=optional_params,
model=model, model=model,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
) )
elif custom_llm_provider == "fireworks_ai": elif custom_llm_provider == "fireworks_ai":
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
@ -3525,6 +3520,11 @@ def get_optional_params( # noqa: PLR0915
non_default_params=non_default_params, non_default_params=non_default_params,
optional_params=optional_params, optional_params=optional_params,
model=model, model=model,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
) )
elif custom_llm_provider == "hosted_vllm": elif custom_llm_provider == "hosted_vllm":
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
@ -3594,55 +3594,17 @@ def get_optional_params( # noqa: PLR0915
) )
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
if functions is not None: optional_params = litellm.OpenrouterConfig().map_openai_params(
optional_params["functions"] = functions non_default_params=non_default_params,
if function_call is not None: optional_params=optional_params,
optional_params["function_call"] = function_call model=model,
if temperature is not None: drop_params=(
optional_params["temperature"] = temperature drop_params
if top_p is not None: if drop_params is not None and isinstance(drop_params, bool)
optional_params["top_p"] = top_p else False
if n is not None: ),
optional_params["n"] = n
if stream is not None:
optional_params["stream"] = stream
if stop is not None:
optional_params["stop"] = stop
if max_tokens is not None:
optional_params["max_tokens"] = max_tokens
if presence_penalty is not None:
optional_params["presence_penalty"] = presence_penalty
if frequency_penalty is not None:
optional_params["frequency_penalty"] = frequency_penalty
if logit_bias is not None:
optional_params["logit_bias"] = logit_bias
if user is not None:
optional_params["user"] = user
if response_format is not None:
optional_params["response_format"] = response_format
if seed is not None:
optional_params["seed"] = seed
if tools is not None:
optional_params["tools"] = tools
if tool_choice is not None:
optional_params["tool_choice"] = tool_choice
if max_retries is not None:
optional_params["max_retries"] = max_retries
# OpenRouter-only parameters
extra_body = {}
transforms = passed_params.pop("transforms", None)
models = passed_params.pop("models", None)
route = passed_params.pop("route", None)
if transforms is not None:
extra_body["transforms"] = transforms
if models is not None:
extra_body["models"] = models
if route is not None:
extra_body["route"] = route
optional_params["extra_body"] = (
extra_body # openai client supports `extra_body` param
) )
elif custom_llm_provider == "watsonx": elif custom_llm_provider == "watsonx":
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider model=model, custom_llm_provider=custom_llm_provider
@ -3727,7 +3689,11 @@ def get_optional_params( # noqa: PLR0915
optional_params=optional_params, optional_params=optional_params,
model=model, model=model,
api_version=api_version, # type: ignore api_version=api_version, # type: ignore
drop_params=drop_params, drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
) )
else: # assume passing in params for text-completion openai else: # assume passing in params for text-completion openai
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
@ -6271,7 +6237,7 @@ from litellm.llms.base_llm.transformation import BaseConfig
class ProviderConfigManager: class ProviderConfigManager:
@staticmethod @staticmethod
def get_provider_chat_config( def get_provider_chat_config( # noqa: PLR0915
model: str, provider: litellm.LlmProviders model: str, provider: litellm.LlmProviders
) -> BaseConfig: ) -> BaseConfig:
""" """
@ -6333,6 +6299,60 @@ class ProviderConfigManager:
return litellm.LMStudioChatConfig() return litellm.LMStudioChatConfig()
elif litellm.LlmProviders.GALADRIEL == provider: elif litellm.LlmProviders.GALADRIEL == provider:
return litellm.GaladrielChatConfig() return litellm.GaladrielChatConfig()
elif litellm.LlmProviders.REPLICATE == provider:
return litellm.ReplicateConfig()
elif litellm.LlmProviders.HUGGINGFACE == provider:
return litellm.HuggingfaceConfig()
elif litellm.LlmProviders.TOGETHER_AI == provider:
return litellm.TogetherAIConfig()
elif litellm.LlmProviders.OPENROUTER == provider:
return litellm.OpenrouterConfig()
elif litellm.LlmProviders.GEMINI == provider:
return litellm.GoogleAIStudioGeminiConfig()
elif (
litellm.LlmProviders.AI21 == provider
or litellm.LlmProviders.AI21_CHAT == provider
):
return litellm.AI21ChatConfig()
elif litellm.LlmProviders.AZURE == provider:
return litellm.AzureOpenAIConfig()
elif litellm.LlmProviders.AZURE_AI == provider:
return litellm.AzureAIStudioConfig()
elif litellm.LlmProviders.AZURE_TEXT == provider:
return litellm.AzureOpenAITextConfig()
elif litellm.LlmProviders.HOSTED_VLLM == provider:
return litellm.HostedVLLMChatConfig()
elif litellm.LlmProviders.NLP_CLOUD == provider:
return litellm.NLPCloudConfig()
elif litellm.LlmProviders.OOBABOOGA == provider:
return litellm.OobaboogaConfig()
elif litellm.LlmProviders.OLLAMA_CHAT == provider:
return litellm.OllamaChatConfig()
elif litellm.LlmProviders.DEEPINFRA == provider:
return litellm.DeepInfraConfig()
elif litellm.LlmProviders.PERPLEXITY == provider:
return litellm.PerplexityChatConfig()
elif (
litellm.LlmProviders.MISTRAL == provider
or litellm.LlmProviders.CODESTRAL == provider
):
return litellm.MistralConfig()
elif litellm.LlmProviders.NVIDIA_NIM == provider:
return litellm.NvidiaNimConfig()
elif litellm.LlmProviders.CEREBRAS == provider:
return litellm.CerebrasConfig()
elif litellm.LlmProviders.VOLCENGINE == provider:
return litellm.VolcEngineConfig()
elif litellm.LlmProviders.TEXT_COMPLETION_CODESTRAL == provider:
return litellm.MistralTextCompletionConfig()
elif litellm.LlmProviders.SAMBANOVA == provider:
return litellm.SambanovaConfig()
elif litellm.LlmProviders.MARITALK == provider:
return litellm.MaritalkConfig()
elif litellm.LlmProviders.CLOUDFLARE == provider:
return litellm.CloudflareChatConfig()
elif litellm.LlmProviders.ANTHROPIC_TEXT == provider:
return litellm.AnthropicTextConfig()
elif litellm.LlmProviders.VLLM == provider: elif litellm.LlmProviders.VLLM == provider:
return litellm.VLLMConfig() return litellm.VLLMConfig()
elif litellm.LlmProviders.OLLAMA == provider: elif litellm.LlmProviders.OLLAMA == provider:

View file

@ -168,12 +168,17 @@ def test_all_model_configs():
drop_params=False, drop_params=False,
) == {"max_tokens": 10} ) == {"max_tokens": 10}
from litellm.llms.huggingface_restapi import HuggingfaceConfig from litellm.llms.huggingface.chat.handler import HuggingfaceConfig
assert "max_completion_tokens" in HuggingfaceConfig().get_supported_openai_params() assert "max_completion_tokens" in HuggingfaceConfig().get_supported_openai_params(
assert HuggingfaceConfig().map_openai_params({"max_completion_tokens": 10}, {}) == { model="llama3"
"max_new_tokens": 10 )
} assert HuggingfaceConfig().map_openai_params(
non_default_params={"max_completion_tokens": 10},
optional_params={},
model="llama3",
drop_params=False,
) == {"max_new_tokens": 10}
from litellm.llms.nvidia_nim.chat import NvidiaNimConfig from litellm.llms.nvidia_nim.chat import NvidiaNimConfig
@ -184,15 +189,19 @@ def test_all_model_configs():
model="llama3", model="llama3",
non_default_params={"max_completion_tokens": 10}, non_default_params={"max_completion_tokens": 10},
optional_params={}, optional_params={},
drop_params=False,
) == {"max_tokens": 10} ) == {"max_tokens": 10}
from litellm.llms.ollama_chat import OllamaChatConfig from litellm.llms.ollama_chat import OllamaChatConfig
assert "max_completion_tokens" in OllamaChatConfig().get_supported_openai_params() assert "max_completion_tokens" in OllamaChatConfig().get_supported_openai_params(
model="llama3"
)
assert OllamaChatConfig().map_openai_params( assert OllamaChatConfig().map_openai_params(
model="llama3", model="llama3",
non_default_params={"max_completion_tokens": 10}, non_default_params={"max_completion_tokens": 10},
optional_params={}, optional_params={},
drop_params=False,
) == {"num_predict": 10} ) == {"num_predict": 10}
from litellm.llms.predibase import PredibaseConfig from litellm.llms.predibase import PredibaseConfig
@ -207,11 +216,13 @@ def test_all_model_configs():
assert ( assert (
"max_completion_tokens" "max_completion_tokens"
in MistralTextCompletionConfig().get_supported_openai_params() in MistralTextCompletionConfig().get_supported_openai_params(model="llama3")
) )
assert MistralTextCompletionConfig().map_openai_params( assert MistralTextCompletionConfig().map_openai_params(
{"max_completion_tokens": 10}, model="llama3",
{}, non_default_params={"max_completion_tokens": 10},
optional_params={},
drop_params=False,
) == {"max_tokens": 10} ) == {"max_tokens": 10}
from litellm.llms.volcengine import VolcEngineConfig from litellm.llms.volcengine import VolcEngineConfig
@ -223,9 +234,10 @@ def test_all_model_configs():
model="llama3", model="llama3",
non_default_params={"max_completion_tokens": 10}, non_default_params={"max_completion_tokens": 10},
optional_params={}, optional_params={},
drop_params=False,
) == {"max_tokens": 10} ) == {"max_tokens": 10}
from litellm.llms.ai21.chat import AI21ChatConfig from litellm.llms.ai21.chat.transformation import AI21ChatConfig
assert "max_completion_tokens" in AI21ChatConfig().get_supported_openai_params( assert "max_completion_tokens" in AI21ChatConfig().get_supported_openai_params(
"jamba-1.5-mini@001" "jamba-1.5-mini@001"
@ -234,11 +246,14 @@ def test_all_model_configs():
model="jamba-1.5-mini@001", model="jamba-1.5-mini@001",
non_default_params={"max_completion_tokens": 10}, non_default_params={"max_completion_tokens": 10},
optional_params={}, optional_params={},
drop_params=False,
) == {"max_tokens": 10} ) == {"max_tokens": 10}
from litellm.llms.azure.chat.gpt_transformation import AzureOpenAIConfig from litellm.llms.azure.chat.gpt_transformation import AzureOpenAIConfig
assert "max_completion_tokens" in AzureOpenAIConfig().get_supported_openai_params() assert "max_completion_tokens" in AzureOpenAIConfig().get_supported_openai_params(
model="gpt-3.5-turbo"
)
assert AzureOpenAIConfig().map_openai_params( assert AzureOpenAIConfig().map_openai_params(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
non_default_params={"max_completion_tokens": 10}, non_default_params={"max_completion_tokens": 10},
@ -266,11 +281,13 @@ def test_all_model_configs():
assert ( assert (
"max_completion_tokens" "max_completion_tokens"
in MistralTextCompletionConfig().get_supported_openai_params() in MistralTextCompletionConfig().get_supported_openai_params(model="llama3")
) )
assert MistralTextCompletionConfig().map_openai_params( assert MistralTextCompletionConfig().map_openai_params(
model="llama3",
non_default_params={"max_completion_tokens": 10}, non_default_params={"max_completion_tokens": 10},
optional_params={}, optional_params={},
drop_params=False,
) == {"max_tokens": 10} ) == {"max_tokens": 10}
from litellm.llms.bedrock.common_utils import ( from litellm.llms.bedrock.common_utils import (
@ -341,7 +358,9 @@ def test_all_model_configs():
assert ( assert (
"max_completion_tokens" "max_completion_tokens"
in GoogleAIStudioGeminiConfig().get_supported_openai_params() in GoogleAIStudioGeminiConfig().get_supported_openai_params(
model="gemini-1.0-pro"
)
) )
assert GoogleAIStudioGeminiConfig().map_openai_params( assert GoogleAIStudioGeminiConfig().map_openai_params(
@ -351,7 +370,9 @@ def test_all_model_configs():
drop_params=False, drop_params=False,
) == {"max_output_tokens": 10} ) == {"max_output_tokens": 10}
assert "max_completion_tokens" in VertexGeminiConfig().get_supported_openai_params() assert "max_completion_tokens" in VertexGeminiConfig().get_supported_openai_params(
model="gemini-1.0-pro"
)
assert VertexGeminiConfig().map_openai_params( assert VertexGeminiConfig().map_openai_params(
model="gemini-1.0-pro", model="gemini-1.0-pro",

View file

@ -190,9 +190,10 @@ def test_databricks_optional_params():
custom_llm_provider="databricks", custom_llm_provider="databricks",
max_tokens=10, max_tokens=10,
temperature=0.2, temperature=0.2,
stream=True,
) )
print(f"optional_params: {optional_params}") print(f"optional_params: {optional_params}")
assert len(optional_params) == 2 assert len(optional_params) == 3
assert "user" not in optional_params assert "user" not in optional_params

View file

@ -449,8 +449,12 @@ def test_azure_tool_call_invoke_helper():
{"role": "assistant", "function_call": {"name": "get_weather"}}, {"role": "assistant", "function_call": {"name": "get_weather"}},
] ]
transformed_messages = litellm.AzureOpenAIConfig.transform_request( transformed_messages = litellm.AzureOpenAIConfig().transform_request(
model="gpt-4o", messages=messages, optional_params={} model="gpt-4o",
messages=messages,
optional_params={},
litellm_params={},
headers={},
) )
assert transformed_messages["messages"] == [ assert transformed_messages["messages"] == [

View file

@ -69,7 +69,7 @@ def test_batch_completions_models():
def test_batch_completion_models_all_responses(): def test_batch_completion_models_all_responses():
try: try:
responses = batch_completion_models_all_responses( responses = batch_completion_models_all_responses(
models=["j2-light", "claude-3-haiku-20240307"], models=["gemini/gemini-1.5-flash", "claude-3-haiku-20240307"],
messages=[{"role": "user", "content": "write a poem"}], messages=[{"role": "user", "content": "write a poem"}],
max_tokens=10, max_tokens=10,
) )

View file

@ -1606,30 +1606,33 @@ HF Tests we should pass
##################################################### #####################################################
##################################################### #####################################################
# Test util to sort models to TGI, conv, None # Test util to sort models to TGI, conv, None
from litellm.llms.huggingface.chat.transformation import HuggingfaceChatConfig
def test_get_hf_task_for_model(): def test_get_hf_task_for_model():
model = "glaiveai/glaive-coder-7b" model = "glaiveai/glaive-coder-7b"
model_type, _ = litellm.llms.huggingface_restapi.get_hf_task_for_model(model) model_type, _ = HuggingfaceChatConfig().get_hf_task_for_model(model)
print(f"model:{model}, model type: {model_type}") print(f"model:{model}, model type: {model_type}")
assert model_type == "text-generation-inference" assert model_type == "text-generation-inference"
model = "meta-llama/Llama-2-7b-hf" model = "meta-llama/Llama-2-7b-hf"
model_type, _ = litellm.llms.huggingface_restapi.get_hf_task_for_model(model) model_type, _ = HuggingfaceChatConfig().get_hf_task_for_model(model)
print(f"model:{model}, model type: {model_type}") print(f"model:{model}, model type: {model_type}")
assert model_type == "text-generation-inference" assert model_type == "text-generation-inference"
model = "facebook/blenderbot-400M-distill" model = "facebook/blenderbot-400M-distill"
model_type, _ = litellm.llms.huggingface_restapi.get_hf_task_for_model(model) model_type, _ = HuggingfaceChatConfig().get_hf_task_for_model(model)
print(f"model:{model}, model type: {model_type}") print(f"model:{model}, model type: {model_type}")
assert model_type == "conversational" assert model_type == "conversational"
model = "facebook/blenderbot-3B" model = "facebook/blenderbot-3B"
model_type, _ = litellm.llms.huggingface_restapi.get_hf_task_for_model(model) model_type, _ = HuggingfaceChatConfig().get_hf_task_for_model(model)
print(f"model:{model}, model type: {model_type}") print(f"model:{model}, model type: {model_type}")
assert model_type == "conversational" assert model_type == "conversational"
# neither Conv or None # neither Conv or None
model = "roneneldan/TinyStories-3M" model = "roneneldan/TinyStories-3M"
model_type, _ = litellm.llms.huggingface_restapi.get_hf_task_for_model(model) model_type, _ = HuggingfaceChatConfig().get_hf_task_for_model(model)
print(f"model:{model}, model type: {model_type}") print(f"model:{model}, model type: {model_type}")
assert model_type == "text-generation" assert model_type == "text-generation"
@ -1717,14 +1720,17 @@ def tgi_mock_post(url, **kwargs):
def test_hf_test_completion_tgi(): def test_hf_test_completion_tgi():
litellm.set_verbose = True litellm.set_verbose = True
try: try:
client = HTTPHandler()
with patch("requests.post", side_effect=tgi_mock_post) as mock_client: with patch.object(client, "post", side_effect=tgi_mock_post) as mock_client:
response = completion( response = completion(
model="huggingface/HuggingFaceH4/zephyr-7b-beta", model="huggingface/HuggingFaceH4/zephyr-7b-beta",
messages=[{"content": "Hello, how are you?", "role": "user"}], messages=[{"content": "Hello, how are you?", "role": "user"}],
max_tokens=10, max_tokens=10,
wait_for_model=True, wait_for_model=True,
client=client,
) )
mock_client.assert_called_once()
# Add any assertions-here to check the response # Add any assertions-here to check the response
print(response) print(response)
assert "options" in mock_client.call_args.kwargs["data"] assert "options" in mock_client.call_args.kwargs["data"]
@ -1862,13 +1868,15 @@ def mock_post(url, **kwargs):
def test_hf_classifier_task(): def test_hf_classifier_task():
try: try:
with patch("requests.post", side_effect=mock_post): client = HTTPHandler()
with patch.object(client, "post", side_effect=mock_post):
litellm.set_verbose = True litellm.set_verbose = True
user_message = "I like you. I love you" user_message = "I like you. I love you"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
response = completion( response = completion(
model="huggingface/text-classification/shahrukhx01/question-vs-statement-classifier", model="huggingface/text-classification/shahrukhx01/question-vs-statement-classifier",
messages=messages, messages=messages,
client=client,
) )
print(f"response: {response}") print(f"response: {response}")
assert isinstance(response, litellm.ModelResponse) assert isinstance(response, litellm.ModelResponse)
@ -3096,19 +3104,20 @@ async def test_completion_replicate_llama3(sync_mode):
response = completion( response = completion(
model=model_name, model=model_name,
messages=messages, messages=messages,
max_tokens=10,
) )
else: else:
response = await litellm.acompletion( response = await litellm.acompletion(
model=model_name, model=model_name,
messages=messages, messages=messages,
max_tokens=10,
) )
print(f"ASYNC REPLICATE RESPONSE - {response}") print(f"ASYNC REPLICATE RESPONSE - {response}")
print(response) print(f"REPLICATE RESPONSE - {response}")
# Add any assertions here to check the response # Add any assertions here to check the response
assert isinstance(response, litellm.ModelResponse) assert isinstance(response, litellm.ModelResponse)
assert len(response.choices[0].message.content.strip()) > 0
response_format_tests(response=response) response_format_tests(response=response)
except litellm.APIError as e:
pass
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@ -3745,22 +3754,6 @@ def test_mistral_anyscale_stream():
# pytest.fail(f"Error occurred: {e}") # pytest.fail(f"Error occurred: {e}")
#### Test A121 ###################
@pytest.mark.skip(reason="Local test")
def test_completion_ai21():
print("running ai21 j2light test")
litellm.set_verbose = True
model_name = "j2-light"
try:
response = completion(
model=model_name, messages=messages, max_tokens=100, temperature=0.8
)
# Add any assertions here to check the response
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_ai21() # test_completion_ai21()
# test_completion_ai21() # test_completion_ai21()
## test deep infra ## test deep infra

View file

@ -165,10 +165,10 @@ def test_get_gpt3_tokens():
# test_get_gpt3_tokens() # test_get_gpt3_tokens()
def test_get_palm_tokens(): def test_get_gemini_tokens():
# # 🦄🦄🦄🦄🦄🦄🦄🦄 # # 🦄🦄🦄🦄🦄🦄🦄🦄
max_tokens = get_max_tokens("palm/chat-bison") max_tokens = get_max_tokens("gemini/gemini-1.5-flash")
assert max_tokens == 4096 assert max_tokens == 8192
print(max_tokens) print(max_tokens)

View file

@ -29,19 +29,6 @@ def logger_fn(user_model_dict):
pass pass
# completion with num retries + impact on exception mapping
def test_completion_with_num_retries():
try:
response = completion(
model="j2-ultra",
messages=[{"messages": "vibe", "bad": "message"}],
num_retries=2,
)
pytest.fail(f"Unmapped exception occurred")
except Exception as e:
pass
# test_completion_with_num_retries() # test_completion_with_num_retries()
def test_completion_with_0_num_retries(): def test_completion_with_0_num_retries():
try: try:

View file

@ -290,35 +290,46 @@ async def test_add_and_delete_deployments(llm_router, model_list_flag_value):
assert len(llm_router.model_list) == len(model_list) + prev_llm_router_val assert len(llm_router.model_list) == len(model_list) + prev_llm_router_val
def test_provider_config_manager(): from litellm import LITELLM_CHAT_PROVIDERS, LlmProviders
from litellm import LITELLM_CHAT_PROVIDERS, LlmProviders from litellm.utils import ProviderConfigManager
from litellm.utils import ProviderConfigManager from litellm.llms.base_llm.transformation import BaseConfig
from litellm.llms.base_llm.transformation import BaseConfig
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
for provider in LITELLM_CHAT_PROVIDERS:
if provider == LlmProviders.TRITON or provider == LlmProviders.PREDIBASE:
continue
assert isinstance(
ProviderConfigManager.get_provider_chat_config(
model="gpt-3.5-turbo", provider=LlmProviders(provider)
),
BaseConfig,
), f"Provider {provider} is not a subclass of BaseConfig"
config = ProviderConfigManager.get_provider_chat_config( def _check_provider_config(config: BaseConfig, provider: LlmProviders):
model="gpt-3.5-turbo", provider=LlmProviders(provider) assert isinstance(
) config,
BaseConfig,
if ( ), f"Provider {provider} is not a subclass of BaseConfig. Got={config}"
provider != litellm.LlmProviders.OPENAI
and provider != litellm.LlmProviders.OPENAI_LIKE
and provider != litellm.LlmProviders.CUSTOM_OPENAI
):
assert (
config.__class__.__name__ != "OpenAIGPTConfig"
), f"Provider {provider} is an instance of OpenAIGPTConfig"
if (
provider != litellm.LlmProviders.OPENAI
and provider != litellm.LlmProviders.OPENAI_LIKE
and provider != litellm.LlmProviders.CUSTOM_OPENAI
):
assert ( assert (
"_abc_impl" not in config.get_config() config.__class__.__name__ != "OpenAIGPTConfig"
), f"Provider {provider} has _abc_impl" ), f"Provider {provider} is an instance of OpenAIGPTConfig"
assert "_abc_impl" not in config.get_config(), f"Provider {provider} has _abc_impl"
# def test_provider_config_manager():
# from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
# for provider in LITELLM_CHAT_PROVIDERS:
# if (
# provider == LlmProviders.VERTEX_AI
# or provider == LlmProviders.VERTEX_AI_BETA
# or provider == LlmProviders.BEDROCK
# or provider == LlmProviders.BASETEN
# or provider == LlmProviders.SAGEMAKER
# or provider == LlmProviders.SAGEMAKER_CHAT
# or provider == LlmProviders.VLLM
# or provider == LlmProviders.PETALS
# or provider == LlmProviders.OLLAMA
# ):
# continue
# config = ProviderConfigManager.get_provider_chat_config(
# model="gpt-3.5-turbo", provider=LlmProviders(provider)
# )
# _check_provider_config(config, provider)

View file

@ -522,6 +522,7 @@ async def test_basic_gcs_logging_per_request_with_no_litellm_callback_set():
) )
@pytest.mark.flaky(retries=5, delay=3)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_gcs_logging_config_without_service_account(): async def test_get_gcs_logging_config_without_service_account():
""" """

View file

@ -167,51 +167,6 @@ def cohere_test_completion():
# cohere_test_completion() # cohere_test_completion()
# AI21
def ai21_test_completion():
litellm.AI21Config(maxTokens=10)
litellm.set_verbose = True
try:
# OVERRIDE WITH DYNAMIC MAX TOKENS
response_1 = litellm.completion(
model="j2-mid",
messages=[
{
"content": "Hello, how are you? Be as verbose as possible",
"role": "user",
}
],
max_tokens=100,
)
response_1_text = response_1.choices[0].message.content
print(f"response_1_text: {response_1_text}")
# USE CONFIG TOKENS
response_2 = litellm.completion(
model="j2-mid",
messages=[
{
"content": "Hello, how are you? Be as verbose as possible",
"role": "user",
}
],
)
response_2_text = response_2.choices[0].message.content
print(f"response_2_text: {response_2_text}")
assert len(response_2_text) < len(response_1_text)
response_3 = litellm.completion(
model="j2-light",
messages=[{"content": "Hello, how are you?", "role": "user"}],
n=2,
)
assert len(response_3.choices) > 1
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# ai21_test_completion() # ai21_test_completion()

View file

@ -47,6 +47,7 @@ def cleanup_redis():
print(f"Error cleaning up Redis: {str(e)}") print(f"Error cleaning up Redis: {str(e)}")
@pytest.mark.flaky(retries=6, delay=2)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_provider_budgets_e2e_test(): async def test_provider_budgets_e2e_test():
""" """
@ -106,7 +107,7 @@ async def test_provider_budgets_e2e_test():
print("response.hidden_params", response._hidden_params) print("response.hidden_params", response._hidden_params)
await asyncio.sleep(0.5) await asyncio.sleep(1)
assert response._hidden_params.get("custom_llm_provider") == "azure" assert response._hidden_params.get("custom_llm_provider") == "azure"

View file

@ -1931,66 +1931,11 @@ async def test_completion_watsonx_stream():
# raise Exception("Empty response received") # raise Exception("Empty response received")
# except Exception: # except Exception:
# pytest.fail(f"error occurred: {traceback.format_exc()}") # pytest.fail(f"error occurred: {traceback.format_exc()}")
# test_maritalk_streaming()
# test on openai completion call
# # test on ai21 completion call
def ai21_completion_call():
try:
messages = [
{
"role": "system",
"content": "You are an all-knowing oracle",
},
{"role": "user", "content": "What is the meaning of the Universe?"},
]
response = completion(
model="j2-ultra", messages=messages, stream=True, max_tokens=500
)
print(f"response: {response}")
has_finished = False
complete_response = ""
start_time = time.time()
for idx, chunk in enumerate(response):
chunk, finished = streaming_format_tests(idx, chunk)
has_finished = finished
complete_response += chunk
if finished:
break
if has_finished is False:
raise Exception("finished reason missing from final chunk")
if complete_response.strip() == "":
raise Exception("Empty response received")
print(f"completion_response: {complete_response}")
except Exception:
pytest.fail(f"error occurred: {traceback.format_exc()}")
# ai21_completion_call() # ai21_completion_call()
def ai21_completion_call_bad_key():
try:
api_key = "bad-key"
response = completion(
model="j2-ultra", messages=messages, stream=True, api_key=api_key
)
print(f"response: {response}")
complete_response = ""
start_time = time.time()
for idx, chunk in enumerate(response):
chunk, finished = streaming_format_tests(idx, chunk)
if finished:
break
complete_response += chunk
if complete_response.strip() == "":
raise Exception("Empty response received")
print(f"completion_response: {complete_response}")
except Exception:
pytest.fail(f"error occurred: {traceback.format_exc()}")
# ai21_completion_call_bad_key() # ai21_completion_call_bad_key()
@ -2418,34 +2363,6 @@ def test_completion_openai_with_functions():
#### Test Async streaming #### #### Test Async streaming ####
# # test on ai21 completion call
async def ai21_async_completion_call():
try:
response = completion(
model="j2-ultra", messages=messages, stream=True, logger_fn=logger_fn
)
print(f"response: {response}")
complete_response = ""
start_time = time.time()
# Change for loop to async for loop
idx = 0
async for chunk in response:
chunk, finished = streaming_format_tests(idx, chunk)
if finished:
break
complete_response += chunk
idx += 1
if complete_response.strip() == "":
raise Exception("Empty response received")
print(f"complete response: {complete_response}")
except Exception:
print(f"error occurred: {traceback.format_exc()}")
pass
# asyncio.run(ai21_async_completion_call())
async def completion_call(): async def completion_call():
try: try:
response = completion( response = completion(

View file

@ -3934,6 +3934,7 @@ def test_completion_text_003_prompt_array():
##### hugging face tests ##### hugging face tests
@pytest.mark.skip(reason="local test")
def test_completion_hf_prompt_array(): def test_completion_hf_prompt_array():
try: try:
litellm.set_verbose = True litellm.set_verbose = True

View file

@ -437,8 +437,8 @@ def test_token_counter():
print(tokens) print(tokens)
assert tokens > 0 assert tokens > 0
tokens = token_counter(model="palm/chat-bison", messages=messages) tokens = token_counter(model="gemini/chat-bison", messages=messages)
print("palm/chat-bison") print("gemini/chat-bison")
print(tokens) print(tokens)
assert tokens > 0 assert tokens > 0
@ -465,7 +465,7 @@ def test_token_counter():
("azure/gpt-4-1106-preview", True), ("azure/gpt-4-1106-preview", True),
("groq/gemma-7b-it", True), ("groq/gemma-7b-it", True),
("anthropic.claude-instant-v1", False), ("anthropic.claude-instant-v1", False),
("palm/chat-bison", False), ("gemini/gemini-1.5-flash", True),
], ],
) )
def test_supports_function_calling(model, expected_bool): def test_supports_function_calling(model, expected_bool):