mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
Litellm merge pr (#7161)
* build: merge branch * test: fix openai naming * fix(main.py): fix openai renaming * style: ignore function length for config factory * fix(sagemaker/): fix routing logic * fix: fix imports * fix: fix override
This commit is contained in:
parent
d5aae81c6d
commit
350cfc36f7
88 changed files with 3617 additions and 4421 deletions
|
@ -1,43 +0,0 @@
|
||||||
# PaLM API - Google
|
|
||||||
|
|
||||||
:::warning
|
|
||||||
|
|
||||||
Warning: [The PaLM API is decomissioned by Google](https://ai.google.dev/palm_docs/deprecation) The PaLM API is scheduled to be decomissioned in October 2024. Please upgrade to the Gemini API or Vertex AI API
|
|
||||||
|
|
||||||
:::
|
|
||||||
|
|
||||||
## Pre-requisites
|
|
||||||
* `pip install -q google-generativeai`
|
|
||||||
|
|
||||||
## Sample Usage
|
|
||||||
```python
|
|
||||||
from litellm import completion
|
|
||||||
import os
|
|
||||||
|
|
||||||
os.environ['PALM_API_KEY'] = ""
|
|
||||||
response = completion(
|
|
||||||
model="palm/chat-bison",
|
|
||||||
messages=[{"role": "user", "content": "write code for saying hi from LiteLLM"}]
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Sample Usage - Streaming
|
|
||||||
```python
|
|
||||||
from litellm import completion
|
|
||||||
import os
|
|
||||||
|
|
||||||
os.environ['PALM_API_KEY'] = ""
|
|
||||||
response = completion(
|
|
||||||
model="palm/chat-bison",
|
|
||||||
messages=[{"role": "user", "content": "write code for saying hi from LiteLLM"}],
|
|
||||||
stream=True
|
|
||||||
)
|
|
||||||
|
|
||||||
for chunk in response:
|
|
||||||
print(chunk)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Chat Models
|
|
||||||
| Model Name | Function Call | Required OS Variables |
|
|
||||||
|------------------|--------------------------------------|-------------------------|
|
|
||||||
| chat-bison | `completion('palm/chat-bison', messages)` | `os.environ['PALM_API_KEY']` |
|
|
|
@ -190,11 +190,9 @@ const sidebars = {
|
||||||
"providers/aleph_alpha",
|
"providers/aleph_alpha",
|
||||||
"providers/baseten",
|
"providers/baseten",
|
||||||
"providers/openrouter",
|
"providers/openrouter",
|
||||||
"providers/palm",
|
|
||||||
"providers/sambanova",
|
"providers/sambanova",
|
||||||
"providers/custom_llm_server",
|
"providers/custom_llm_server",
|
||||||
"providers/petals",
|
"providers/petals",
|
||||||
|
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
|
@ -601,6 +601,7 @@ openai_compatible_providers: List = [
|
||||||
"cerebras",
|
"cerebras",
|
||||||
"sambanova",
|
"sambanova",
|
||||||
"ai21_chat",
|
"ai21_chat",
|
||||||
|
"ai21",
|
||||||
"volcengine",
|
"volcengine",
|
||||||
"codestral",
|
"codestral",
|
||||||
"deepseek",
|
"deepseek",
|
||||||
|
@ -853,7 +854,6 @@ class LlmProviders(str, Enum):
|
||||||
OPENROUTER = "openrouter"
|
OPENROUTER = "openrouter"
|
||||||
VERTEX_AI = "vertex_ai"
|
VERTEX_AI = "vertex_ai"
|
||||||
VERTEX_AI_BETA = "vertex_ai_beta"
|
VERTEX_AI_BETA = "vertex_ai_beta"
|
||||||
PALM = "palm"
|
|
||||||
GEMINI = "gemini"
|
GEMINI = "gemini"
|
||||||
AI21 = "ai21"
|
AI21 = "ai21"
|
||||||
BASETEN = "baseten"
|
BASETEN = "baseten"
|
||||||
|
@ -871,7 +871,6 @@ class LlmProviders(str, Enum):
|
||||||
OLLAMA_CHAT = "ollama_chat"
|
OLLAMA_CHAT = "ollama_chat"
|
||||||
DEEPINFRA = "deepinfra"
|
DEEPINFRA = "deepinfra"
|
||||||
PERPLEXITY = "perplexity"
|
PERPLEXITY = "perplexity"
|
||||||
ANYSCALE = "anyscale"
|
|
||||||
MISTRAL = "mistral"
|
MISTRAL = "mistral"
|
||||||
GROQ = "groq"
|
GROQ = "groq"
|
||||||
NVIDIA_NIM = "nvidia_nim"
|
NVIDIA_NIM = "nvidia_nim"
|
||||||
|
@ -1057,10 +1056,15 @@ from .types.utils import ImageObject
|
||||||
from .llms.custom_llm import CustomLLM
|
from .llms.custom_llm import CustomLLM
|
||||||
from .llms.openai_like.chat.handler import OpenAILikeChatConfig
|
from .llms.openai_like.chat.handler import OpenAILikeChatConfig
|
||||||
from .llms.galadriel.chat.transformation import GaladrielChatConfig
|
from .llms.galadriel.chat.transformation import GaladrielChatConfig
|
||||||
from .llms.huggingface_restapi import HuggingfaceConfig
|
|
||||||
from .llms.empower.chat.transformation import EmpowerChatConfig
|
|
||||||
from .llms.github.chat.transformation import GithubChatConfig
|
from .llms.github.chat.transformation import GithubChatConfig
|
||||||
from .llms.anthropic.chat.handler import AnthropicConfig
|
from .llms.empower.chat.transformation import EmpowerChatConfig
|
||||||
|
from .llms.huggingface.chat.transformation import (
|
||||||
|
HuggingfaceChatConfig as HuggingfaceConfig,
|
||||||
|
)
|
||||||
|
from .llms.oobabooga.chat.transformation import OobaboogaConfig
|
||||||
|
from .llms.maritalk import MaritalkConfig
|
||||||
|
from .llms.openrouter.chat.transformation import OpenrouterConfig
|
||||||
|
from .llms.anthropic.chat.transformation import AnthropicConfig
|
||||||
from .llms.anthropic.experimental_pass_through.transformation import (
|
from .llms.anthropic.experimental_pass_through.transformation import (
|
||||||
AnthropicExperimentalPassThroughConfig,
|
AnthropicExperimentalPassThroughConfig,
|
||||||
)
|
)
|
||||||
|
@ -1069,24 +1073,26 @@ from .llms.anthropic.completion.transformation import AnthropicTextConfig
|
||||||
from .llms.databricks.chat.transformation import DatabricksConfig
|
from .llms.databricks.chat.transformation import DatabricksConfig
|
||||||
from .llms.databricks.embed.transformation import DatabricksEmbeddingConfig
|
from .llms.databricks.embed.transformation import DatabricksEmbeddingConfig
|
||||||
from .llms.predibase import PredibaseConfig
|
from .llms.predibase import PredibaseConfig
|
||||||
from .llms.replicate import ReplicateConfig
|
from .llms.replicate.chat.transformation import ReplicateConfig
|
||||||
from .llms.cohere.completion.transformation import CohereTextConfig as CohereConfig
|
from .llms.cohere.completion.transformation import CohereTextConfig as CohereConfig
|
||||||
from .llms.clarifai.chat.transformation import ClarifaiConfig
|
from .llms.clarifai.chat.transformation import ClarifaiConfig
|
||||||
from .llms.cloudflare.chat.transformation import CloudflareChatConfig
|
from .llms.ai21.chat.transformation import AI21ChatConfig, AI21ChatConfig as AI21Config
|
||||||
from .llms.ai21.completion import AI21Config
|
|
||||||
from .llms.ai21.chat import AI21ChatConfig
|
|
||||||
from .llms.together_ai.chat import TogetherAIConfig
|
from .llms.together_ai.chat import TogetherAIConfig
|
||||||
from .llms.palm import PalmConfig
|
from .llms.cloudflare.chat.transformation import CloudflareChatConfig
|
||||||
from .llms.gemini import GeminiConfig
|
from .llms.deprecated_providers.palm import (
|
||||||
from .llms.nlp_cloud import NLPCloudConfig
|
PalmConfig,
|
||||||
|
) # here to prevent breaking changes
|
||||||
|
from .llms.nlp_cloud.chat.handler import NLPCloudConfig
|
||||||
from .llms.aleph_alpha import AlephAlphaConfig
|
from .llms.aleph_alpha import AlephAlphaConfig
|
||||||
from .llms.petals import PetalsConfig
|
from .llms.petals import PetalsConfig
|
||||||
from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||||
VertexGeminiConfig,
|
VertexGeminiConfig,
|
||||||
GoogleAIStudioGeminiConfig,
|
GoogleAIStudioGeminiConfig,
|
||||||
VertexAIConfig,
|
VertexAIConfig,
|
||||||
|
GoogleAIStudioGeminiConfig as GeminiConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings.transformation import (
|
from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings.transformation import (
|
||||||
VertexAITextEmbeddingConfig,
|
VertexAITextEmbeddingConfig,
|
||||||
)
|
)
|
||||||
|
@ -1107,7 +1113,6 @@ from .llms.ollama.completion.transformation import OllamaConfig
|
||||||
from .llms.sagemaker.completion.transformation import SagemakerConfig
|
from .llms.sagemaker.completion.transformation import SagemakerConfig
|
||||||
from .llms.sagemaker.chat.transformation import SagemakerChatConfig
|
from .llms.sagemaker.chat.transformation import SagemakerChatConfig
|
||||||
from .llms.ollama_chat import OllamaChatConfig
|
from .llms.ollama_chat import OllamaChatConfig
|
||||||
from .llms.maritalk import MaritTalkConfig
|
|
||||||
from .llms.bedrock.chat.invoke_handler import (
|
from .llms.bedrock.chat.invoke_handler import (
|
||||||
AmazonCohereChatConfig,
|
AmazonCohereChatConfig,
|
||||||
AmazonConverseConfig,
|
AmazonConverseConfig,
|
||||||
|
@ -1134,11 +1139,8 @@ from .llms.bedrock.embed.amazon_titan_v2_transformation import (
|
||||||
)
|
)
|
||||||
from .llms.cohere.chat.transformation import CohereChatConfig
|
from .llms.cohere.chat.transformation import CohereChatConfig
|
||||||
from .llms.bedrock.embed.cohere_transformation import BedrockCohereEmbeddingConfig
|
from .llms.bedrock.embed.cohere_transformation import BedrockCohereEmbeddingConfig
|
||||||
from .llms.openai.openai import (
|
from .llms.openai.openai import OpenAIConfig, MistralEmbeddingConfig
|
||||||
OpenAIConfig,
|
from .llms.deepinfra.chat.transformation import DeepInfraConfig
|
||||||
MistralEmbeddingConfig,
|
|
||||||
DeepInfraConfig,
|
|
||||||
)
|
|
||||||
from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig
|
from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig
|
||||||
from .llms.groq.chat.transformation import GroqChatConfig
|
from .llms.groq.chat.transformation import GroqChatConfig
|
||||||
from .llms.azure_ai.chat.transformation import AzureAIStudioConfig
|
from .llms.azure_ai.chat.transformation import AzureAIStudioConfig
|
||||||
|
@ -1167,7 +1169,7 @@ nvidiaNimEmbeddingConfig = NvidiaNimEmbeddingConfig()
|
||||||
|
|
||||||
from .llms.cerebras.chat import CerebrasConfig
|
from .llms.cerebras.chat import CerebrasConfig
|
||||||
from .llms.sambanova.chat import SambanovaConfig
|
from .llms.sambanova.chat import SambanovaConfig
|
||||||
from .llms.ai21.chat import AI21ChatConfig
|
from .llms.ai21.chat.transformation import AI21ChatConfig
|
||||||
from .llms.fireworks_ai.chat.transformation import FireworksAIConfig
|
from .llms.fireworks_ai.chat.transformation import FireworksAIConfig
|
||||||
from .llms.fireworks_ai.embed.fireworks_ai_transformation import (
|
from .llms.fireworks_ai.embed.fireworks_ai_transformation import (
|
||||||
FireworksAIEmbeddingConfig,
|
FireworksAIEmbeddingConfig,
|
||||||
|
@ -1183,6 +1185,7 @@ from .llms.azure.azure import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from .llms.azure.chat.gpt_transformation import AzureOpenAIConfig
|
from .llms.azure.chat.gpt_transformation import AzureOpenAIConfig
|
||||||
|
from .llms.azure.completion.transformation import AzureOpenAITextConfig
|
||||||
from .llms.hosted_vllm.chat.transformation import HostedVLLMChatConfig
|
from .llms.hosted_vllm.chat.transformation import HostedVLLMChatConfig
|
||||||
from .llms.vllm.completion.transformation import VLLMConfig
|
from .llms.vllm.completion.transformation import VLLMConfig
|
||||||
from .llms.deepseek.chat.transformation import DeepSeekChatConfig
|
from .llms.deepseek.chat.transformation import DeepSeekChatConfig
|
||||||
|
|
|
@ -3,54 +3,51 @@ DEFAULT_BATCH_SIZE = 512
|
||||||
DEFAULT_FLUSH_INTERVAL_SECONDS = 5
|
DEFAULT_FLUSH_INTERVAL_SECONDS = 5
|
||||||
DEFAULT_MAX_RETRIES = 2
|
DEFAULT_MAX_RETRIES = 2
|
||||||
LITELLM_CHAT_PROVIDERS = [
|
LITELLM_CHAT_PROVIDERS = [
|
||||||
# "openai",
|
"openai",
|
||||||
# "openai_like",
|
"openai_like",
|
||||||
# "xai",
|
"xai",
|
||||||
# "custom_openai",
|
"custom_openai",
|
||||||
# "text-completion-openai",
|
"text-completion-openai",
|
||||||
# "cohere",
|
"cohere",
|
||||||
# "cohere_chat",
|
"cohere_chat",
|
||||||
# "clarifai",
|
"clarifai",
|
||||||
# "anthropic",
|
"anthropic",
|
||||||
# "anthropic_text",
|
"anthropic_text",
|
||||||
# "replicate",
|
"replicate",
|
||||||
# "huggingface",
|
"huggingface",
|
||||||
# "together_ai",
|
"together_ai",
|
||||||
# "openrouter",
|
"openrouter",
|
||||||
# "vertex_ai",
|
"vertex_ai",
|
||||||
# "vertex_ai_beta",
|
"vertex_ai_beta",
|
||||||
# "palm",
|
"gemini",
|
||||||
# "gemini",
|
"ai21",
|
||||||
# "ai21",
|
"baseten",
|
||||||
# "baseten",
|
"azure",
|
||||||
# "azure",
|
"azure_text",
|
||||||
# "azure_text",
|
"azure_ai",
|
||||||
# "azure_ai",
|
"sagemaker",
|
||||||
# "sagemaker",
|
"sagemaker_chat",
|
||||||
# "sagemaker_chat",
|
"bedrock",
|
||||||
# "bedrock",
|
|
||||||
"vllm",
|
"vllm",
|
||||||
# "nlp_cloud",
|
"nlp_cloud",
|
||||||
# "petals",
|
"petals",
|
||||||
# "oobabooga",
|
"oobabooga",
|
||||||
"ollama",
|
"ollama",
|
||||||
# "ollama_chat",
|
"ollama_chat",
|
||||||
# "deepinfra",
|
"deepinfra",
|
||||||
# "perplexity",
|
"perplexity",
|
||||||
# "anyscale",
|
"mistral",
|
||||||
# "mistral",
|
"groq",
|
||||||
# "groq",
|
"nvidia_nim",
|
||||||
# "nvidia_nim",
|
"cerebras",
|
||||||
# "cerebras",
|
"ai21_chat",
|
||||||
# "ai21_chat",
|
"volcengine",
|
||||||
# "volcengine",
|
"codestral",
|
||||||
# "codestral",
|
"text-completion-codestral",
|
||||||
# "text-completion-codestral",
|
"deepseek",
|
||||||
# "deepseek",
|
"sambanova",
|
||||||
# "sambanova",
|
"maritalk",
|
||||||
# "maritalk",
|
"cloudflare",
|
||||||
# "voyage",
|
|
||||||
# "cloudflare",
|
|
||||||
"fireworks_ai",
|
"fireworks_ai",
|
||||||
"friendliai",
|
"friendliai",
|
||||||
"watsonx",
|
"watsonx",
|
||||||
|
|
|
@ -285,9 +285,7 @@ def get_llm_provider( # noqa: PLR0915
|
||||||
):
|
):
|
||||||
custom_llm_provider = "vertex_ai"
|
custom_llm_provider = "vertex_ai"
|
||||||
## ai21
|
## ai21
|
||||||
elif model in litellm.ai21_models:
|
elif model in litellm.ai21_chat_models or model in litellm.ai21_models:
|
||||||
custom_llm_provider = "ai21"
|
|
||||||
elif model in litellm.ai21_chat_models:
|
|
||||||
custom_llm_provider = "ai21_chat"
|
custom_llm_provider = "ai21_chat"
|
||||||
api_base = (
|
api_base = (
|
||||||
api_base
|
api_base
|
||||||
|
|
|
@ -31,7 +31,7 @@ def get_supported_openai_params( # noqa: PLR0915
|
||||||
elif custom_llm_provider == "ollama":
|
elif custom_llm_provider == "ollama":
|
||||||
return litellm.OllamaConfig().get_supported_openai_params(model=model)
|
return litellm.OllamaConfig().get_supported_openai_params(model=model)
|
||||||
elif custom_llm_provider == "ollama_chat":
|
elif custom_llm_provider == "ollama_chat":
|
||||||
return litellm.OllamaChatConfig().get_supported_openai_params()
|
return litellm.OllamaChatConfig().get_supported_openai_params(model=model)
|
||||||
elif custom_llm_provider == "anthropic":
|
elif custom_llm_provider == "anthropic":
|
||||||
return litellm.AnthropicConfig().get_supported_openai_params(model=model)
|
return litellm.AnthropicConfig().get_supported_openai_params(model=model)
|
||||||
elif custom_llm_provider == "fireworks_ai":
|
elif custom_llm_provider == "fireworks_ai":
|
||||||
|
@ -50,7 +50,7 @@ def get_supported_openai_params( # noqa: PLR0915
|
||||||
return litellm.CerebrasConfig().get_supported_openai_params(model=model)
|
return litellm.CerebrasConfig().get_supported_openai_params(model=model)
|
||||||
elif custom_llm_provider == "xai":
|
elif custom_llm_provider == "xai":
|
||||||
return litellm.XAIChatConfig().get_supported_openai_params(model=model)
|
return litellm.XAIChatConfig().get_supported_openai_params(model=model)
|
||||||
elif custom_llm_provider == "ai21_chat":
|
elif custom_llm_provider == "ai21_chat" or custom_llm_provider == "ai21":
|
||||||
return litellm.AI21ChatConfig().get_supported_openai_params(model=model)
|
return litellm.AI21ChatConfig().get_supported_openai_params(model=model)
|
||||||
elif custom_llm_provider == "volcengine":
|
elif custom_llm_provider == "volcengine":
|
||||||
return litellm.VolcEngineConfig().get_supported_openai_params(model=model)
|
return litellm.VolcEngineConfig().get_supported_openai_params(model=model)
|
||||||
|
@ -97,79 +97,50 @@ def get_supported_openai_params( # noqa: PLR0915
|
||||||
model=model
|
model=model
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return litellm.AzureOpenAIConfig().get_supported_openai_params()
|
return litellm.AzureOpenAIConfig().get_supported_openai_params(model=model)
|
||||||
elif custom_llm_provider == "openrouter":
|
elif custom_llm_provider == "openrouter":
|
||||||
return [
|
return litellm.OpenrouterConfig().get_supported_openai_params(model=model)
|
||||||
"temperature",
|
|
||||||
"top_p",
|
|
||||||
"frequency_penalty",
|
|
||||||
"presence_penalty",
|
|
||||||
"repetition_penalty",
|
|
||||||
"seed",
|
|
||||||
"max_tokens",
|
|
||||||
"logit_bias",
|
|
||||||
"logprobs",
|
|
||||||
"top_logprobs",
|
|
||||||
"response_format",
|
|
||||||
"stop",
|
|
||||||
"tools",
|
|
||||||
"tool_choice",
|
|
||||||
]
|
|
||||||
elif custom_llm_provider == "mistral" or custom_llm_provider == "codestral":
|
elif custom_llm_provider == "mistral" or custom_llm_provider == "codestral":
|
||||||
# mistal and codestral api have the exact same params
|
# mistal and codestral api have the exact same params
|
||||||
if request_type == "chat_completion":
|
if request_type == "chat_completion":
|
||||||
return litellm.MistralConfig().get_supported_openai_params()
|
return litellm.MistralConfig().get_supported_openai_params(model=model)
|
||||||
elif request_type == "embeddings":
|
elif request_type == "embeddings":
|
||||||
return litellm.MistralEmbeddingConfig().get_supported_openai_params()
|
return litellm.MistralEmbeddingConfig().get_supported_openai_params()
|
||||||
elif custom_llm_provider == "text-completion-codestral":
|
elif custom_llm_provider == "text-completion-codestral":
|
||||||
return litellm.MistralTextCompletionConfig().get_supported_openai_params()
|
return litellm.MistralTextCompletionConfig().get_supported_openai_params(
|
||||||
|
model=model
|
||||||
|
)
|
||||||
|
elif custom_llm_provider == "sambanova":
|
||||||
|
return litellm.SambanovaConfig().get_supported_openai_params(model=model)
|
||||||
elif custom_llm_provider == "replicate":
|
elif custom_llm_provider == "replicate":
|
||||||
return [
|
return litellm.ReplicateConfig().get_supported_openai_params(model=model)
|
||||||
"stream",
|
|
||||||
"temperature",
|
|
||||||
"max_tokens",
|
|
||||||
"top_p",
|
|
||||||
"stop",
|
|
||||||
"seed",
|
|
||||||
"tools",
|
|
||||||
"tool_choice",
|
|
||||||
"functions",
|
|
||||||
"function_call",
|
|
||||||
]
|
|
||||||
elif custom_llm_provider == "huggingface":
|
elif custom_llm_provider == "huggingface":
|
||||||
return litellm.HuggingfaceConfig().get_supported_openai_params()
|
return litellm.HuggingfaceConfig().get_supported_openai_params(model=model)
|
||||||
elif custom_llm_provider == "jina_ai":
|
elif custom_llm_provider == "jina_ai":
|
||||||
if request_type == "embeddings":
|
if request_type == "embeddings":
|
||||||
return litellm.JinaAIEmbeddingConfig().get_supported_openai_params()
|
return litellm.JinaAIEmbeddingConfig().get_supported_openai_params()
|
||||||
elif custom_llm_provider == "together_ai":
|
elif custom_llm_provider == "together_ai":
|
||||||
return litellm.TogetherAIConfig().get_supported_openai_params(model=model)
|
return litellm.TogetherAIConfig().get_supported_openai_params(model=model)
|
||||||
elif custom_llm_provider == "ai21":
|
|
||||||
return [
|
|
||||||
"stream",
|
|
||||||
"n",
|
|
||||||
"temperature",
|
|
||||||
"max_tokens",
|
|
||||||
"top_p",
|
|
||||||
"stop",
|
|
||||||
"frequency_penalty",
|
|
||||||
"presence_penalty",
|
|
||||||
]
|
|
||||||
elif custom_llm_provider == "databricks":
|
elif custom_llm_provider == "databricks":
|
||||||
if request_type == "chat_completion":
|
if request_type == "chat_completion":
|
||||||
return litellm.DatabricksConfig().get_supported_openai_params(model=model)
|
return litellm.DatabricksConfig().get_supported_openai_params(model=model)
|
||||||
elif request_type == "embeddings":
|
elif request_type == "embeddings":
|
||||||
return litellm.DatabricksEmbeddingConfig().get_supported_openai_params()
|
return litellm.DatabricksEmbeddingConfig().get_supported_openai_params()
|
||||||
elif custom_llm_provider == "palm" or custom_llm_provider == "gemini":
|
elif custom_llm_provider == "palm" or custom_llm_provider == "gemini":
|
||||||
return litellm.GoogleAIStudioGeminiConfig().get_supported_openai_params()
|
return litellm.GoogleAIStudioGeminiConfig().get_supported_openai_params(
|
||||||
|
model=model
|
||||||
|
)
|
||||||
elif custom_llm_provider == "vertex_ai":
|
elif custom_llm_provider == "vertex_ai":
|
||||||
if request_type == "chat_completion":
|
if request_type == "chat_completion":
|
||||||
if model.startswith("meta/"):
|
if model.startswith("meta/"):
|
||||||
return litellm.VertexAILlama3Config().get_supported_openai_params()
|
return litellm.VertexAILlama3Config().get_supported_openai_params()
|
||||||
if model.startswith("mistral"):
|
if model.startswith("mistral"):
|
||||||
return litellm.MistralConfig().get_supported_openai_params()
|
return litellm.MistralConfig().get_supported_openai_params(model=model)
|
||||||
if model.startswith("codestral"):
|
if model.startswith("codestral"):
|
||||||
return (
|
return (
|
||||||
litellm.MistralTextCompletionConfig().get_supported_openai_params()
|
litellm.MistralTextCompletionConfig().get_supported_openai_params(
|
||||||
|
model=model
|
||||||
|
)
|
||||||
)
|
)
|
||||||
if model.startswith("claude"):
|
if model.startswith("claude"):
|
||||||
return litellm.VertexAIAnthropicConfig().get_supported_openai_params(
|
return litellm.VertexAIAnthropicConfig().get_supported_openai_params(
|
||||||
|
@ -180,7 +151,7 @@ def get_supported_openai_params( # noqa: PLR0915
|
||||||
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()
|
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()
|
||||||
elif custom_llm_provider == "vertex_ai_beta":
|
elif custom_llm_provider == "vertex_ai_beta":
|
||||||
if request_type == "chat_completion":
|
if request_type == "chat_completion":
|
||||||
return litellm.VertexGeminiConfig().get_supported_openai_params()
|
return litellm.VertexGeminiConfig().get_supported_openai_params(model=model)
|
||||||
elif request_type == "embeddings":
|
elif request_type == "embeddings":
|
||||||
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()
|
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()
|
||||||
elif custom_llm_provider == "sagemaker":
|
elif custom_llm_provider == "sagemaker":
|
||||||
|
@ -199,20 +170,11 @@ def get_supported_openai_params( # noqa: PLR0915
|
||||||
elif custom_llm_provider == "cloudflare":
|
elif custom_llm_provider == "cloudflare":
|
||||||
return litellm.CloudflareChatConfig().get_supported_openai_params(model=model)
|
return litellm.CloudflareChatConfig().get_supported_openai_params(model=model)
|
||||||
elif custom_llm_provider == "nlp_cloud":
|
elif custom_llm_provider == "nlp_cloud":
|
||||||
return [
|
return litellm.NLPCloudConfig().get_supported_openai_params(model=model)
|
||||||
"max_tokens",
|
|
||||||
"stream",
|
|
||||||
"temperature",
|
|
||||||
"top_p",
|
|
||||||
"presence_penalty",
|
|
||||||
"frequency_penalty",
|
|
||||||
"n",
|
|
||||||
"stop",
|
|
||||||
]
|
|
||||||
elif custom_llm_provider == "petals":
|
elif custom_llm_provider == "petals":
|
||||||
return ["max_tokens", "temperature", "top_p", "stream"]
|
return ["max_tokens", "temperature", "top_p", "stream"]
|
||||||
elif custom_llm_provider == "deepinfra":
|
elif custom_llm_provider == "deepinfra":
|
||||||
return litellm.DeepInfraConfig().get_supported_openai_params()
|
return litellm.DeepInfraConfig().get_supported_openai_params(model=model)
|
||||||
elif custom_llm_provider == "perplexity":
|
elif custom_llm_provider == "perplexity":
|
||||||
return [
|
return [
|
||||||
"temperature",
|
"temperature",
|
||||||
|
|
|
@ -7,8 +7,10 @@ this is OpenAI compatible - no translation needed / occurs
|
||||||
import types
|
import types
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from ...openai_like.chat.transformation import OpenAILikeChatConfig
|
||||||
|
|
||||||
class AI21ChatConfig:
|
|
||||||
|
class AI21ChatConfig(OpenAILikeChatConfig):
|
||||||
"""
|
"""
|
||||||
Reference: https://docs.ai21.com/reference/jamba-15-api-ref#request-parameters
|
Reference: https://docs.ai21.com/reference/jamba-15-api-ref#request-parameters
|
||||||
|
|
||||||
|
@ -19,8 +21,6 @@ class AI21ChatConfig:
|
||||||
response_format: Optional[dict] = None
|
response_format: Optional[dict] = None
|
||||||
documents: Optional[list] = None
|
documents: Optional[list] = None
|
||||||
max_tokens: Optional[int] = None
|
max_tokens: Optional[int] = None
|
||||||
temperature: Optional[float] = None
|
|
||||||
top_p: Optional[float] = None
|
|
||||||
stop: Optional[Union[str, list]] = None
|
stop: Optional[Union[str, list]] = None
|
||||||
n: Optional[int] = None
|
n: Optional[int] = None
|
||||||
stream: Optional[bool] = None
|
stream: Optional[bool] = None
|
||||||
|
@ -49,21 +49,7 @@ class AI21ChatConfig:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config(cls):
|
def get_config(cls):
|
||||||
return {
|
return super().get_config()
|
||||||
k: v
|
|
||||||
for k, v in cls.__dict__.items()
|
|
||||||
if not k.startswith("__")
|
|
||||||
and not isinstance(
|
|
||||||
v,
|
|
||||||
(
|
|
||||||
types.FunctionType,
|
|
||||||
types.BuiltinFunctionType,
|
|
||||||
classmethod,
|
|
||||||
staticmethod,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
and v is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_supported_openai_params(self, model: str) -> list:
|
def get_supported_openai_params(self, model: str) -> list:
|
||||||
"""
|
"""
|
||||||
|
@ -77,22 +63,9 @@ class AI21ChatConfig:
|
||||||
"max_tokens",
|
"max_tokens",
|
||||||
"max_completion_tokens",
|
"max_completion_tokens",
|
||||||
"temperature",
|
"temperature",
|
||||||
"top_p",
|
|
||||||
"stop",
|
"stop",
|
||||||
"n",
|
"n",
|
||||||
"stream",
|
"stream",
|
||||||
"seed",
|
"seed",
|
||||||
"tool_choice",
|
"tool_choice",
|
||||||
"user",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
def map_openai_params(
|
|
||||||
self, model: str, non_default_params: dict, optional_params: dict
|
|
||||||
) -> dict:
|
|
||||||
supported_openai_params = self.get_supported_openai_params(model=model)
|
|
||||||
for param, value in non_default_params.items():
|
|
||||||
if param == "max_completion_tokens":
|
|
||||||
optional_params["max_tokens"] = value
|
|
||||||
elif param in supported_openai_params:
|
|
||||||
optional_params[param] = value
|
|
||||||
return optional_params
|
|
|
@ -1,221 +0,0 @@
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import time # type: ignore
|
|
||||||
import traceback
|
|
||||||
import types
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Callable, Optional
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
import requests # type: ignore
|
|
||||||
|
|
||||||
import litellm
|
|
||||||
from litellm.utils import Choices, Message, ModelResponse
|
|
||||||
|
|
||||||
|
|
||||||
class AI21Error(Exception):
|
|
||||||
def __init__(self, status_code, message):
|
|
||||||
self.status_code = status_code
|
|
||||||
self.message = message
|
|
||||||
self.request = httpx.Request(
|
|
||||||
method="POST", url="https://api.ai21.com/studio/v1/"
|
|
||||||
)
|
|
||||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
|
||||||
super().__init__(
|
|
||||||
self.message
|
|
||||||
) # Call the base class constructor with the parameters it needs
|
|
||||||
|
|
||||||
|
|
||||||
class AI21Config:
|
|
||||||
"""
|
|
||||||
Reference: https://docs.ai21.com/reference/j2-complete-ref
|
|
||||||
|
|
||||||
The class `AI21Config` provides configuration for the AI21's API interface. Below are the parameters:
|
|
||||||
|
|
||||||
- `numResults` (int32): Number of completions to sample and return. Optional, default is 1. If the temperature is greater than 0 (non-greedy decoding), a value greater than 1 can be meaningful.
|
|
||||||
|
|
||||||
- `maxTokens` (int32): The maximum number of tokens to generate per result. Optional, default is 16. If no `stopSequences` are given, generation stops after producing `maxTokens`.
|
|
||||||
|
|
||||||
- `minTokens` (int32): The minimum number of tokens to generate per result. Optional, default is 0. If `stopSequences` are given, they are ignored until `minTokens` are generated.
|
|
||||||
|
|
||||||
- `temperature` (float): Modifies the distribution from which tokens are sampled. Optional, default is 0.7. A value of 0 essentially disables sampling and results in greedy decoding.
|
|
||||||
|
|
||||||
- `topP` (float): Used for sampling tokens from the corresponding top percentile of probability mass. Optional, default is 1. For instance, a value of 0.9 considers only tokens comprising the top 90% probability mass.
|
|
||||||
|
|
||||||
- `stopSequences` (array of strings): Stops decoding if any of the input strings is generated. Optional.
|
|
||||||
|
|
||||||
- `topKReturn` (int32): Range between 0 to 10, including both. Optional, default is 0. Specifies the top-K alternative tokens to return. A non-zero value includes the string representations and log-probabilities for each of the top-K alternatives at each position.
|
|
||||||
|
|
||||||
- `frequencyPenalty` (object): Placeholder for frequency penalty object.
|
|
||||||
|
|
||||||
- `presencePenalty` (object): Placeholder for presence penalty object.
|
|
||||||
|
|
||||||
- `countPenalty` (object): Placeholder for count penalty object.
|
|
||||||
"""
|
|
||||||
|
|
||||||
numResults: Optional[int] = None
|
|
||||||
maxTokens: Optional[int] = None
|
|
||||||
minTokens: Optional[int] = None
|
|
||||||
temperature: Optional[float] = None
|
|
||||||
topP: Optional[float] = None
|
|
||||||
stopSequences: Optional[list] = None
|
|
||||||
topKReturn: Optional[int] = None
|
|
||||||
frequencePenalty: Optional[dict] = None
|
|
||||||
presencePenalty: Optional[dict] = None
|
|
||||||
countPenalty: Optional[dict] = None
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
numResults: Optional[int] = None,
|
|
||||||
maxTokens: Optional[int] = None,
|
|
||||||
minTokens: Optional[int] = None,
|
|
||||||
temperature: Optional[float] = None,
|
|
||||||
topP: Optional[float] = None,
|
|
||||||
stopSequences: Optional[list] = None,
|
|
||||||
topKReturn: Optional[int] = None,
|
|
||||||
frequencePenalty: Optional[dict] = None,
|
|
||||||
presencePenalty: Optional[dict] = None,
|
|
||||||
countPenalty: Optional[dict] = None,
|
|
||||||
) -> None:
|
|
||||||
locals_ = locals()
|
|
||||||
for key, value in locals_.items():
|
|
||||||
if key != "self" and value is not None:
|
|
||||||
setattr(self.__class__, key, value)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_config(cls):
|
|
||||||
return {
|
|
||||||
k: v
|
|
||||||
for k, v in cls.__dict__.items()
|
|
||||||
if not k.startswith("__")
|
|
||||||
and not isinstance(
|
|
||||||
v,
|
|
||||||
(
|
|
||||||
types.FunctionType,
|
|
||||||
types.BuiltinFunctionType,
|
|
||||||
classmethod,
|
|
||||||
staticmethod,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
and v is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def validate_environment(api_key):
|
|
||||||
if api_key is None:
|
|
||||||
raise ValueError(
|
|
||||||
"Missing AI21 API Key - A call is being made to ai21 but no key is set either in the environment variables or via params"
|
|
||||||
)
|
|
||||||
headers = {
|
|
||||||
"accept": "application/json",
|
|
||||||
"content-type": "application/json",
|
|
||||||
"Authorization": "Bearer " + api_key,
|
|
||||||
}
|
|
||||||
return headers
|
|
||||||
|
|
||||||
|
|
||||||
def completion(
|
|
||||||
model: str,
|
|
||||||
messages: list,
|
|
||||||
api_base: str,
|
|
||||||
model_response: ModelResponse,
|
|
||||||
print_verbose: Callable,
|
|
||||||
encoding,
|
|
||||||
api_key,
|
|
||||||
logging_obj,
|
|
||||||
optional_params: dict,
|
|
||||||
litellm_params=None,
|
|
||||||
logger_fn=None,
|
|
||||||
):
|
|
||||||
headers = validate_environment(api_key)
|
|
||||||
model = model
|
|
||||||
prompt = ""
|
|
||||||
for message in messages:
|
|
||||||
if "role" in message:
|
|
||||||
if message["role"] == "user":
|
|
||||||
prompt += f"{message['content']}"
|
|
||||||
else:
|
|
||||||
prompt += f"{message['content']}"
|
|
||||||
else:
|
|
||||||
prompt += f"{message['content']}"
|
|
||||||
|
|
||||||
## Load Config
|
|
||||||
config = litellm.AI21Config.get_config()
|
|
||||||
for k, v in config.items():
|
|
||||||
if (
|
|
||||||
k not in optional_params
|
|
||||||
): # completion(top_k=3) > ai21_config(top_k=3) <- allows for dynamic variables to be passed in
|
|
||||||
optional_params[k] = v
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"prompt": prompt,
|
|
||||||
# "instruction": prompt, # some baseten models require the prompt to be passed in via the 'instruction' kwarg
|
|
||||||
**optional_params,
|
|
||||||
}
|
|
||||||
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.pre_call(
|
|
||||||
input=prompt,
|
|
||||||
api_key=api_key,
|
|
||||||
additional_args={"complete_input_dict": data},
|
|
||||||
)
|
|
||||||
## COMPLETION CALL
|
|
||||||
response = requests.post(
|
|
||||||
api_base + model + "/complete", headers=headers, data=json.dumps(data)
|
|
||||||
)
|
|
||||||
if response.status_code != 200:
|
|
||||||
raise AI21Error(status_code=response.status_code, message=response.text)
|
|
||||||
if "stream" in optional_params and optional_params["stream"] is True:
|
|
||||||
return response.iter_lines()
|
|
||||||
else:
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.post_call(
|
|
||||||
input=prompt,
|
|
||||||
api_key=api_key,
|
|
||||||
original_response=response.text,
|
|
||||||
additional_args={"complete_input_dict": data},
|
|
||||||
)
|
|
||||||
## RESPONSE OBJECT
|
|
||||||
completion_response = response.json()
|
|
||||||
try:
|
|
||||||
choices_list = []
|
|
||||||
for idx, item in enumerate(completion_response["completions"]):
|
|
||||||
if len(item["data"]["text"]) > 0:
|
|
||||||
message_obj = Message(content=item["data"]["text"])
|
|
||||||
else:
|
|
||||||
message_obj = Message(content=None)
|
|
||||||
choice_obj = Choices(
|
|
||||||
finish_reason=item["finishReason"]["reason"],
|
|
||||||
index=idx + 1,
|
|
||||||
message=message_obj,
|
|
||||||
)
|
|
||||||
choices_list.append(choice_obj)
|
|
||||||
model_response.choices = choices_list # type: ignore
|
|
||||||
except Exception:
|
|
||||||
raise AI21Error(
|
|
||||||
message=traceback.format_exc(), status_code=response.status_code
|
|
||||||
)
|
|
||||||
|
|
||||||
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
|
||||||
prompt_tokens = len(encoding.encode(prompt))
|
|
||||||
completion_tokens = len(
|
|
||||||
encoding.encode(model_response["choices"][0]["message"].get("content"))
|
|
||||||
)
|
|
||||||
|
|
||||||
model_response.created = int(time.time())
|
|
||||||
model_response.model = model
|
|
||||||
setattr(
|
|
||||||
model_response,
|
|
||||||
"usage",
|
|
||||||
litellm.Usage(
|
|
||||||
prompt_tokens=prompt_tokens,
|
|
||||||
completion_tokens=completion_tokens,
|
|
||||||
total_tokens=prompt_tokens + completion_tokens,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return model_response
|
|
||||||
|
|
||||||
|
|
||||||
def embedding():
|
|
||||||
# logic for parsing in - calling - parsing out model embedding calls
|
|
||||||
pass
|
|
|
@ -52,20 +52,6 @@ from ..common_utils import AnthropicError, process_anthropic_headers
|
||||||
from .transformation import AnthropicConfig
|
from .transformation import AnthropicConfig
|
||||||
|
|
||||||
|
|
||||||
# makes headers for API call
|
|
||||||
def validate_environment(
|
|
||||||
api_key,
|
|
||||||
user_headers,
|
|
||||||
model,
|
|
||||||
messages: List[AllMessageValues],
|
|
||||||
is_vertex_request: bool,
|
|
||||||
tools: Optional[List[AllAnthropicToolsValues]],
|
|
||||||
anthropic_version: Optional[str] = None,
|
|
||||||
):
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
async def make_call(
|
async def make_call(
|
||||||
client: Optional[AsyncHTTPHandler],
|
client: Optional[AsyncHTTPHandler],
|
||||||
api_base: str,
|
api_base: str,
|
||||||
|
@ -239,7 +225,7 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
data: dict,
|
data: dict,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
json_mode: bool,
|
json_mode: bool,
|
||||||
litellm_params=None,
|
litellm_params: dict,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
headers={},
|
headers={},
|
||||||
client: Optional[AsyncHTTPHandler] = None,
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
@ -283,6 +269,7 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
request_data=data,
|
request_data=data,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
json_mode=json_mode,
|
json_mode=json_mode,
|
||||||
)
|
)
|
||||||
|
@ -460,6 +447,7 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
request_data=data,
|
request_data=data,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
json_mode=json_mode,
|
json_mode=json_mode,
|
||||||
)
|
)
|
||||||
|
|
|
@ -567,6 +567,7 @@ class AnthropicConfig(BaseConfig):
|
||||||
request_data: Dict,
|
request_data: Dict,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: Dict,
|
optional_params: Dict,
|
||||||
|
litellm_params: dict,
|
||||||
encoding: Any,
|
encoding: Any,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
json_mode: Optional[bool] = None,
|
json_mode: Optional[bool] = None,
|
||||||
|
@ -715,11 +716,6 @@ class AnthropicConfig(BaseConfig):
|
||||||
return litellm.Message(content=json_mode_content_str)
|
return litellm.Message(content=json_mode_content_str)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _transform_messages(
|
|
||||||
self, messages: List[AllMessageValues]
|
|
||||||
) -> List[AllMessageValues]:
|
|
||||||
return messages
|
|
||||||
|
|
||||||
def get_error_class(
|
def get_error_class(
|
||||||
self, error_message: str, status_code: int, headers: Union[Dict, httpx.Headers]
|
self, error_message: str, status_code: int, headers: Union[Dict, httpx.Headers]
|
||||||
) -> BaseLLMException:
|
) -> BaseLLMException:
|
||||||
|
|
|
@ -180,6 +180,7 @@ class AnthropicTextConfig(BaseConfig):
|
||||||
request_data: dict,
|
request_data: dict,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
encoding: str,
|
encoding: str,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
json_mode: Optional[bool] = None,
|
json_mode: Optional[bool] = None,
|
||||||
|
|
|
@ -35,38 +35,11 @@ from ...types.llms.openai import (
|
||||||
RetrieveBatchRequest,
|
RetrieveBatchRequest,
|
||||||
)
|
)
|
||||||
from ..base import BaseLLM
|
from ..base import BaseLLM
|
||||||
from .common_utils import process_azure_headers
|
from .common_utils import AzureOpenAIError, process_azure_headers
|
||||||
|
|
||||||
azure_ad_cache = DualCache()
|
azure_ad_cache = DualCache()
|
||||||
|
|
||||||
|
|
||||||
class AzureOpenAIError(Exception):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
status_code,
|
|
||||||
message,
|
|
||||||
request: Optional[httpx.Request] = None,
|
|
||||||
response: Optional[httpx.Response] = None,
|
|
||||||
headers: Optional[httpx.Headers] = None,
|
|
||||||
):
|
|
||||||
self.status_code = status_code
|
|
||||||
self.message = message
|
|
||||||
self.headers = headers
|
|
||||||
if request:
|
|
||||||
self.request = request
|
|
||||||
else:
|
|
||||||
self.request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
|
||||||
if response:
|
|
||||||
self.response = response
|
|
||||||
else:
|
|
||||||
self.response = httpx.Response(
|
|
||||||
status_code=status_code, request=self.request
|
|
||||||
)
|
|
||||||
super().__init__(
|
|
||||||
self.message
|
|
||||||
) # Call the base class constructor with the parameters it needs
|
|
||||||
|
|
||||||
|
|
||||||
class AzureOpenAIAssistantsAPIConfig:
|
class AzureOpenAIAssistantsAPIConfig:
|
||||||
"""
|
"""
|
||||||
Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/assistants-reference-messages?tabs=python#create-message
|
Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/assistants-reference-messages?tabs=python#create-message
|
||||||
|
@ -412,8 +385,12 @@ class AzureChatCompletion(BaseLLM):
|
||||||
|
|
||||||
data = {"model": None, "messages": messages, **optional_params}
|
data = {"model": None, "messages": messages, **optional_params}
|
||||||
else:
|
else:
|
||||||
data = litellm.AzureOpenAIConfig.transform_request(
|
data = litellm.AzureOpenAIConfig().transform_request(
|
||||||
model=model, messages=messages, optional_params=optional_params
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
headers=headers or {},
|
||||||
)
|
)
|
||||||
|
|
||||||
if acompletion is True:
|
if acompletion is True:
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
import types
|
import types
|
||||||
from typing import List, Optional, Type, Union
|
from typing import TYPE_CHECKING, Any, List, Optional, Type, Union
|
||||||
|
|
||||||
|
from httpx._models import Headers, Response
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm.llms.base_llm.transformation import BaseLLMException
|
||||||
|
|
||||||
from ....exceptions import UnsupportedParamsError
|
from ....exceptions import UnsupportedParamsError
|
||||||
from ....types.llms.openai import (
|
from ....types.llms.openai import (
|
||||||
|
@ -11,10 +14,19 @@ from ....types.llms.openai import (
|
||||||
ChatCompletionToolParam,
|
ChatCompletionToolParam,
|
||||||
ChatCompletionToolParamFunctionChunk,
|
ChatCompletionToolParamFunctionChunk,
|
||||||
)
|
)
|
||||||
|
from ...base_llm.transformation import BaseConfig
|
||||||
from ...prompt_templates.factory import convert_to_azure_openai_messages
|
from ...prompt_templates.factory import convert_to_azure_openai_messages
|
||||||
|
from ..common_utils import AzureOpenAIError
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
|
|
||||||
|
LoggingClass = LiteLLMLoggingObj
|
||||||
|
else:
|
||||||
|
LoggingClass = Any
|
||||||
|
|
||||||
|
|
||||||
class AzureOpenAIConfig:
|
class AzureOpenAIConfig(BaseConfig):
|
||||||
"""
|
"""
|
||||||
Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions
|
Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions
|
||||||
|
|
||||||
|
@ -61,23 +73,9 @@ class AzureOpenAIConfig:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config(cls):
|
def get_config(cls):
|
||||||
return {
|
return super().get_config()
|
||||||
k: v
|
|
||||||
for k, v in cls.__dict__.items()
|
|
||||||
if not k.startswith("__")
|
|
||||||
and not isinstance(
|
|
||||||
v,
|
|
||||||
(
|
|
||||||
types.FunctionType,
|
|
||||||
types.BuiltinFunctionType,
|
|
||||||
classmethod,
|
|
||||||
staticmethod,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
and v is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_supported_openai_params(self):
|
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||||
return [
|
return [
|
||||||
"temperature",
|
"temperature",
|
||||||
"n",
|
"n",
|
||||||
|
@ -110,10 +108,10 @@ class AzureOpenAIConfig:
|
||||||
non_default_params: dict,
|
non_default_params: dict,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
model: str,
|
model: str,
|
||||||
api_version: str, # Y-M-D-{optional}
|
drop_params: bool,
|
||||||
drop_params,
|
api_version: str = "",
|
||||||
) -> dict:
|
) -> dict:
|
||||||
supported_openai_params = self.get_supported_openai_params()
|
supported_openai_params = self.get_supported_openai_params(model)
|
||||||
|
|
||||||
api_version_times = api_version.split("-")
|
api_version_times = api_version.split("-")
|
||||||
api_version_year = api_version_times[0]
|
api_version_year = api_version_times[0]
|
||||||
|
@ -204,9 +202,13 @@ class AzureOpenAIConfig:
|
||||||
|
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def transform_request(
|
def transform_request(
|
||||||
cls, model: str, messages: List[AllMessageValues], optional_params: dict
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
headers: dict,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
messages = convert_to_azure_openai_messages(messages)
|
messages = convert_to_azure_openai_messages(messages)
|
||||||
return {
|
return {
|
||||||
|
@ -215,6 +217,24 @@ class AzureOpenAIConfig:
|
||||||
**optional_params,
|
**optional_params,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def transform_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
raw_response: Response,
|
||||||
|
model_response: litellm.ModelResponse,
|
||||||
|
logging_obj: LoggingClass,
|
||||||
|
request_data: dict,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
encoding: Any,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
json_mode: Optional[bool] = None,
|
||||||
|
) -> litellm.ModelResponse:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Azure OpenAI handler.py has custom logic for transforming response, as it uses the OpenAI SDK."
|
||||||
|
)
|
||||||
|
|
||||||
def get_mapped_special_auth_params(self) -> dict:
|
def get_mapped_special_auth_params(self) -> dict:
|
||||||
return {"token": "azure_ad_token"}
|
return {"token": "azure_ad_token"}
|
||||||
|
|
||||||
|
@ -246,3 +266,22 @@ class AzureOpenAIConfig:
|
||||||
"westus3",
|
"westus3",
|
||||||
"westus4",
|
"westus4",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def get_error_class(
|
||||||
|
self, error_message: str, status_code: int, headers: Union[dict, Headers]
|
||||||
|
) -> BaseLLMException:
|
||||||
|
return AzureOpenAIError(
|
||||||
|
message=error_message, status_code=status_code, headers=headers
|
||||||
|
)
|
||||||
|
|
||||||
|
def validate_environment(
|
||||||
|
self,
|
||||||
|
headers: dict,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Azure OpenAI has custom logic for validating environment, as it uses the OpenAI SDK."
|
||||||
|
)
|
||||||
|
|
|
@ -1,7 +1,27 @@
|
||||||
from typing import Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
from litellm.llms.base_llm.transformation import BaseLLMException
|
||||||
|
|
||||||
|
|
||||||
|
class AzureOpenAIError(BaseLLMException):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
status_code,
|
||||||
|
message,
|
||||||
|
request: Optional[httpx.Request] = None,
|
||||||
|
response: Optional[httpx.Response] = None,
|
||||||
|
headers: Optional[Union[httpx.Headers, dict]] = None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status_code,
|
||||||
|
message=message,
|
||||||
|
request=request,
|
||||||
|
response=response,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def process_azure_headers(headers: Union[httpx.Headers, dict]) -> dict:
|
def process_azure_headers(headers: Union[httpx.Headers, dict]) -> dict:
|
||||||
openai_headers = {}
|
openai_headers = {}
|
||||||
|
|
|
@ -19,104 +19,16 @@ from litellm.utils import (
|
||||||
convert_to_model_response_object,
|
convert_to_model_response_object,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .base import BaseLLM
|
from ...base import BaseLLM
|
||||||
from .openai.completion.handler import OpenAITextCompletion
|
from ...openai.completion.handler import OpenAITextCompletion
|
||||||
from .openai.completion.transformation import OpenAITextCompletionConfig
|
from ...openai.completion.transformation import OpenAITextCompletionConfig
|
||||||
from .prompt_templates.factory import custom_prompt, prompt_factory
|
from ...prompt_templates.factory import custom_prompt, prompt_factory
|
||||||
|
from ..common_utils import AzureOpenAIError
|
||||||
|
|
||||||
openai_text_completion_config = OpenAITextCompletionConfig()
|
openai_text_completion_config = OpenAITextCompletionConfig()
|
||||||
|
|
||||||
|
|
||||||
class AzureOpenAIError(Exception):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
status_code,
|
|
||||||
message,
|
|
||||||
request: Optional[httpx.Request] = None,
|
|
||||||
response: Optional[httpx.Response] = None,
|
|
||||||
headers: Optional[httpx.Headers] = None,
|
|
||||||
):
|
|
||||||
self.status_code = status_code
|
|
||||||
self.message = message
|
|
||||||
self.headers = headers
|
|
||||||
if request:
|
|
||||||
self.request = request
|
|
||||||
else:
|
|
||||||
self.request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
|
||||||
if response:
|
|
||||||
self.response = response
|
|
||||||
else:
|
|
||||||
self.response = httpx.Response(
|
|
||||||
status_code=status_code, request=self.request
|
|
||||||
)
|
|
||||||
super().__init__(
|
|
||||||
self.message
|
|
||||||
) # Call the base class constructor with the parameters it needs
|
|
||||||
|
|
||||||
|
|
||||||
class AzureOpenAIConfig(OpenAIConfig):
|
|
||||||
"""
|
|
||||||
Reference: https://platform.openai.com/docs/api-reference/chat/create
|
|
||||||
|
|
||||||
The class `AzureOpenAIConfig` provides configuration for the OpenAI's Chat API interface, for use with Azure. It inherits from `OpenAIConfig`. Below are the parameters::
|
|
||||||
|
|
||||||
- `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition.
|
|
||||||
|
|
||||||
- `function_call` (string or object): This optional parameter controls how the model calls functions.
|
|
||||||
|
|
||||||
- `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs.
|
|
||||||
|
|
||||||
- `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
|
|
||||||
|
|
||||||
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion.
|
|
||||||
|
|
||||||
- `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message.
|
|
||||||
|
|
||||||
- `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics.
|
|
||||||
|
|
||||||
- `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
|
|
||||||
|
|
||||||
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2.
|
|
||||||
|
|
||||||
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
frequency_penalty: Optional[int] = None,
|
|
||||||
function_call: Optional[Union[str, dict]] = None,
|
|
||||||
functions: Optional[list] = None,
|
|
||||||
logit_bias: Optional[dict] = None,
|
|
||||||
max_tokens: Optional[int] = None,
|
|
||||||
n: Optional[int] = None,
|
|
||||||
presence_penalty: Optional[int] = None,
|
|
||||||
stop: Optional[Union[str, list]] = None,
|
|
||||||
temperature: Optional[int] = None,
|
|
||||||
top_p: Optional[int] = None,
|
|
||||||
) -> None:
|
|
||||||
super().__init__(
|
|
||||||
frequency_penalty=frequency_penalty,
|
|
||||||
function_call=function_call,
|
|
||||||
functions=functions,
|
|
||||||
logit_bias=logit_bias,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
n=n,
|
|
||||||
presence_penalty=presence_penalty,
|
|
||||||
stop=stop,
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def select_azure_base_url_or_endpoint(azure_client_params: dict):
|
def select_azure_base_url_or_endpoint(azure_client_params: dict):
|
||||||
# azure_client_params = {
|
|
||||||
# "api_version": api_version,
|
|
||||||
# "azure_endpoint": api_base,
|
|
||||||
# "azure_deployment": model,
|
|
||||||
# "http_client": litellm.client_session,
|
|
||||||
# "max_retries": max_retries,
|
|
||||||
# "timeout": timeout,
|
|
||||||
# }
|
|
||||||
azure_endpoint = azure_client_params.get("azure_endpoint", None)
|
azure_endpoint = azure_client_params.get("azure_endpoint", None)
|
||||||
if azure_endpoint is not None:
|
if azure_endpoint is not None:
|
||||||
# see : https://github.com/openai/openai-python/blob/3d61ed42aba652b547029095a7eb269ad4e1e957/src/openai/lib/azure.py#L192
|
# see : https://github.com/openai/openai-python/blob/3d61ed42aba652b547029095a7eb269ad4e1e957/src/openai/lib/azure.py#L192
|
53
litellm/llms/azure/completion/transformation.py
Normal file
53
litellm/llms/azure/completion/transformation.py
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from ...openai.completion.transformation import OpenAITextCompletionConfig
|
||||||
|
|
||||||
|
|
||||||
|
class AzureOpenAITextConfig(OpenAITextCompletionConfig):
|
||||||
|
"""
|
||||||
|
Reference: https://platform.openai.com/docs/api-reference/chat/create
|
||||||
|
|
||||||
|
The class `AzureOpenAIConfig` provides configuration for the OpenAI's Chat API interface, for use with Azure. It inherits from `OpenAIConfig`. Below are the parameters::
|
||||||
|
|
||||||
|
- `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition.
|
||||||
|
|
||||||
|
- `function_call` (string or object): This optional parameter controls how the model calls functions.
|
||||||
|
|
||||||
|
- `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs.
|
||||||
|
|
||||||
|
- `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
|
||||||
|
|
||||||
|
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion.
|
||||||
|
|
||||||
|
- `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message.
|
||||||
|
|
||||||
|
- `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics.
|
||||||
|
|
||||||
|
- `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
|
||||||
|
|
||||||
|
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2.
|
||||||
|
|
||||||
|
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
frequency_penalty: Optional[int] = None,
|
||||||
|
logit_bias: Optional[dict] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
n: Optional[int] = None,
|
||||||
|
presence_penalty: Optional[int] = None,
|
||||||
|
stop: Optional[Union[str, list]] = None,
|
||||||
|
temperature: Optional[int] = None,
|
||||||
|
top_p: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
frequency_penalty=frequency_penalty,
|
||||||
|
logit_bias=logit_bias,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
n=n,
|
||||||
|
presence_penalty=presence_penalty,
|
||||||
|
stop=stop,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
)
|
|
@ -1 +0,0 @@
|
||||||
from .handler import AzureAIChatCompletion
|
|
|
@ -1,59 +1,3 @@
|
||||||
from typing import Any, Callable, List, Optional, Union
|
"""
|
||||||
|
LLM Calling done in `openai/openai.py`
|
||||||
from httpx._config import Timeout
|
"""
|
||||||
|
|
||||||
from litellm.llms.bedrock.chat.invoke_handler import MockResponseIterator
|
|
||||||
from litellm.llms.openai.openai import OpenAIChatCompletion
|
|
||||||
from litellm.types.utils import ModelResponse
|
|
||||||
from litellm.utils import CustomStreamWrapper
|
|
||||||
|
|
||||||
from .transformation import AzureAIStudioConfig
|
|
||||||
|
|
||||||
|
|
||||||
class AzureAIChatCompletion(OpenAIChatCompletion):
|
|
||||||
def completion(
|
|
||||||
self,
|
|
||||||
model_response: ModelResponse,
|
|
||||||
timeout: Union[float, Timeout],
|
|
||||||
optional_params: dict,
|
|
||||||
logging_obj: Any,
|
|
||||||
model: Optional[str] = None,
|
|
||||||
messages: Optional[list] = None,
|
|
||||||
print_verbose: Optional[Callable[..., Any]] = None,
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
api_base: Optional[str] = None,
|
|
||||||
acompletion: bool = False,
|
|
||||||
litellm_params=None,
|
|
||||||
logger_fn=None,
|
|
||||||
headers: Optional[dict] = None,
|
|
||||||
custom_prompt_dict: dict = {},
|
|
||||||
client=None,
|
|
||||||
organization: Optional[str] = None,
|
|
||||||
custom_llm_provider: Optional[str] = None,
|
|
||||||
drop_params: Optional[bool] = None,
|
|
||||||
):
|
|
||||||
|
|
||||||
transformed_messages = AzureAIStudioConfig()._transform_messages(
|
|
||||||
messages=messages # type: ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
return super().completion(
|
|
||||||
model_response,
|
|
||||||
timeout,
|
|
||||||
optional_params,
|
|
||||||
logging_obj,
|
|
||||||
model,
|
|
||||||
transformed_messages,
|
|
||||||
print_verbose,
|
|
||||||
api_key,
|
|
||||||
api_base,
|
|
||||||
acompletion,
|
|
||||||
litellm_params,
|
|
||||||
logger_fn,
|
|
||||||
headers,
|
|
||||||
custom_prompt_dict,
|
|
||||||
client,
|
|
||||||
organization,
|
|
||||||
custom_llm_provider,
|
|
||||||
drop_params,
|
|
||||||
)
|
|
||||||
|
|
|
@ -13,6 +13,7 @@ from typing import (
|
||||||
Iterator,
|
Iterator,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
|
TypedDict,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -34,15 +35,25 @@ class BaseLLMException(Exception):
|
||||||
self,
|
self,
|
||||||
status_code: int,
|
status_code: int,
|
||||||
message: str,
|
message: str,
|
||||||
headers: Optional[Union[httpx.Headers, Dict]] = None,
|
headers: Optional[Union[dict, httpx.Headers]] = None,
|
||||||
request: Optional[httpx.Request] = None,
|
request: Optional[httpx.Request] = None,
|
||||||
response: Optional[httpx.Response] = None,
|
response: Optional[httpx.Response] = None,
|
||||||
):
|
):
|
||||||
self.status_code = status_code
|
self.status_code = status_code
|
||||||
self.message: str = message
|
self.message: str = message
|
||||||
self.headers = headers
|
self.headers = headers
|
||||||
self.request = httpx.Request(method="POST", url="https://docs.litellm.ai/docs")
|
if request:
|
||||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
self.request = request
|
||||||
|
else:
|
||||||
|
self.request = httpx.Request(
|
||||||
|
method="POST", url="https://docs.litellm.ai/docs"
|
||||||
|
)
|
||||||
|
if response:
|
||||||
|
self.response = response
|
||||||
|
else:
|
||||||
|
self.response = httpx.Response(
|
||||||
|
status_code=status_code, request=self.request
|
||||||
|
)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
self.message
|
self.message
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
@ -117,12 +128,6 @@ class BaseConfig(ABC):
|
||||||
) -> dict:
|
) -> dict:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def _transform_messages(
|
|
||||||
self, messages: List[AllMessageValues]
|
|
||||||
) -> List[AllMessageValues]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def transform_response(
|
def transform_response(
|
||||||
self,
|
self,
|
||||||
|
@ -133,7 +138,8 @@ class BaseConfig(ABC):
|
||||||
request_data: dict,
|
request_data: dict,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
encoding: str,
|
litellm_params: dict,
|
||||||
|
encoding: Any,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
json_mode: Optional[bool] = None,
|
json_mode: Optional[bool] = None,
|
||||||
) -> ModelResponse:
|
) -> ModelResponse:
|
||||||
|
|
|
@ -7,8 +7,10 @@ this is OpenAI compatible - no translation needed / occurs
|
||||||
import types
|
import types
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||||
|
|
||||||
class CerebrasConfig:
|
|
||||||
|
class CerebrasConfig(OpenAIGPTConfig):
|
||||||
"""
|
"""
|
||||||
Reference: https://inference-docs.cerebras.ai/api-reference/chat-completions
|
Reference: https://inference-docs.cerebras.ai/api-reference/chat-completions
|
||||||
|
|
||||||
|
@ -18,9 +20,7 @@ class CerebrasConfig:
|
||||||
max_tokens: Optional[int] = None
|
max_tokens: Optional[int] = None
|
||||||
response_format: Optional[dict] = None
|
response_format: Optional[dict] = None
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
stop: Optional[str] = None
|
|
||||||
stream: Optional[bool] = None
|
stream: Optional[bool] = None
|
||||||
temperature: Optional[float] = None
|
|
||||||
top_p: Optional[int] = None
|
top_p: Optional[int] = None
|
||||||
tool_choice: Optional[str] = None
|
tool_choice: Optional[str] = None
|
||||||
tools: Optional[list] = None
|
tools: Optional[list] = None
|
||||||
|
@ -46,21 +46,7 @@ class CerebrasConfig:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config(cls):
|
def get_config(cls):
|
||||||
return {
|
return super().get_config()
|
||||||
k: v
|
|
||||||
for k, v in cls.__dict__.items()
|
|
||||||
if not k.startswith("__")
|
|
||||||
and not isinstance(
|
|
||||||
v,
|
|
||||||
(
|
|
||||||
types.FunctionType,
|
|
||||||
types.BuiltinFunctionType,
|
|
||||||
classmethod,
|
|
||||||
staticmethod,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
and v is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_supported_openai_params(self, model: str) -> list:
|
def get_supported_openai_params(self, model: str) -> list:
|
||||||
"""
|
"""
|
||||||
|
@ -83,7 +69,11 @@ class CerebrasConfig:
|
||||||
]
|
]
|
||||||
|
|
||||||
def map_openai_params(
|
def map_openai_params(
|
||||||
self, model: str, non_default_params: dict, optional_params: dict
|
self,
|
||||||
|
non_default_params: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
model: str,
|
||||||
|
drop_params: bool,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
supported_openai_params = self.get_supported_openai_params(model=model)
|
supported_openai_params = self.get_supported_openai_params(model=model)
|
||||||
for param, value in non_default_params.items():
|
for param, value in non_default_params.items():
|
||||||
|
|
|
@ -148,6 +148,7 @@ class ClarifaiConfig(BaseConfig):
|
||||||
request_data: dict,
|
request_data: dict,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
encoding: str,
|
encoding: str,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
json_mode: Optional[bool] = None,
|
json_mode: Optional[bool] = None,
|
||||||
|
|
|
@ -49,6 +49,10 @@ class CloudflareChatConfig(BaseConfig):
|
||||||
if key != "self" and value is not None:
|
if key != "self" and value is not None:
|
||||||
setattr(self.__class__, key, value)
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return super().get_config()
|
||||||
|
|
||||||
def validate_environment(
|
def validate_environment(
|
||||||
self,
|
self,
|
||||||
headers: dict,
|
headers: dict,
|
||||||
|
@ -120,6 +124,7 @@ class CloudflareChatConfig(BaseConfig):
|
||||||
request_data: dict,
|
request_data: dict,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
encoding: str,
|
encoding: str,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
json_mode: Optional[bool] = None,
|
json_mode: Optional[bool] = None,
|
||||||
|
|
|
@ -216,7 +216,8 @@ class CohereChatConfig(BaseConfig):
|
||||||
request_data: dict,
|
request_data: dict,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
encoding: str,
|
litellm_params: dict,
|
||||||
|
encoding: Any,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
json_mode: Optional[bool] = None,
|
json_mode: Optional[bool] = None,
|
||||||
) -> ModelResponse:
|
) -> ModelResponse:
|
||||||
|
|
|
@ -217,7 +217,8 @@ class CohereTextConfig(BaseConfig):
|
||||||
request_data: dict,
|
request_data: dict,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
encoding: str,
|
litellm_params: dict,
|
||||||
|
encoding: Any,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
json_mode: Optional[bool] = None,
|
json_mode: Optional[bool] = None,
|
||||||
) -> ModelResponse:
|
) -> ModelResponse:
|
||||||
|
|
|
@ -211,7 +211,6 @@ class AsyncHTTPHandler:
|
||||||
headers=headers,
|
headers=headers,
|
||||||
)
|
)
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
|
|
||||||
if stream is True:
|
if stream is True:
|
||||||
setattr(e, "message", await e.response.aread())
|
setattr(e, "message", await e.response.aread())
|
||||||
setattr(e, "text", await e.response.aread())
|
setattr(e, "text", await e.response.aread())
|
||||||
|
|
|
@ -51,7 +51,8 @@ class BaseLLMHTTPHandler:
|
||||||
logging_obj: LiteLLMLoggingObj,
|
logging_obj: LiteLLMLoggingObj,
|
||||||
messages: list,
|
messages: list,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
encoding: str,
|
litellm_params: dict,
|
||||||
|
encoding: Any,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
):
|
):
|
||||||
async_httpx_client = get_async_httpx_client(
|
async_httpx_client = get_async_httpx_client(
|
||||||
|
@ -75,6 +76,7 @@ class BaseLLMHTTPHandler:
|
||||||
request_data=data,
|
request_data=data,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -163,6 +165,7 @@ class BaseLLMHTTPHandler:
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -211,6 +214,7 @@ class BaseLLMHTTPHandler:
|
||||||
request_data=data,
|
request_data=data,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -10,14 +10,14 @@ from pydantic import BaseModel
|
||||||
from litellm.types.llms.openai import AllMessageValues
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
from litellm.types.utils import ProviderField
|
from litellm.types.utils import ProviderField
|
||||||
|
|
||||||
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
from ...openai_like.chat.transformation import OpenAILikeChatConfig
|
||||||
from ...prompt_templates.common_utils import (
|
from ...prompt_templates.common_utils import (
|
||||||
handle_messages_with_content_list_to_str_conversion,
|
handle_messages_with_content_list_to_str_conversion,
|
||||||
strip_name_from_messages,
|
strip_name_from_messages,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class DatabricksConfig(OpenAIGPTConfig):
|
class DatabricksConfig(OpenAILikeChatConfig):
|
||||||
"""
|
"""
|
||||||
Reference: https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request
|
Reference: https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request
|
||||||
"""
|
"""
|
||||||
|
@ -85,30 +85,6 @@ class DatabricksConfig(OpenAIGPTConfig):
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def map_openai_params(
|
|
||||||
self,
|
|
||||||
non_default_params: dict,
|
|
||||||
optional_params: dict,
|
|
||||||
model: str,
|
|
||||||
drop_params: bool,
|
|
||||||
):
|
|
||||||
for param, value in non_default_params.items():
|
|
||||||
if param == "max_tokens" or param == "max_completion_tokens":
|
|
||||||
optional_params["max_tokens"] = value
|
|
||||||
if param == "n":
|
|
||||||
optional_params["n"] = value
|
|
||||||
if param == "stream" and value is True:
|
|
||||||
optional_params["stream"] = value
|
|
||||||
if param == "temperature":
|
|
||||||
optional_params["temperature"] = value
|
|
||||||
if param == "top_p":
|
|
||||||
optional_params["top_p"] = value
|
|
||||||
if param == "stop":
|
|
||||||
optional_params["stop"] = value
|
|
||||||
if param == "response_format":
|
|
||||||
optional_params["response_format"] = value
|
|
||||||
return optional_params
|
|
||||||
|
|
||||||
def _transform_messages(
|
def _transform_messages(
|
||||||
self, messages: List[AllMessageValues]
|
self, messages: List[AllMessageValues]
|
||||||
) -> List[AllMessageValues]:
|
) -> List[AllMessageValues]:
|
||||||
|
|
120
litellm/llms/deepinfra/chat/transformation.py
Normal file
120
litellm/llms/deepinfra/chat/transformation.py
Normal file
|
@ -0,0 +1,120 @@
|
||||||
|
import types
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||||
|
from litellm.secret_managers.main import get_secret_str
|
||||||
|
|
||||||
|
|
||||||
|
class DeepInfraConfig(OpenAIGPTConfig):
|
||||||
|
"""
|
||||||
|
Reference: https://deepinfra.com/docs/advanced/openai_api
|
||||||
|
|
||||||
|
The class `DeepInfra` provides configuration for the DeepInfra's Chat Completions API interface. Below are the parameters:
|
||||||
|
"""
|
||||||
|
|
||||||
|
frequency_penalty: Optional[int] = None
|
||||||
|
function_call: Optional[Union[str, dict]] = None
|
||||||
|
functions: Optional[list] = None
|
||||||
|
logit_bias: Optional[dict] = None
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
n: Optional[int] = None
|
||||||
|
presence_penalty: Optional[int] = None
|
||||||
|
stop: Optional[Union[str, list]] = None
|
||||||
|
temperature: Optional[int] = None
|
||||||
|
top_p: Optional[int] = None
|
||||||
|
response_format: Optional[dict] = None
|
||||||
|
tools: Optional[list] = None
|
||||||
|
tool_choice: Optional[Union[str, dict]] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
frequency_penalty: Optional[int] = None,
|
||||||
|
function_call: Optional[Union[str, dict]] = None,
|
||||||
|
functions: Optional[list] = None,
|
||||||
|
logit_bias: Optional[dict] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
n: Optional[int] = None,
|
||||||
|
presence_penalty: Optional[int] = None,
|
||||||
|
stop: Optional[Union[str, list]] = None,
|
||||||
|
temperature: Optional[int] = None,
|
||||||
|
top_p: Optional[int] = None,
|
||||||
|
response_format: Optional[dict] = None,
|
||||||
|
tools: Optional[list] = None,
|
||||||
|
tool_choice: Optional[Union[str, dict]] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals().copy()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return super().get_config()
|
||||||
|
|
||||||
|
def get_supported_openai_params(self, model: str):
|
||||||
|
return [
|
||||||
|
"stream",
|
||||||
|
"frequency_penalty",
|
||||||
|
"function_call",
|
||||||
|
"functions",
|
||||||
|
"logit_bias",
|
||||||
|
"max_tokens",
|
||||||
|
"max_completion_tokens",
|
||||||
|
"n",
|
||||||
|
"presence_penalty",
|
||||||
|
"stop",
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"response_format",
|
||||||
|
"tools",
|
||||||
|
"tool_choice",
|
||||||
|
]
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self,
|
||||||
|
non_default_params: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
model: str,
|
||||||
|
drop_params: bool,
|
||||||
|
) -> dict:
|
||||||
|
supported_openai_params = self.get_supported_openai_params(model=model)
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if (
|
||||||
|
param == "temperature"
|
||||||
|
and value == 0
|
||||||
|
and model == "mistralai/Mistral-7B-Instruct-v0.1"
|
||||||
|
): # this model does no support temperature == 0
|
||||||
|
value = 0.0001 # close to 0
|
||||||
|
if param == "tool_choice":
|
||||||
|
if (
|
||||||
|
value != "auto" and value != "none"
|
||||||
|
): # https://deepinfra.com/docs/advanced/function_calling
|
||||||
|
## UNSUPPORTED TOOL CHOICE VALUE
|
||||||
|
if litellm.drop_params is True or drop_params is True:
|
||||||
|
value = None
|
||||||
|
else:
|
||||||
|
raise litellm.utils.UnsupportedParamsError(
|
||||||
|
message="Deepinfra doesn't support tool_choice={}. To drop unsupported openai params from the call, set `litellm.drop_params = True`".format(
|
||||||
|
value
|
||||||
|
),
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
elif param == "max_completion_tokens":
|
||||||
|
optional_params["max_tokens"] = value
|
||||||
|
elif param in supported_openai_params:
|
||||||
|
if value is not None:
|
||||||
|
optional_params[param] = value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
def _get_openai_compatible_provider_info(
|
||||||
|
self, api_base: Optional[str], api_key: Optional[str]
|
||||||
|
) -> Tuple[Optional[str], Optional[str]]:
|
||||||
|
# deepinfra is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1
|
||||||
|
api_base = (
|
||||||
|
api_base
|
||||||
|
or get_secret_str("DEEPINFRA_API_BASE")
|
||||||
|
or "https://api.deepinfra.com/v1/openai"
|
||||||
|
)
|
||||||
|
dynamic_api_key = api_key or get_secret_str("DEEPINFRA_API_KEY")
|
||||||
|
return api_base, dynamic_api_key
|
|
@ -1,421 +0,0 @@
|
||||||
# ####################################
|
|
||||||
# ######### DEPRECATED FILE ##########
|
|
||||||
# ####################################
|
|
||||||
# # logic moved to `vertex_httpx.py` #
|
|
||||||
|
|
||||||
import copy
|
|
||||||
import time
|
|
||||||
import traceback
|
|
||||||
import types
|
|
||||||
from typing import Callable, Optional
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
from packaging.version import Version
|
|
||||||
|
|
||||||
import litellm
|
|
||||||
from litellm import verbose_logger
|
|
||||||
from litellm.utils import Choices, Message, ModelResponse, Usage
|
|
||||||
|
|
||||||
from .prompt_templates.factory import custom_prompt, get_system_prompt, prompt_factory
|
|
||||||
|
|
||||||
|
|
||||||
class GeminiError(Exception):
|
|
||||||
def __init__(self, status_code, message):
|
|
||||||
self.status_code = status_code
|
|
||||||
self.message = message
|
|
||||||
self.request = httpx.Request(
|
|
||||||
method="POST",
|
|
||||||
url="https://developers.generativeai.google/api/python/google/generativeai/chat",
|
|
||||||
)
|
|
||||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
|
||||||
super().__init__(
|
|
||||||
self.message
|
|
||||||
) # Call the base class constructor with the parameters it needs
|
|
||||||
|
|
||||||
|
|
||||||
class GeminiConfig:
|
|
||||||
"""
|
|
||||||
Reference: https://ai.google.dev/api/python/google/generativeai/GenerationConfig
|
|
||||||
|
|
||||||
The class `GeminiConfig` provides configuration for the Gemini's API interface. Here are the parameters:
|
|
||||||
|
|
||||||
- `candidate_count` (int): Number of generated responses to return.
|
|
||||||
|
|
||||||
- `stop_sequences` (List[str]): The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop at the first appearance of a stop sequence. The stop sequence will not be included as part of the response.
|
|
||||||
|
|
||||||
- `max_output_tokens` (int): The maximum number of tokens to include in a candidate. If unset, this will default to output_token_limit specified in the model's specification.
|
|
||||||
|
|
||||||
- `temperature` (float): Controls the randomness of the output. Note: The default value varies by model, see the Model.temperature attribute of the Model returned the genai.get_model function. Values can range from [0.0,1.0], inclusive. A value closer to 1.0 will produce responses that are more varied and creative, while a value closer to 0.0 will typically result in more straightforward responses from the model.
|
|
||||||
|
|
||||||
- `top_p` (float): Optional. The maximum cumulative probability of tokens to consider when sampling.
|
|
||||||
|
|
||||||
- `top_k` (int): Optional. The maximum number of tokens to consider when sampling.
|
|
||||||
"""
|
|
||||||
|
|
||||||
candidate_count: Optional[int] = None
|
|
||||||
stop_sequences: Optional[list] = None
|
|
||||||
max_output_tokens: Optional[int] = None
|
|
||||||
temperature: Optional[float] = None
|
|
||||||
top_p: Optional[float] = None
|
|
||||||
top_k: Optional[int] = None
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
candidate_count: Optional[int] = None,
|
|
||||||
stop_sequences: Optional[list] = None,
|
|
||||||
max_output_tokens: Optional[int] = None,
|
|
||||||
temperature: Optional[float] = None,
|
|
||||||
top_p: Optional[float] = None,
|
|
||||||
top_k: Optional[int] = None,
|
|
||||||
) -> None:
|
|
||||||
locals_ = locals()
|
|
||||||
for key, value in locals_.items():
|
|
||||||
if key != "self" and value is not None:
|
|
||||||
setattr(self.__class__, key, value)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_config(cls):
|
|
||||||
return {
|
|
||||||
k: v
|
|
||||||
for k, v in cls.__dict__.items()
|
|
||||||
if not k.startswith("__")
|
|
||||||
and not isinstance(
|
|
||||||
v,
|
|
||||||
(
|
|
||||||
types.FunctionType,
|
|
||||||
types.BuiltinFunctionType,
|
|
||||||
classmethod,
|
|
||||||
staticmethod,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
and v is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# class TextStreamer:
|
|
||||||
# """
|
|
||||||
# A class designed to return an async stream from AsyncGenerateContentResponse object.
|
|
||||||
# """
|
|
||||||
|
|
||||||
# def __init__(self, response):
|
|
||||||
# self.response = response
|
|
||||||
# self._aiter = self.response.__aiter__()
|
|
||||||
|
|
||||||
# async def __aiter__(self):
|
|
||||||
# while True:
|
|
||||||
# try:
|
|
||||||
# # This will manually advance the async iterator.
|
|
||||||
# # In the case the next object doesn't exists, __anext__() will simply raise a StopAsyncIteration exception
|
|
||||||
# next_object = await self._aiter.__anext__()
|
|
||||||
# yield next_object
|
|
||||||
# except StopAsyncIteration:
|
|
||||||
# # After getting all items from the async iterator, stop iterating
|
|
||||||
# break
|
|
||||||
|
|
||||||
|
|
||||||
# def supports_system_instruction():
|
|
||||||
# import google.generativeai as genai
|
|
||||||
|
|
||||||
# gemini_pkg_version = Version(genai.__version__)
|
|
||||||
# return gemini_pkg_version >= Version("0.5.0")
|
|
||||||
|
|
||||||
|
|
||||||
# def completion(
|
|
||||||
# model: str,
|
|
||||||
# messages: list,
|
|
||||||
# model_response: ModelResponse,
|
|
||||||
# print_verbose: Callable,
|
|
||||||
# api_key,
|
|
||||||
# encoding,
|
|
||||||
# logging_obj,
|
|
||||||
# custom_prompt_dict: dict,
|
|
||||||
# acompletion: bool = False,
|
|
||||||
# optional_params=None,
|
|
||||||
# litellm_params=None,
|
|
||||||
# logger_fn=None,
|
|
||||||
# ):
|
|
||||||
# try:
|
|
||||||
# import google.generativeai as genai # type: ignore
|
|
||||||
# except Exception:
|
|
||||||
# raise Exception(
|
|
||||||
# "Importing google.generativeai failed, please run 'pip install -q google-generativeai"
|
|
||||||
# )
|
|
||||||
# genai.configure(api_key=api_key)
|
|
||||||
# system_prompt = ""
|
|
||||||
# if model in custom_prompt_dict:
|
|
||||||
# # check if the model has a registered custom prompt
|
|
||||||
# model_prompt_details = custom_prompt_dict[model]
|
|
||||||
# prompt = custom_prompt(
|
|
||||||
# role_dict=model_prompt_details["roles"],
|
|
||||||
# initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
|
||||||
# final_prompt_value=model_prompt_details["final_prompt_value"],
|
|
||||||
# messages=messages,
|
|
||||||
# )
|
|
||||||
# else:
|
|
||||||
# system_prompt, messages = get_system_prompt(messages=messages)
|
|
||||||
# prompt = prompt_factory(
|
|
||||||
# model=model, messages=messages, custom_llm_provider="gemini"
|
|
||||||
# )
|
|
||||||
|
|
||||||
# ## Load Config
|
|
||||||
# inference_params = copy.deepcopy(optional_params)
|
|
||||||
# stream = inference_params.pop("stream", None)
|
|
||||||
|
|
||||||
# # Handle safety settings
|
|
||||||
# safety_settings_param = inference_params.pop("safety_settings", None)
|
|
||||||
# safety_settings = None
|
|
||||||
# if safety_settings_param:
|
|
||||||
# safety_settings = [
|
|
||||||
# genai.types.SafetySettingDict(x) for x in safety_settings_param
|
|
||||||
# ]
|
|
||||||
|
|
||||||
# config = litellm.GeminiConfig.get_config()
|
|
||||||
# for k, v in config.items():
|
|
||||||
# if (
|
|
||||||
# k not in inference_params
|
|
||||||
# ): # completion(top_k=3) > gemini_config(top_k=3) <- allows for dynamic variables to be passed in
|
|
||||||
# inference_params[k] = v
|
|
||||||
|
|
||||||
# ## LOGGING
|
|
||||||
# logging_obj.pre_call(
|
|
||||||
# input=prompt,
|
|
||||||
# api_key="",
|
|
||||||
# additional_args={
|
|
||||||
# "complete_input_dict": {
|
|
||||||
# "inference_params": inference_params,
|
|
||||||
# "system_prompt": system_prompt,
|
|
||||||
# }
|
|
||||||
# },
|
|
||||||
# )
|
|
||||||
# ## COMPLETION CALL
|
|
||||||
# try:
|
|
||||||
# _params = {"model_name": "models/{}".format(model)}
|
|
||||||
# _system_instruction = supports_system_instruction()
|
|
||||||
# if _system_instruction and len(system_prompt) > 0:
|
|
||||||
# _params["system_instruction"] = system_prompt
|
|
||||||
# _model = genai.GenerativeModel(**_params)
|
|
||||||
# if stream is True:
|
|
||||||
# if acompletion is True:
|
|
||||||
|
|
||||||
# async def async_streaming():
|
|
||||||
# try:
|
|
||||||
# response = await _model.generate_content_async(
|
|
||||||
# contents=prompt,
|
|
||||||
# generation_config=genai.types.GenerationConfig(
|
|
||||||
# **inference_params
|
|
||||||
# ),
|
|
||||||
# safety_settings=safety_settings,
|
|
||||||
# stream=True,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# response = litellm.CustomStreamWrapper(
|
|
||||||
# TextStreamer(response),
|
|
||||||
# model,
|
|
||||||
# custom_llm_provider="gemini",
|
|
||||||
# logging_obj=logging_obj,
|
|
||||||
# )
|
|
||||||
# return response
|
|
||||||
# except Exception as e:
|
|
||||||
# raise GeminiError(status_code=500, message=str(e))
|
|
||||||
|
|
||||||
# return async_streaming()
|
|
||||||
# response = _model.generate_content(
|
|
||||||
# contents=prompt,
|
|
||||||
# generation_config=genai.types.GenerationConfig(**inference_params),
|
|
||||||
# safety_settings=safety_settings,
|
|
||||||
# stream=True,
|
|
||||||
# )
|
|
||||||
# return response
|
|
||||||
# elif acompletion == True:
|
|
||||||
# return async_completion(
|
|
||||||
# _model=_model,
|
|
||||||
# model=model,
|
|
||||||
# prompt=prompt,
|
|
||||||
# inference_params=inference_params,
|
|
||||||
# safety_settings=safety_settings,
|
|
||||||
# logging_obj=logging_obj,
|
|
||||||
# print_verbose=print_verbose,
|
|
||||||
# model_response=model_response,
|
|
||||||
# messages=messages,
|
|
||||||
# encoding=encoding,
|
|
||||||
# )
|
|
||||||
# else:
|
|
||||||
# params = {
|
|
||||||
# "contents": prompt,
|
|
||||||
# "generation_config": genai.types.GenerationConfig(**inference_params),
|
|
||||||
# "safety_settings": safety_settings,
|
|
||||||
# }
|
|
||||||
# response = _model.generate_content(**params)
|
|
||||||
# except Exception as e:
|
|
||||||
# raise GeminiError(
|
|
||||||
# message=str(e),
|
|
||||||
# status_code=500,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# ## LOGGING
|
|
||||||
# logging_obj.post_call(
|
|
||||||
# input=prompt,
|
|
||||||
# api_key="",
|
|
||||||
# original_response=response,
|
|
||||||
# additional_args={"complete_input_dict": {}},
|
|
||||||
# )
|
|
||||||
# print_verbose(f"raw model_response: {response}")
|
|
||||||
# ## RESPONSE OBJECT
|
|
||||||
# completion_response = response
|
|
||||||
# try:
|
|
||||||
# choices_list = []
|
|
||||||
# for idx, item in enumerate(completion_response.candidates):
|
|
||||||
# if len(item.content.parts) > 0:
|
|
||||||
# message_obj = Message(content=item.content.parts[0].text)
|
|
||||||
# else:
|
|
||||||
# message_obj = Message(content=None)
|
|
||||||
# choice_obj = Choices(index=idx, message=message_obj)
|
|
||||||
# choices_list.append(choice_obj)
|
|
||||||
# model_response.choices = choices_list
|
|
||||||
# except Exception as e:
|
|
||||||
# verbose_logger.error("LiteLLM.gemini.py: Exception occured - {}".format(str(e)))
|
|
||||||
# raise GeminiError(
|
|
||||||
# message=traceback.format_exc(), status_code=response.status_code
|
|
||||||
# )
|
|
||||||
|
|
||||||
# try:
|
|
||||||
# completion_response = model_response["choices"][0]["message"].get("content")
|
|
||||||
# if completion_response is None:
|
|
||||||
# raise Exception
|
|
||||||
# except Exception:
|
|
||||||
# original_response = f"response: {response}"
|
|
||||||
# if hasattr(response, "candidates"):
|
|
||||||
# original_response = f"response: {response.candidates}"
|
|
||||||
# if "SAFETY" in original_response:
|
|
||||||
# original_response += (
|
|
||||||
# "\nThe candidate content was flagged for safety reasons."
|
|
||||||
# )
|
|
||||||
# elif "RECITATION" in original_response:
|
|
||||||
# original_response += (
|
|
||||||
# "\nThe candidate content was flagged for recitation reasons."
|
|
||||||
# )
|
|
||||||
# raise GeminiError(
|
|
||||||
# status_code=400,
|
|
||||||
# message=f"No response received. Original response - {original_response}",
|
|
||||||
# )
|
|
||||||
|
|
||||||
# ## CALCULATING USAGE
|
|
||||||
# prompt_str = ""
|
|
||||||
# for m in messages:
|
|
||||||
# if isinstance(m["content"], str):
|
|
||||||
# prompt_str += m["content"]
|
|
||||||
# elif isinstance(m["content"], list):
|
|
||||||
# for content in m["content"]:
|
|
||||||
# if content["type"] == "text":
|
|
||||||
# prompt_str += content["text"]
|
|
||||||
# prompt_tokens = len(encoding.encode(prompt_str))
|
|
||||||
# completion_tokens = len(
|
|
||||||
# encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
|
||||||
# )
|
|
||||||
|
|
||||||
# model_response.created = int(time.time())
|
|
||||||
# model_response.model = "gemini/" + model
|
|
||||||
# usage = Usage(
|
|
||||||
# prompt_tokens=prompt_tokens,
|
|
||||||
# completion_tokens=completion_tokens,
|
|
||||||
# total_tokens=prompt_tokens + completion_tokens,
|
|
||||||
# )
|
|
||||||
# setattr(model_response, "usage", usage)
|
|
||||||
# return model_response
|
|
||||||
|
|
||||||
|
|
||||||
# async def async_completion(
|
|
||||||
# _model,
|
|
||||||
# model,
|
|
||||||
# prompt,
|
|
||||||
# inference_params,
|
|
||||||
# safety_settings,
|
|
||||||
# logging_obj,
|
|
||||||
# print_verbose,
|
|
||||||
# model_response,
|
|
||||||
# messages,
|
|
||||||
# encoding,
|
|
||||||
# ):
|
|
||||||
# import google.generativeai as genai # type: ignore
|
|
||||||
|
|
||||||
# response = await _model.generate_content_async(
|
|
||||||
# contents=prompt,
|
|
||||||
# generation_config=genai.types.GenerationConfig(**inference_params),
|
|
||||||
# safety_settings=safety_settings,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# ## LOGGING
|
|
||||||
# logging_obj.post_call(
|
|
||||||
# input=prompt,
|
|
||||||
# api_key="",
|
|
||||||
# original_response=response,
|
|
||||||
# additional_args={"complete_input_dict": {}},
|
|
||||||
# )
|
|
||||||
# print_verbose(f"raw model_response: {response}")
|
|
||||||
# ## RESPONSE OBJECT
|
|
||||||
# completion_response = response
|
|
||||||
# try:
|
|
||||||
# choices_list = []
|
|
||||||
# for idx, item in enumerate(completion_response.candidates):
|
|
||||||
# if len(item.content.parts) > 0:
|
|
||||||
# message_obj = Message(content=item.content.parts[0].text)
|
|
||||||
# else:
|
|
||||||
# message_obj = Message(content=None)
|
|
||||||
# choice_obj = Choices(index=idx, message=message_obj)
|
|
||||||
# choices_list.append(choice_obj)
|
|
||||||
# model_response["choices"] = choices_list
|
|
||||||
# except Exception as e:
|
|
||||||
# verbose_logger.error("LiteLLM.gemini.py: Exception occured - {}".format(str(e)))
|
|
||||||
# raise GeminiError(
|
|
||||||
# message=traceback.format_exc(), status_code=response.status_code
|
|
||||||
# )
|
|
||||||
|
|
||||||
# try:
|
|
||||||
# completion_response = model_response["choices"][0]["message"].get("content")
|
|
||||||
# if completion_response is None:
|
|
||||||
# raise Exception
|
|
||||||
# except Exception:
|
|
||||||
# original_response = f"response: {response}"
|
|
||||||
# if hasattr(response, "candidates"):
|
|
||||||
# original_response = f"response: {response.candidates}"
|
|
||||||
# if "SAFETY" in original_response:
|
|
||||||
# original_response += (
|
|
||||||
# "\nThe candidate content was flagged for safety reasons."
|
|
||||||
# )
|
|
||||||
# elif "RECITATION" in original_response:
|
|
||||||
# original_response += (
|
|
||||||
# "\nThe candidate content was flagged for recitation reasons."
|
|
||||||
# )
|
|
||||||
# raise GeminiError(
|
|
||||||
# status_code=400,
|
|
||||||
# message=f"No response received. Original response - {original_response}",
|
|
||||||
# )
|
|
||||||
|
|
||||||
# ## CALCULATING USAGE
|
|
||||||
# prompt_str = ""
|
|
||||||
# for m in messages:
|
|
||||||
# if isinstance(m["content"], str):
|
|
||||||
# prompt_str += m["content"]
|
|
||||||
# elif isinstance(m["content"], list):
|
|
||||||
# for content in m["content"]:
|
|
||||||
# if content["type"] == "text":
|
|
||||||
# prompt_str += content["text"]
|
|
||||||
# prompt_tokens = len(encoding.encode(prompt_str))
|
|
||||||
# completion_tokens = len(
|
|
||||||
# encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
|
||||||
# )
|
|
||||||
|
|
||||||
# model_response["created"] = int(time.time())
|
|
||||||
# model_response["model"] = "gemini/" + model
|
|
||||||
# usage = Usage(
|
|
||||||
# prompt_tokens=prompt_tokens,
|
|
||||||
# completion_tokens=completion_tokens,
|
|
||||||
# total_tokens=prompt_tokens + completion_tokens,
|
|
||||||
# )
|
|
||||||
# model_response.usage = usage
|
|
||||||
# return model_response
|
|
||||||
|
|
||||||
|
|
||||||
# def embedding():
|
|
||||||
# # logic for parsing in - calling - parsing out model embedding calls
|
|
||||||
# pass
|
|
750
litellm/llms/huggingface/chat/handler.py
Normal file
750
litellm/llms/huggingface/chat/handler.py
Normal file
|
@ -0,0 +1,750 @@
|
||||||
|
## Uses the huggingface text generation inference API
|
||||||
|
import copy
|
||||||
|
import enum
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import types
|
||||||
|
from enum import Enum
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
get_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import requests
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
AsyncHTTPHandler,
|
||||||
|
HTTPHandler,
|
||||||
|
_get_httpx_client,
|
||||||
|
get_async_httpx_client,
|
||||||
|
)
|
||||||
|
from litellm.llms.huggingface.chat.transformation import (
|
||||||
|
HuggingfaceChatConfig as HuggingfaceConfig,
|
||||||
|
)
|
||||||
|
from litellm.secret_managers.main import get_secret_str
|
||||||
|
from litellm.types.completion import ChatCompletionMessageToolCallParam
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
|
from litellm.types.utils import Logprobs as TextCompletionLogprobs
|
||||||
|
from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage
|
||||||
|
|
||||||
|
from ...base import BaseLLM
|
||||||
|
from ...prompt_templates.factory import custom_prompt, prompt_factory
|
||||||
|
from ..common_utils import HuggingfaceError, hf_task_list, hf_tasks
|
||||||
|
|
||||||
|
hf_chat_config = HuggingfaceConfig()
|
||||||
|
|
||||||
|
|
||||||
|
hf_tasks_embeddings = Literal[ # pipeline tags + hf tei endpoints - https://huggingface.github.io/text-embeddings-inference/#/
|
||||||
|
"sentence-similarity", "feature-extraction", "rerank", "embed", "similarity"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_hf_task_embedding_for_model(
|
||||||
|
model: str, task_type: Optional[str], api_base: str
|
||||||
|
) -> Optional[str]:
|
||||||
|
if task_type is not None:
|
||||||
|
if task_type in get_args(hf_tasks_embeddings):
|
||||||
|
return task_type
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
"Invalid task_type={}. Expected one of={}".format(
|
||||||
|
task_type, hf_tasks_embeddings
|
||||||
|
)
|
||||||
|
)
|
||||||
|
http_client = HTTPHandler(concurrent_limit=1)
|
||||||
|
|
||||||
|
model_info = http_client.get(url=api_base)
|
||||||
|
|
||||||
|
model_info_dict = model_info.json()
|
||||||
|
|
||||||
|
pipeline_tag: Optional[str] = model_info_dict.get("pipeline_tag", None)
|
||||||
|
|
||||||
|
return pipeline_tag
|
||||||
|
|
||||||
|
|
||||||
|
async def async_get_hf_task_embedding_for_model(
|
||||||
|
model: str, task_type: Optional[str], api_base: str
|
||||||
|
) -> Optional[str]:
|
||||||
|
if task_type is not None:
|
||||||
|
if task_type in get_args(hf_tasks_embeddings):
|
||||||
|
return task_type
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
"Invalid task_type={}. Expected one of={}".format(
|
||||||
|
task_type, hf_tasks_embeddings
|
||||||
|
)
|
||||||
|
)
|
||||||
|
http_client = get_async_httpx_client(
|
||||||
|
llm_provider=litellm.LlmProviders.HUGGINGFACE,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_info = await http_client.get(url=api_base)
|
||||||
|
|
||||||
|
model_info_dict = model_info.json()
|
||||||
|
|
||||||
|
pipeline_tag: Optional[str] = model_info_dict.get("pipeline_tag", None)
|
||||||
|
|
||||||
|
return pipeline_tag
|
||||||
|
|
||||||
|
|
||||||
|
async def make_call(
|
||||||
|
client: Optional[AsyncHTTPHandler],
|
||||||
|
api_base: str,
|
||||||
|
headers: dict,
|
||||||
|
data: str,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
logging_obj,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
json_mode: bool,
|
||||||
|
) -> Tuple[Any, httpx.Headers]:
|
||||||
|
if client is None:
|
||||||
|
client = litellm.module_level_aclient
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.post(
|
||||||
|
api_base, headers=headers, data=data, stream=True, timeout=timeout
|
||||||
|
)
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
error_headers = getattr(e, "headers", None)
|
||||||
|
error_response = getattr(e, "response", None)
|
||||||
|
if error_headers is None and error_response:
|
||||||
|
error_headers = getattr(error_response, "headers", None)
|
||||||
|
raise HuggingfaceError(
|
||||||
|
status_code=e.response.status_code,
|
||||||
|
message=str(await e.response.aread()),
|
||||||
|
headers=cast(dict, error_headers) if error_headers else None,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
for exception in litellm.LITELLM_EXCEPTION_TYPES:
|
||||||
|
if isinstance(e, exception):
|
||||||
|
raise e
|
||||||
|
raise HuggingfaceError(status_code=500, message=str(e))
|
||||||
|
|
||||||
|
# LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key="",
|
||||||
|
original_response=response, # Pass the completion stream for logging
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
|
||||||
|
return response.aiter_lines(), response.headers
|
||||||
|
|
||||||
|
|
||||||
|
class Huggingface(BaseLLM):
|
||||||
|
_client_session: Optional[httpx.Client] = None
|
||||||
|
_aclient_session: Optional[httpx.AsyncClient] = None
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def completion( # noqa: PLR0915
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: Optional[str],
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
timeout: float,
|
||||||
|
encoding,
|
||||||
|
api_key,
|
||||||
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
custom_prompt_dict={},
|
||||||
|
acompletion: bool = False,
|
||||||
|
logger_fn=None,
|
||||||
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||||
|
headers: dict = {},
|
||||||
|
):
|
||||||
|
super().completion()
|
||||||
|
exception_mapping_worked = False
|
||||||
|
try:
|
||||||
|
task, model = hf_chat_config.get_hf_task_for_model(model)
|
||||||
|
litellm_params["task"] = task
|
||||||
|
headers = hf_chat_config.validate_environment(
|
||||||
|
api_key=api_key,
|
||||||
|
headers=headers,
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
)
|
||||||
|
completion_url = hf_chat_config.get_api_base(api_base=api_base, model=model)
|
||||||
|
data = hf_chat_config.transform_request(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=data,
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": data,
|
||||||
|
"headers": headers,
|
||||||
|
"api_base": completion_url,
|
||||||
|
"acompletion": acompletion,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
## COMPLETION CALL
|
||||||
|
|
||||||
|
if acompletion is True:
|
||||||
|
### ASYNC STREAMING
|
||||||
|
if optional_params.get("stream", False):
|
||||||
|
return self.async_streaming(logging_obj=logging_obj, api_base=completion_url, data=data, headers=headers, model_response=model_response, model=model, timeout=timeout, messages=messages) # type: ignore
|
||||||
|
else:
|
||||||
|
### ASYNC COMPLETION
|
||||||
|
return self.acompletion(api_base=completion_url, data=data, headers=headers, model_response=model_response, task=task, encoding=encoding, model=model, optional_params=optional_params, timeout=timeout, litellm_params=litellm_params) # type: ignore
|
||||||
|
if client is None or not isinstance(client, HTTPHandler):
|
||||||
|
client = HTTPHandler()
|
||||||
|
### SYNC STREAMING
|
||||||
|
if "stream" in optional_params and optional_params["stream"] is True:
|
||||||
|
response = client.post(
|
||||||
|
url=completion_url,
|
||||||
|
headers=headers,
|
||||||
|
data=json.dumps(data),
|
||||||
|
stream=optional_params["stream"],
|
||||||
|
)
|
||||||
|
return response.iter_lines()
|
||||||
|
### SYNC COMPLETION
|
||||||
|
else:
|
||||||
|
response = client.post(
|
||||||
|
url=completion_url,
|
||||||
|
headers=headers,
|
||||||
|
data=json.dumps(data),
|
||||||
|
)
|
||||||
|
|
||||||
|
return hf_chat_config.transform_response(
|
||||||
|
model=model,
|
||||||
|
raw_response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_key=api_key,
|
||||||
|
request_data=data,
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
encoding=encoding,
|
||||||
|
json_mode=None,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
)
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
raise HuggingfaceError(
|
||||||
|
status_code=e.response.status_code,
|
||||||
|
message=e.response.text,
|
||||||
|
headers=e.response.headers,
|
||||||
|
)
|
||||||
|
except HuggingfaceError as e:
|
||||||
|
exception_mapping_worked = True
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
if exception_mapping_worked:
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
raise HuggingfaceError(status_code=500, message=traceback.format_exc())
|
||||||
|
|
||||||
|
async def acompletion(
|
||||||
|
self,
|
||||||
|
api_base: str,
|
||||||
|
data: dict,
|
||||||
|
headers: dict,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
encoding: Any,
|
||||||
|
model: str,
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
timeout: float,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
api_key: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
):
|
||||||
|
response: Optional[httpx.Response] = None
|
||||||
|
try:
|
||||||
|
http_client = get_async_httpx_client(
|
||||||
|
llm_provider=litellm.LlmProviders.HUGGINGFACE
|
||||||
|
)
|
||||||
|
### ASYNC COMPLETION
|
||||||
|
http_response = await http_client.post(
|
||||||
|
url=api_base, headers=headers, data=json.dumps(data), timeout=timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
response = http_response
|
||||||
|
|
||||||
|
return hf_chat_config.transform_response(
|
||||||
|
model=model,
|
||||||
|
raw_response=http_response,
|
||||||
|
model_response=model_response,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_key=api_key,
|
||||||
|
request_data=data,
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
encoding=encoding,
|
||||||
|
json_mode=None,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, httpx.TimeoutException):
|
||||||
|
raise HuggingfaceError(status_code=500, message="Request Timeout Error")
|
||||||
|
elif isinstance(e, HuggingfaceError):
|
||||||
|
raise e
|
||||||
|
elif response is not None and hasattr(response, "text"):
|
||||||
|
raise HuggingfaceError(
|
||||||
|
status_code=500,
|
||||||
|
message=f"{str(e)}\n\nOriginal Response: {response.text}",
|
||||||
|
headers=response.headers,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise HuggingfaceError(status_code=500, message=f"{str(e)}")
|
||||||
|
|
||||||
|
async def async_streaming(
|
||||||
|
self,
|
||||||
|
logging_obj,
|
||||||
|
api_base: str,
|
||||||
|
data: dict,
|
||||||
|
headers: dict,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
model: str,
|
||||||
|
timeout: float,
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
):
|
||||||
|
completion_stream, _ = await make_call(
|
||||||
|
client=client,
|
||||||
|
api_base=api_base,
|
||||||
|
headers=headers,
|
||||||
|
data=json.dumps(data),
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
timeout=timeout,
|
||||||
|
json_mode=False,
|
||||||
|
)
|
||||||
|
streamwrapper = CustomStreamWrapper(
|
||||||
|
completion_stream=completion_stream,
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="huggingface",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
return streamwrapper
|
||||||
|
|
||||||
|
def _transform_input_on_pipeline_tag(
|
||||||
|
self, input: List, pipeline_tag: Optional[str]
|
||||||
|
) -> dict:
|
||||||
|
if pipeline_tag is None:
|
||||||
|
return {"inputs": input}
|
||||||
|
if pipeline_tag == "sentence-similarity" or pipeline_tag == "similarity":
|
||||||
|
if len(input) < 2:
|
||||||
|
raise HuggingfaceError(
|
||||||
|
status_code=400,
|
||||||
|
message="sentence-similarity requires 2+ sentences",
|
||||||
|
)
|
||||||
|
return {"inputs": {"source_sentence": input[0], "sentences": input[1:]}}
|
||||||
|
elif pipeline_tag == "rerank":
|
||||||
|
if len(input) < 2:
|
||||||
|
raise HuggingfaceError(
|
||||||
|
status_code=400,
|
||||||
|
message="reranker requires 2+ sentences",
|
||||||
|
)
|
||||||
|
return {"inputs": {"query": input[0], "texts": input[1:]}}
|
||||||
|
return {"inputs": input} # default to feature-extraction pipeline tag
|
||||||
|
|
||||||
|
async def _async_transform_input(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
task_type: Optional[str],
|
||||||
|
embed_url: str,
|
||||||
|
input: List,
|
||||||
|
optional_params: dict,
|
||||||
|
) -> dict:
|
||||||
|
hf_task = await async_get_hf_task_embedding_for_model(
|
||||||
|
model=model, task_type=task_type, api_base=embed_url
|
||||||
|
)
|
||||||
|
|
||||||
|
data = self._transform_input_on_pipeline_tag(input=input, pipeline_tag=hf_task)
|
||||||
|
|
||||||
|
if len(optional_params.keys()) > 0:
|
||||||
|
data["options"] = optional_params
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _process_optional_params(self, data: dict, optional_params: dict) -> dict:
|
||||||
|
special_options_keys = HuggingfaceConfig().get_special_options_params()
|
||||||
|
special_parameters_keys = [
|
||||||
|
"min_length",
|
||||||
|
"max_length",
|
||||||
|
"top_k",
|
||||||
|
"top_p",
|
||||||
|
"temperature",
|
||||||
|
"repetition_penalty",
|
||||||
|
"max_time",
|
||||||
|
]
|
||||||
|
|
||||||
|
for k, v in optional_params.items():
|
||||||
|
if k in special_options_keys:
|
||||||
|
data.setdefault("options", {})
|
||||||
|
data["options"][k] = v
|
||||||
|
elif k in special_parameters_keys:
|
||||||
|
data.setdefault("parameters", {})
|
||||||
|
data["parameters"][k] = v
|
||||||
|
else:
|
||||||
|
data[k] = v
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _transform_input(
|
||||||
|
self,
|
||||||
|
input: List,
|
||||||
|
model: str,
|
||||||
|
call_type: Literal["sync", "async"],
|
||||||
|
optional_params: dict,
|
||||||
|
embed_url: str,
|
||||||
|
) -> dict:
|
||||||
|
data: Dict = {}
|
||||||
|
## TRANSFORMATION ##
|
||||||
|
if "sentence-transformers" in model:
|
||||||
|
if len(input) == 0:
|
||||||
|
raise HuggingfaceError(
|
||||||
|
status_code=400,
|
||||||
|
message="sentence transformers requires 2+ sentences",
|
||||||
|
)
|
||||||
|
data = {"inputs": {"source_sentence": input[0], "sentences": input[1:]}}
|
||||||
|
else:
|
||||||
|
data = {"inputs": input}
|
||||||
|
|
||||||
|
task_type = optional_params.pop("input_type", None)
|
||||||
|
|
||||||
|
if call_type == "sync":
|
||||||
|
hf_task = get_hf_task_embedding_for_model(
|
||||||
|
model=model, task_type=task_type, api_base=embed_url
|
||||||
|
)
|
||||||
|
elif call_type == "async":
|
||||||
|
return self._async_transform_input(
|
||||||
|
model=model, task_type=task_type, embed_url=embed_url, input=input
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
data = self._transform_input_on_pipeline_tag(
|
||||||
|
input=input, pipeline_tag=hf_task
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(optional_params.keys()) > 0:
|
||||||
|
data = self._process_optional_params(
|
||||||
|
data=data, optional_params=optional_params
|
||||||
|
)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _process_embedding_response(
|
||||||
|
self,
|
||||||
|
embeddings: dict,
|
||||||
|
model_response: litellm.EmbeddingResponse,
|
||||||
|
model: str,
|
||||||
|
input: List,
|
||||||
|
encoding: Any,
|
||||||
|
) -> litellm.EmbeddingResponse:
|
||||||
|
output_data = []
|
||||||
|
if "similarities" in embeddings:
|
||||||
|
for idx, embedding in embeddings["similarities"]:
|
||||||
|
output_data.append(
|
||||||
|
{
|
||||||
|
"object": "embedding",
|
||||||
|
"index": idx,
|
||||||
|
"embedding": embedding, # flatten list returned from hf
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
for idx, embedding in enumerate(embeddings):
|
||||||
|
if isinstance(embedding, float):
|
||||||
|
output_data.append(
|
||||||
|
{
|
||||||
|
"object": "embedding",
|
||||||
|
"index": idx,
|
||||||
|
"embedding": embedding, # flatten list returned from hf
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif isinstance(embedding, list) and isinstance(embedding[0], float):
|
||||||
|
output_data.append(
|
||||||
|
{
|
||||||
|
"object": "embedding",
|
||||||
|
"index": idx,
|
||||||
|
"embedding": embedding, # flatten list returned from hf
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output_data.append(
|
||||||
|
{
|
||||||
|
"object": "embedding",
|
||||||
|
"index": idx,
|
||||||
|
"embedding": embedding[0][
|
||||||
|
0
|
||||||
|
], # flatten list returned from hf
|
||||||
|
}
|
||||||
|
)
|
||||||
|
model_response.object = "list"
|
||||||
|
model_response.data = output_data
|
||||||
|
model_response.model = model
|
||||||
|
input_tokens = 0
|
||||||
|
for text in input:
|
||||||
|
input_tokens += len(encoding.encode(text))
|
||||||
|
|
||||||
|
setattr(
|
||||||
|
model_response,
|
||||||
|
"usage",
|
||||||
|
litellm.Usage(
|
||||||
|
prompt_tokens=input_tokens,
|
||||||
|
completion_tokens=input_tokens,
|
||||||
|
total_tokens=input_tokens,
|
||||||
|
prompt_tokens_details=None,
|
||||||
|
completion_tokens_details=None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
async def aembedding(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input: list,
|
||||||
|
model_response: litellm.utils.EmbeddingResponse,
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
optional_params: dict,
|
||||||
|
api_base: str,
|
||||||
|
api_key: Optional[str],
|
||||||
|
headers: dict,
|
||||||
|
encoding: Callable,
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
):
|
||||||
|
## TRANSFORMATION ##
|
||||||
|
data = self._transform_input(
|
||||||
|
input=input,
|
||||||
|
model=model,
|
||||||
|
call_type="sync",
|
||||||
|
optional_params=optional_params,
|
||||||
|
embed_url=api_base,
|
||||||
|
)
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=input,
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": data,
|
||||||
|
"headers": headers,
|
||||||
|
"api_base": api_base,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
## COMPLETION CALL
|
||||||
|
if client is None:
|
||||||
|
client = get_async_httpx_client(
|
||||||
|
llm_provider=litellm.LlmProviders.HUGGINGFACE,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.post(api_base, headers=headers, data=json.dumps(data))
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=input,
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
original_response=response,
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = response.json()
|
||||||
|
|
||||||
|
if "error" in embeddings:
|
||||||
|
raise HuggingfaceError(status_code=500, message=embeddings["error"])
|
||||||
|
|
||||||
|
## PROCESS RESPONSE ##
|
||||||
|
return self._process_embedding_response(
|
||||||
|
embeddings=embeddings,
|
||||||
|
model_response=model_response,
|
||||||
|
model=model,
|
||||||
|
input=input,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
|
||||||
|
def embedding(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input: list,
|
||||||
|
model_response: litellm.EmbeddingResponse,
|
||||||
|
optional_params: dict,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
encoding: Callable,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
|
||||||
|
aembedding: Optional[bool] = None,
|
||||||
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||||
|
headers={},
|
||||||
|
) -> litellm.EmbeddingResponse:
|
||||||
|
super().embedding()
|
||||||
|
headers = hf_chat_config.validate_environment(
|
||||||
|
api_key=api_key,
|
||||||
|
headers=headers,
|
||||||
|
model=model,
|
||||||
|
optional_params=optional_params,
|
||||||
|
messages=[],
|
||||||
|
)
|
||||||
|
# print_verbose(f"{model}, {task}")
|
||||||
|
embed_url = ""
|
||||||
|
if "https" in model:
|
||||||
|
embed_url = model
|
||||||
|
elif api_base:
|
||||||
|
embed_url = api_base
|
||||||
|
elif "HF_API_BASE" in os.environ:
|
||||||
|
embed_url = os.getenv("HF_API_BASE", "")
|
||||||
|
elif "HUGGINGFACE_API_BASE" in os.environ:
|
||||||
|
embed_url = os.getenv("HUGGINGFACE_API_BASE", "")
|
||||||
|
else:
|
||||||
|
embed_url = f"https://api-inference.huggingface.co/models/{model}"
|
||||||
|
|
||||||
|
## ROUTING ##
|
||||||
|
if aembedding is True:
|
||||||
|
return self.aembedding(
|
||||||
|
input=input,
|
||||||
|
model_response=model_response,
|
||||||
|
timeout=timeout,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
headers=headers,
|
||||||
|
api_base=embed_url, # type: ignore
|
||||||
|
api_key=api_key,
|
||||||
|
client=client if isinstance(client, AsyncHTTPHandler) else None,
|
||||||
|
model=model,
|
||||||
|
optional_params=optional_params,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
|
||||||
|
## TRANSFORMATION ##
|
||||||
|
|
||||||
|
data = self._transform_input(
|
||||||
|
input=input,
|
||||||
|
model=model,
|
||||||
|
call_type="sync",
|
||||||
|
optional_params=optional_params,
|
||||||
|
embed_url=embed_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=input,
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": data,
|
||||||
|
"headers": headers,
|
||||||
|
"api_base": embed_url,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
## COMPLETION CALL
|
||||||
|
if client is None or not isinstance(client, HTTPHandler):
|
||||||
|
client = HTTPHandler(concurrent_limit=1)
|
||||||
|
response = client.post(embed_url, headers=headers, data=json.dumps(data))
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=input,
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
original_response=response,
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = response.json()
|
||||||
|
|
||||||
|
if "error" in embeddings:
|
||||||
|
raise HuggingfaceError(status_code=500, message=embeddings["error"])
|
||||||
|
|
||||||
|
## PROCESS RESPONSE ##
|
||||||
|
return self._process_embedding_response(
|
||||||
|
embeddings=embeddings,
|
||||||
|
model_response=model_response,
|
||||||
|
model=model,
|
||||||
|
input=input,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _transform_logprobs(
|
||||||
|
self, hf_response: Optional[List]
|
||||||
|
) -> Optional[TextCompletionLogprobs]:
|
||||||
|
"""
|
||||||
|
Transform Hugging Face logprobs to OpenAI.Completion() format
|
||||||
|
"""
|
||||||
|
if hf_response is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Initialize an empty list for the transformed logprobs
|
||||||
|
_logprob: TextCompletionLogprobs = TextCompletionLogprobs(
|
||||||
|
text_offset=[],
|
||||||
|
token_logprobs=[],
|
||||||
|
tokens=[],
|
||||||
|
top_logprobs=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
# For each Hugging Face response, transform the logprobs
|
||||||
|
for response in hf_response:
|
||||||
|
# Extract the relevant information from the response
|
||||||
|
response_details = response["details"]
|
||||||
|
top_tokens = response_details.get("top_tokens", {})
|
||||||
|
|
||||||
|
for i, token in enumerate(response_details["prefill"]):
|
||||||
|
# Extract the text of the token
|
||||||
|
token_text = token["text"]
|
||||||
|
|
||||||
|
# Extract the logprob of the token
|
||||||
|
token_logprob = token["logprob"]
|
||||||
|
|
||||||
|
# Add the token information to the 'token_info' list
|
||||||
|
_logprob.tokens.append(token_text)
|
||||||
|
_logprob.token_logprobs.append(token_logprob)
|
||||||
|
|
||||||
|
# stub this to work with llm eval harness
|
||||||
|
top_alt_tokens = {"": -1.0, "": -2.0, "": -3.0} # noqa: F601
|
||||||
|
_logprob.top_logprobs.append(top_alt_tokens)
|
||||||
|
|
||||||
|
# For each element in the 'tokens' list, extract the relevant information
|
||||||
|
for i, token in enumerate(response_details["tokens"]):
|
||||||
|
# Extract the text of the token
|
||||||
|
token_text = token["text"]
|
||||||
|
|
||||||
|
# Extract the logprob of the token
|
||||||
|
token_logprob = token["logprob"]
|
||||||
|
|
||||||
|
top_alt_tokens = {}
|
||||||
|
temp_top_logprobs = []
|
||||||
|
if top_tokens != {}:
|
||||||
|
temp_top_logprobs = top_tokens[i]
|
||||||
|
|
||||||
|
# top_alt_tokens should look like this: { "alternative_1": -1, "alternative_2": -2, "alternative_3": -3 }
|
||||||
|
for elem in temp_top_logprobs:
|
||||||
|
text = elem["text"]
|
||||||
|
logprob = elem["logprob"]
|
||||||
|
top_alt_tokens[text] = logprob
|
||||||
|
|
||||||
|
# Add the token information to the 'token_info' list
|
||||||
|
_logprob.tokens.append(token_text)
|
||||||
|
_logprob.token_logprobs.append(token_logprob)
|
||||||
|
_logprob.top_logprobs.append(top_alt_tokens)
|
||||||
|
|
||||||
|
# Add the text offset of the token
|
||||||
|
# This is computed as the sum of the lengths of all previous tokens
|
||||||
|
_logprob.text_offset.append(
|
||||||
|
sum(len(t["text"]) for t in response_details["tokens"][:i])
|
||||||
|
)
|
||||||
|
|
||||||
|
return _logprob
|
590
litellm/llms/huggingface/chat/transformation.py
Normal file
590
litellm/llms/huggingface/chat/transformation.py
Normal file
|
@ -0,0 +1,590 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import types
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
|
||||||
|
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
||||||
|
from litellm.llms.prompt_templates.common_utils import convert_content_list_to_str
|
||||||
|
from litellm.llms.prompt_templates.factory import custom_prompt, prompt_factory
|
||||||
|
from litellm.secret_managers.main import get_secret_str
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
|
from litellm.types.utils import Choices, Message, ModelResponse, Usage
|
||||||
|
from litellm.utils import token_counter
|
||||||
|
|
||||||
|
from ..common_utils import HuggingfaceError, hf_task_list, hf_tasks, output_parser
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
|
|
||||||
|
LoggingClass = LiteLLMLoggingObj
|
||||||
|
else:
|
||||||
|
LoggingClass = Any
|
||||||
|
|
||||||
|
|
||||||
|
tgi_models_cache = None
|
||||||
|
conv_models_cache = None
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingfaceChatConfig(BaseConfig):
|
||||||
|
"""
|
||||||
|
Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate
|
||||||
|
"""
|
||||||
|
|
||||||
|
hf_task: Optional[hf_tasks] = (
|
||||||
|
None # litellm-specific param, used to know the api spec to use when calling huggingface api
|
||||||
|
)
|
||||||
|
best_of: Optional[int] = None
|
||||||
|
decoder_input_details: Optional[bool] = None
|
||||||
|
details: Optional[bool] = True # enables returning logprobs + best of
|
||||||
|
max_new_tokens: Optional[int] = None
|
||||||
|
repetition_penalty: Optional[float] = None
|
||||||
|
return_full_text: Optional[bool] = (
|
||||||
|
False # by default don't return the input as part of the output
|
||||||
|
)
|
||||||
|
seed: Optional[int] = None
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
top_k: Optional[int] = None
|
||||||
|
top_n_tokens: Optional[int] = None
|
||||||
|
top_p: Optional[int] = None
|
||||||
|
truncate: Optional[int] = None
|
||||||
|
typical_p: Optional[float] = None
|
||||||
|
watermark: Optional[bool] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
best_of: Optional[int] = None,
|
||||||
|
decoder_input_details: Optional[bool] = None,
|
||||||
|
details: Optional[bool] = None,
|
||||||
|
max_new_tokens: Optional[int] = None,
|
||||||
|
repetition_penalty: Optional[float] = None,
|
||||||
|
return_full_text: Optional[bool] = None,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
top_n_tokens: Optional[int] = None,
|
||||||
|
top_p: Optional[int] = None,
|
||||||
|
truncate: Optional[int] = None,
|
||||||
|
typical_p: Optional[float] = None,
|
||||||
|
watermark: Optional[bool] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return super().get_config()
|
||||||
|
|
||||||
|
def get_special_options_params(self):
|
||||||
|
return ["use_cache", "wait_for_model"]
|
||||||
|
|
||||||
|
def get_supported_openai_params(self, model: str):
|
||||||
|
return [
|
||||||
|
"stream",
|
||||||
|
"temperature",
|
||||||
|
"max_tokens",
|
||||||
|
"max_completion_tokens",
|
||||||
|
"top_p",
|
||||||
|
"stop",
|
||||||
|
"n",
|
||||||
|
"echo",
|
||||||
|
]
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self,
|
||||||
|
non_default_params: Dict,
|
||||||
|
optional_params: Dict,
|
||||||
|
model: str,
|
||||||
|
drop_params: bool,
|
||||||
|
) -> Dict:
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
|
||||||
|
if param == "temperature":
|
||||||
|
if value == 0.0 or value == 0:
|
||||||
|
# hugging face exception raised when temp==0
|
||||||
|
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
|
||||||
|
value = 0.01
|
||||||
|
optional_params["temperature"] = value
|
||||||
|
if param == "top_p":
|
||||||
|
optional_params["top_p"] = value
|
||||||
|
if param == "n":
|
||||||
|
optional_params["best_of"] = value
|
||||||
|
optional_params["do_sample"] = (
|
||||||
|
True # Need to sample if you want best of for hf inference endpoints
|
||||||
|
)
|
||||||
|
if param == "stream":
|
||||||
|
optional_params["stream"] = value
|
||||||
|
if param == "stop":
|
||||||
|
optional_params["stop"] = value
|
||||||
|
if param == "max_tokens" or param == "max_completion_tokens":
|
||||||
|
# HF TGI raises the following exception when max_new_tokens==0
|
||||||
|
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
|
||||||
|
if value == 0:
|
||||||
|
value = 1
|
||||||
|
optional_params["max_new_tokens"] = value
|
||||||
|
if param == "echo":
|
||||||
|
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
|
||||||
|
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
|
||||||
|
optional_params["decoder_input_details"] = True
|
||||||
|
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
def get_hf_api_key(self) -> Optional[str]:
|
||||||
|
return get_secret_str("HUGGINGFACE_API_KEY")
|
||||||
|
|
||||||
|
def read_tgi_conv_models(self):
|
||||||
|
try:
|
||||||
|
global tgi_models_cache, conv_models_cache
|
||||||
|
# Check if the cache is already populated
|
||||||
|
# so we don't keep on reading txt file if there are 1k requests
|
||||||
|
if (tgi_models_cache is not None) and (conv_models_cache is not None):
|
||||||
|
return tgi_models_cache, conv_models_cache
|
||||||
|
# If not, read the file and populate the cache
|
||||||
|
tgi_models = set()
|
||||||
|
script_directory = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
script_directory = os.path.dirname(script_directory)
|
||||||
|
# Construct the file path relative to the script's directory
|
||||||
|
file_path = os.path.join(
|
||||||
|
script_directory,
|
||||||
|
"huggingface_llms_metadata",
|
||||||
|
"hf_text_generation_models.txt",
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(file_path, "r") as file:
|
||||||
|
for line in file:
|
||||||
|
tgi_models.add(line.strip())
|
||||||
|
|
||||||
|
# Cache the set for future use
|
||||||
|
tgi_models_cache = tgi_models
|
||||||
|
|
||||||
|
# If not, read the file and populate the cache
|
||||||
|
file_path = os.path.join(
|
||||||
|
script_directory,
|
||||||
|
"huggingface_llms_metadata",
|
||||||
|
"hf_conversational_models.txt",
|
||||||
|
)
|
||||||
|
conv_models = set()
|
||||||
|
with open(file_path, "r") as file:
|
||||||
|
for line in file:
|
||||||
|
conv_models.add(line.strip())
|
||||||
|
# Cache the set for future use
|
||||||
|
conv_models_cache = conv_models
|
||||||
|
return tgi_models, conv_models
|
||||||
|
except Exception:
|
||||||
|
return set(), set()
|
||||||
|
|
||||||
|
def get_hf_task_for_model(self, model: str) -> Tuple[hf_tasks, str]:
|
||||||
|
# read text file, cast it to set
|
||||||
|
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
|
||||||
|
if model.split("/")[0] in hf_task_list:
|
||||||
|
split_model = model.split("/", 1)
|
||||||
|
return split_model[0], split_model[1] # type: ignore
|
||||||
|
tgi_models, conversational_models = self.read_tgi_conv_models()
|
||||||
|
|
||||||
|
if model in tgi_models:
|
||||||
|
return "text-generation-inference", model
|
||||||
|
elif model in conversational_models:
|
||||||
|
return "conversational", model
|
||||||
|
elif "roneneldan/TinyStories" in model:
|
||||||
|
return "text-generation", model
|
||||||
|
else:
|
||||||
|
return "text-generation-inference", model # default to tgi
|
||||||
|
|
||||||
|
def transform_request(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
headers: dict,
|
||||||
|
) -> dict:
|
||||||
|
task = litellm_params.get("task", None)
|
||||||
|
## VALIDATE API FORMAT
|
||||||
|
if task is None or not isinstance(task, str) or task not in hf_task_list:
|
||||||
|
raise Exception(
|
||||||
|
"Invalid hf task - {}. Valid formats - {}.".format(task, hf_tasks)
|
||||||
|
)
|
||||||
|
|
||||||
|
## Load Config
|
||||||
|
config = litellm.HuggingfaceConfig.get_config()
|
||||||
|
for k, v in config.items():
|
||||||
|
if (
|
||||||
|
k not in optional_params
|
||||||
|
): # completion(top_k=3) > huggingfaceConfig(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
|
optional_params[k] = v
|
||||||
|
|
||||||
|
### MAP INPUT PARAMS
|
||||||
|
#### HANDLE SPECIAL PARAMS
|
||||||
|
special_params = self.get_special_options_params()
|
||||||
|
special_params_dict = {}
|
||||||
|
# Create a list of keys to pop after iteration
|
||||||
|
keys_to_pop = []
|
||||||
|
|
||||||
|
for k, v in optional_params.items():
|
||||||
|
if k in special_params:
|
||||||
|
special_params_dict[k] = v
|
||||||
|
keys_to_pop.append(k)
|
||||||
|
|
||||||
|
# Pop the keys from the dictionary after iteration
|
||||||
|
for k in keys_to_pop:
|
||||||
|
optional_params.pop(k)
|
||||||
|
if task == "conversational":
|
||||||
|
inference_params = deepcopy(optional_params)
|
||||||
|
inference_params.pop("details")
|
||||||
|
inference_params.pop("return_full_text")
|
||||||
|
past_user_inputs = []
|
||||||
|
generated_responses = []
|
||||||
|
text = ""
|
||||||
|
for message in messages:
|
||||||
|
if message["role"] == "user":
|
||||||
|
if text != "":
|
||||||
|
past_user_inputs.append(text)
|
||||||
|
text = convert_content_list_to_str(message)
|
||||||
|
elif message["role"] == "assistant" or message["role"] == "system":
|
||||||
|
generated_responses.append(convert_content_list_to_str(message))
|
||||||
|
data = {
|
||||||
|
"inputs": {
|
||||||
|
"text": text,
|
||||||
|
"past_user_inputs": past_user_inputs,
|
||||||
|
"generated_responses": generated_responses,
|
||||||
|
},
|
||||||
|
"parameters": inference_params,
|
||||||
|
}
|
||||||
|
|
||||||
|
elif task == "text-generation-inference":
|
||||||
|
# always send "details" and "return_full_text" as params
|
||||||
|
if model in litellm.custom_prompt_dict:
|
||||||
|
# check if the model has a registered custom prompt
|
||||||
|
model_prompt_details = litellm.custom_prompt_dict[model]
|
||||||
|
prompt = custom_prompt(
|
||||||
|
role_dict=model_prompt_details.get("roles", None),
|
||||||
|
initial_prompt_value=model_prompt_details.get(
|
||||||
|
"initial_prompt_value", ""
|
||||||
|
),
|
||||||
|
final_prompt_value=model_prompt_details.get(
|
||||||
|
"final_prompt_value", ""
|
||||||
|
),
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt = prompt_factory(model=model, messages=messages)
|
||||||
|
data = {
|
||||||
|
"inputs": prompt, # type: ignore
|
||||||
|
"parameters": optional_params,
|
||||||
|
"stream": ( # type: ignore
|
||||||
|
True
|
||||||
|
if "stream" in optional_params
|
||||||
|
and isinstance(optional_params["stream"], bool)
|
||||||
|
and optional_params["stream"] is True # type: ignore
|
||||||
|
else False
|
||||||
|
),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# Non TGI and Conversational llms
|
||||||
|
# We need this branch, it removes 'details' and 'return_full_text' from params
|
||||||
|
if model in litellm.custom_prompt_dict:
|
||||||
|
# check if the model has a registered custom prompt
|
||||||
|
model_prompt_details = litellm.custom_prompt_dict[model]
|
||||||
|
prompt = custom_prompt(
|
||||||
|
role_dict=model_prompt_details.get("roles", {}),
|
||||||
|
initial_prompt_value=model_prompt_details.get(
|
||||||
|
"initial_prompt_value", ""
|
||||||
|
),
|
||||||
|
final_prompt_value=model_prompt_details.get(
|
||||||
|
"final_prompt_value", ""
|
||||||
|
),
|
||||||
|
bos_token=model_prompt_details.get("bos_token", ""),
|
||||||
|
eos_token=model_prompt_details.get("eos_token", ""),
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt = prompt_factory(model=model, messages=messages)
|
||||||
|
inference_params = deepcopy(optional_params)
|
||||||
|
inference_params.pop("details")
|
||||||
|
inference_params.pop("return_full_text")
|
||||||
|
data = {
|
||||||
|
"inputs": prompt, # type: ignore
|
||||||
|
}
|
||||||
|
if task == "text-generation-inference":
|
||||||
|
data["parameters"] = inference_params
|
||||||
|
data["stream"] = ( # type: ignore
|
||||||
|
True # type: ignore
|
||||||
|
if "stream" in optional_params and optional_params["stream"] is True
|
||||||
|
else False
|
||||||
|
)
|
||||||
|
|
||||||
|
### RE-ADD SPECIAL PARAMS
|
||||||
|
if len(special_params_dict.keys()) > 0:
|
||||||
|
data.update({"options": special_params_dict})
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def get_api_base(self, api_base: Optional[str], model: str) -> str:
|
||||||
|
"""
|
||||||
|
Get the API base for the Huggingface API.
|
||||||
|
|
||||||
|
Do not add the chat/embedding/rerank extension here. Let the handler do this.
|
||||||
|
"""
|
||||||
|
if "https" in model:
|
||||||
|
completion_url = model
|
||||||
|
elif api_base is not None:
|
||||||
|
completion_url = api_base
|
||||||
|
elif "HF_API_BASE" in os.environ:
|
||||||
|
completion_url = os.getenv("HF_API_BASE", "")
|
||||||
|
elif "HUGGINGFACE_API_BASE" in os.environ:
|
||||||
|
completion_url = os.getenv("HUGGINGFACE_API_BASE", "")
|
||||||
|
else:
|
||||||
|
completion_url = f"https://api-inference.huggingface.co/models/{model}"
|
||||||
|
|
||||||
|
return completion_url
|
||||||
|
|
||||||
|
def validate_environment(
|
||||||
|
self,
|
||||||
|
headers: Dict,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: Dict,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
) -> Dict:
|
||||||
|
default_headers = {
|
||||||
|
"content-type": "application/json",
|
||||||
|
}
|
||||||
|
if api_key is not None:
|
||||||
|
default_headers["Authorization"] = (
|
||||||
|
f"Bearer {api_key}" # Huggingface Inference Endpoint default is to accept bearer tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = {**headers, **default_headers}
|
||||||
|
return headers
|
||||||
|
|
||||||
|
def _transform_messages(
|
||||||
|
self,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
) -> List[AllMessageValues]:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def get_error_class(
|
||||||
|
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||||
|
) -> BaseLLMException:
|
||||||
|
return HuggingfaceError(
|
||||||
|
status_code=status_code, message=error_message, headers=headers
|
||||||
|
)
|
||||||
|
|
||||||
|
def _convert_streamed_response_to_complete_response(
|
||||||
|
self,
|
||||||
|
response: httpx.Response,
|
||||||
|
logging_obj: LoggingClass,
|
||||||
|
model: str,
|
||||||
|
data: dict,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
streamed_response = CustomStreamWrapper(
|
||||||
|
completion_stream=response.iter_lines(),
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="huggingface",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
content = ""
|
||||||
|
for chunk in streamed_response:
|
||||||
|
content += chunk["choices"][0]["delta"]["content"]
|
||||||
|
completion_response: List[Dict[str, Any]] = [{"generated_text": content}]
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=data,
|
||||||
|
api_key=api_key,
|
||||||
|
original_response=completion_response,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
return completion_response
|
||||||
|
|
||||||
|
def convert_to_model_response_object( # noqa: PLR0915
|
||||||
|
self,
|
||||||
|
completion_response: Union[List[Dict[str, Any]], Dict[str, Any]],
|
||||||
|
model_response: litellm.ModelResponse,
|
||||||
|
task: Optional[hf_tasks],
|
||||||
|
optional_params: dict,
|
||||||
|
encoding: Any,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
model: str,
|
||||||
|
):
|
||||||
|
if task is None:
|
||||||
|
task = "text-generation-inference" # default to tgi
|
||||||
|
|
||||||
|
if task == "conversational":
|
||||||
|
if len(completion_response["generated_text"]) > 0: # type: ignore
|
||||||
|
model_response.choices[0].message.content = completion_response[ # type: ignore
|
||||||
|
"generated_text"
|
||||||
|
]
|
||||||
|
elif task == "text-generation-inference":
|
||||||
|
if (
|
||||||
|
not isinstance(completion_response, list)
|
||||||
|
or not isinstance(completion_response[0], dict)
|
||||||
|
or "generated_text" not in completion_response[0]
|
||||||
|
):
|
||||||
|
raise HuggingfaceError(
|
||||||
|
status_code=422,
|
||||||
|
message=f"response is not in expected format - {completion_response}",
|
||||||
|
headers=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(completion_response[0]["generated_text"]) > 0:
|
||||||
|
model_response.choices[0].message.content = output_parser( # type: ignore
|
||||||
|
completion_response[0]["generated_text"]
|
||||||
|
)
|
||||||
|
## GETTING LOGPROBS + FINISH REASON
|
||||||
|
if (
|
||||||
|
"details" in completion_response[0]
|
||||||
|
and "tokens" in completion_response[0]["details"]
|
||||||
|
):
|
||||||
|
model_response.choices[0].finish_reason = completion_response[0][
|
||||||
|
"details"
|
||||||
|
]["finish_reason"]
|
||||||
|
sum_logprob = 0
|
||||||
|
for token in completion_response[0]["details"]["tokens"]:
|
||||||
|
if token["logprob"] is not None:
|
||||||
|
sum_logprob += token["logprob"]
|
||||||
|
setattr(model_response.choices[0].message, "_logprob", sum_logprob) # type: ignore
|
||||||
|
if "best_of" in optional_params and optional_params["best_of"] > 1:
|
||||||
|
if (
|
||||||
|
"details" in completion_response[0]
|
||||||
|
and "best_of_sequences" in completion_response[0]["details"]
|
||||||
|
):
|
||||||
|
choices_list = []
|
||||||
|
for idx, item in enumerate(
|
||||||
|
completion_response[0]["details"]["best_of_sequences"]
|
||||||
|
):
|
||||||
|
sum_logprob = 0
|
||||||
|
for token in item["tokens"]:
|
||||||
|
if token["logprob"] is not None:
|
||||||
|
sum_logprob += token["logprob"]
|
||||||
|
if len(item["generated_text"]) > 0:
|
||||||
|
message_obj = Message(
|
||||||
|
content=output_parser(item["generated_text"]),
|
||||||
|
logprobs=sum_logprob,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
message_obj = Message(content=None)
|
||||||
|
choice_obj = Choices(
|
||||||
|
finish_reason=item["finish_reason"],
|
||||||
|
index=idx + 1,
|
||||||
|
message=message_obj,
|
||||||
|
)
|
||||||
|
choices_list.append(choice_obj)
|
||||||
|
model_response.choices.extend(choices_list)
|
||||||
|
elif task == "text-classification":
|
||||||
|
model_response.choices[0].message.content = json.dumps( # type: ignore
|
||||||
|
completion_response
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if (
|
||||||
|
isinstance(completion_response, list)
|
||||||
|
and len(completion_response[0]["generated_text"]) > 0
|
||||||
|
):
|
||||||
|
model_response.choices[0].message.content = output_parser( # type: ignore
|
||||||
|
completion_response[0]["generated_text"]
|
||||||
|
)
|
||||||
|
## CALCULATING USAGE
|
||||||
|
prompt_tokens = 0
|
||||||
|
try:
|
||||||
|
prompt_tokens = token_counter(model=model, messages=messages)
|
||||||
|
except Exception:
|
||||||
|
# this should remain non blocking we should not block a response returning if calculating usage fails
|
||||||
|
pass
|
||||||
|
output_text = model_response["choices"][0]["message"].get("content", "")
|
||||||
|
if output_text is not None and len(output_text) > 0:
|
||||||
|
completion_tokens = 0
|
||||||
|
try:
|
||||||
|
completion_tokens = len(
|
||||||
|
encoding.encode(
|
||||||
|
model_response["choices"][0]["message"].get("content", "")
|
||||||
|
)
|
||||||
|
) ##[TODO] use the llama2 tokenizer here
|
||||||
|
except Exception:
|
||||||
|
# this should remain non blocking we should not block a response returning if calculating usage fails
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
completion_tokens = 0
|
||||||
|
|
||||||
|
model_response.created = int(time.time())
|
||||||
|
model_response.model = model
|
||||||
|
usage = Usage(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
)
|
||||||
|
setattr(model_response, "usage", usage)
|
||||||
|
model_response._hidden_params["original_response"] = completion_response
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
def transform_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
raw_response: httpx.Response,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
logging_obj: LoggingClass,
|
||||||
|
request_data: Dict,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: Dict,
|
||||||
|
litellm_params: Dict,
|
||||||
|
encoding: Any,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
json_mode: Optional[bool] = None,
|
||||||
|
) -> ModelResponse:
|
||||||
|
## Some servers might return streaming responses even though stream was not set to true. (e.g. Baseten)
|
||||||
|
task = litellm_params.get("task", None)
|
||||||
|
is_streamed = False
|
||||||
|
if (
|
||||||
|
raw_response.__dict__["headers"].get("Content-Type", "")
|
||||||
|
== "text/event-stream"
|
||||||
|
):
|
||||||
|
is_streamed = True
|
||||||
|
|
||||||
|
# iterate over the complete streamed response, and return the final answer
|
||||||
|
if is_streamed:
|
||||||
|
completion_response = self._convert_streamed_response_to_complete_response(
|
||||||
|
response=raw_response,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
model=model,
|
||||||
|
data=request_data,
|
||||||
|
api_key=api_key,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=request_data,
|
||||||
|
api_key=api_key,
|
||||||
|
original_response=raw_response.text,
|
||||||
|
additional_args={"complete_input_dict": request_data},
|
||||||
|
)
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
try:
|
||||||
|
completion_response = raw_response.json()
|
||||||
|
if isinstance(completion_response, dict):
|
||||||
|
completion_response = [completion_response]
|
||||||
|
except Exception:
|
||||||
|
raise HuggingfaceError(
|
||||||
|
message=f"Original Response received: {raw_response.text}",
|
||||||
|
status_code=raw_response.status_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(completion_response, dict) and "error" in completion_response:
|
||||||
|
raise HuggingfaceError(
|
||||||
|
message=completion_response["error"], # type: ignore
|
||||||
|
status_code=raw_response.status_code,
|
||||||
|
)
|
||||||
|
return self.convert_to_model_response_object(
|
||||||
|
completion_response=completion_response,
|
||||||
|
model_response=model_response,
|
||||||
|
task=task if task is not None and task in hf_task_list else None,
|
||||||
|
optional_params=optional_params,
|
||||||
|
encoding=encoding,
|
||||||
|
messages=messages,
|
||||||
|
model=model,
|
||||||
|
)
|
45
litellm/llms/huggingface/common_utils.py
Normal file
45
litellm/llms/huggingface/common_utils.py
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from litellm.llms.base_llm.transformation import BaseLLMException
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingfaceError(BaseLLMException):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
status_code: int,
|
||||||
|
message: str,
|
||||||
|
headers: Optional[Union[dict, httpx.Headers]] = None,
|
||||||
|
):
|
||||||
|
super().__init__(status_code=status_code, message=message, headers=headers)
|
||||||
|
|
||||||
|
|
||||||
|
hf_tasks = Literal[
|
||||||
|
"text-generation-inference",
|
||||||
|
"conversational",
|
||||||
|
"text-classification",
|
||||||
|
"text-generation",
|
||||||
|
]
|
||||||
|
|
||||||
|
hf_task_list = [
|
||||||
|
"text-generation-inference",
|
||||||
|
"conversational",
|
||||||
|
"text-classification",
|
||||||
|
"text-generation",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def output_parser(generated_text: str):
|
||||||
|
"""
|
||||||
|
Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens.
|
||||||
|
|
||||||
|
Initial issue that prompted this - https://github.com/BerriAI/litellm/issues/763
|
||||||
|
"""
|
||||||
|
chat_template_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "<s>", "</s>"]
|
||||||
|
for token in chat_template_tokens:
|
||||||
|
if generated_text.strip().startswith(token):
|
||||||
|
generated_text = generated_text.replace(token, "", 1)
|
||||||
|
if generated_text.endswith(token):
|
||||||
|
generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1]
|
||||||
|
return generated_text
|
File diff suppressed because it is too large
Load diff
|
@ -4,59 +4,42 @@ import time
|
||||||
import traceback
|
import traceback
|
||||||
import types
|
import types
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Callable, List, Optional
|
from typing import Any, Callable, List, Optional, Union
|
||||||
|
|
||||||
import requests # type: ignore
|
from httpx._models import Headers
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm.llms.base_llm.transformation import BaseLLMException
|
||||||
|
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||||
from litellm.utils import Choices, Message, ModelResponse, Usage
|
from litellm.utils import Choices, Message, ModelResponse, Usage
|
||||||
|
|
||||||
|
|
||||||
class MaritalkError(Exception):
|
class MaritalkError(BaseLLMException):
|
||||||
def __init__(self, status_code, message):
|
def __init__(
|
||||||
self.status_code = status_code
|
self,
|
||||||
self.message = message
|
status_code: int,
|
||||||
super().__init__(
|
message: str,
|
||||||
self.message
|
headers: Optional[Union[dict, Headers]] = None,
|
||||||
) # Call the base class constructor with the parameters it needs
|
):
|
||||||
|
super().__init__(status_code=status_code, message=message, headers=headers)
|
||||||
|
|
||||||
|
|
||||||
class MaritTalkConfig:
|
class MaritalkConfig(OpenAIGPTConfig):
|
||||||
"""
|
|
||||||
The class `MaritTalkConfig` provides configuration for the MaritTalk's API interface. Here are the parameters:
|
|
||||||
|
|
||||||
- `max_tokens` (integer): Maximum number of tokens the model will generate as part of the response. Default is 1.
|
|
||||||
|
|
||||||
- `model` (string): The model used for conversation. Default is 'maritalk'.
|
|
||||||
|
|
||||||
- `do_sample` (boolean): If set to True, the API will generate a response using sampling. Default is True.
|
|
||||||
|
|
||||||
- `temperature` (number): A non-negative float controlling the randomness in generation. Lower temperatures result in less random generations. Default is 0.7.
|
|
||||||
|
|
||||||
- `top_p` (number): Selection threshold for token inclusion based on cumulative probability. Default is 0.95.
|
|
||||||
|
|
||||||
- `repetition_penalty` (number): Penalty for repetition in the generated conversation. Default is 1.
|
|
||||||
|
|
||||||
- `stopping_tokens` (list of string): List of tokens where the conversation can be stopped/stopped.
|
|
||||||
"""
|
|
||||||
|
|
||||||
max_tokens: Optional[int] = None
|
|
||||||
model: Optional[str] = None
|
|
||||||
do_sample: Optional[bool] = None
|
|
||||||
temperature: Optional[float] = None
|
|
||||||
top_p: Optional[float] = None
|
|
||||||
repetition_penalty: Optional[float] = None
|
|
||||||
stopping_tokens: Optional[List[str]] = None
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
max_tokens: Optional[int] = None,
|
frequency_penalty: Optional[float] = None,
|
||||||
model: Optional[str] = None,
|
presence_penalty: Optional[float] = None,
|
||||||
do_sample: Optional[bool] = None,
|
|
||||||
temperature: Optional[float] = None,
|
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
repetition_penalty: Optional[float] = None,
|
top_k: Optional[int] = None,
|
||||||
stopping_tokens: Optional[List[str]] = None,
|
temperature: Optional[float] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
n: Optional[int] = None,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
stream_options: Optional[dict] = None,
|
||||||
|
tools: Optional[List[dict]] = None,
|
||||||
|
tool_choice: Optional[Union[str, dict]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
locals_ = locals()
|
locals_ = locals()
|
||||||
for key, value in locals_.items():
|
for key, value in locals_.items():
|
||||||
|
@ -65,129 +48,27 @@ class MaritTalkConfig:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config(cls):
|
def get_config(cls):
|
||||||
return {
|
return super().get_config()
|
||||||
k: v
|
|
||||||
for k, v in cls.__dict__.items()
|
|
||||||
if not k.startswith("__")
|
|
||||||
and not isinstance(
|
|
||||||
v,
|
|
||||||
(
|
|
||||||
types.FunctionType,
|
|
||||||
types.BuiltinFunctionType,
|
|
||||||
classmethod,
|
|
||||||
staticmethod,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
and v is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
|
def get_supported_openai_params(self, model: str) -> List:
|
||||||
|
return [
|
||||||
|
"frequency_penalty",
|
||||||
|
"presence_penalty",
|
||||||
|
"top_p",
|
||||||
|
"top_k",
|
||||||
|
"temperature",
|
||||||
|
"max_tokens",
|
||||||
|
"n",
|
||||||
|
"stop",
|
||||||
|
"stream",
|
||||||
|
"stream_options",
|
||||||
|
"tools",
|
||||||
|
"tool_choice",
|
||||||
|
]
|
||||||
|
|
||||||
def validate_environment(api_key):
|
def get_error_class(
|
||||||
headers = {
|
self, error_message: str, status_code: int, headers: Union[dict, Headers]
|
||||||
"accept": "application/json",
|
) -> BaseLLMException:
|
||||||
"content-type": "application/json",
|
return MaritalkError(
|
||||||
}
|
status_code=status_code, message=error_message, headers=headers
|
||||||
if api_key:
|
|
||||||
headers["Authorization"] = f"Key {api_key}"
|
|
||||||
return headers
|
|
||||||
|
|
||||||
|
|
||||||
def completion(
|
|
||||||
model: str,
|
|
||||||
messages: list,
|
|
||||||
api_base: str,
|
|
||||||
model_response: ModelResponse,
|
|
||||||
print_verbose: Callable,
|
|
||||||
encoding,
|
|
||||||
api_key,
|
|
||||||
logging_obj,
|
|
||||||
optional_params: dict,
|
|
||||||
litellm_params=None,
|
|
||||||
logger_fn=None,
|
|
||||||
):
|
|
||||||
headers = validate_environment(api_key)
|
|
||||||
completion_url = api_base
|
|
||||||
model = model
|
|
||||||
|
|
||||||
## Load Config
|
|
||||||
config = litellm.MaritTalkConfig.get_config()
|
|
||||||
for k, v in config.items():
|
|
||||||
if (
|
|
||||||
k not in optional_params
|
|
||||||
): # completion(top_k=3) > maritalk_config(top_k=3) <- allows for dynamic variables to be passed in
|
|
||||||
optional_params[k] = v
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"messages": messages,
|
|
||||||
**optional_params,
|
|
||||||
}
|
|
||||||
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.pre_call(
|
|
||||||
input=messages,
|
|
||||||
api_key=api_key,
|
|
||||||
additional_args={"complete_input_dict": data},
|
|
||||||
)
|
|
||||||
## COMPLETION CALL
|
|
||||||
response = requests.post(
|
|
||||||
completion_url,
|
|
||||||
headers=headers,
|
|
||||||
data=json.dumps(data),
|
|
||||||
stream=optional_params["stream"] if "stream" in optional_params else False,
|
|
||||||
)
|
|
||||||
if "stream" in optional_params and optional_params["stream"] is True:
|
|
||||||
return response.iter_lines()
|
|
||||||
else:
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.post_call(
|
|
||||||
input=messages,
|
|
||||||
api_key=api_key,
|
|
||||||
original_response=response.text,
|
|
||||||
additional_args={"complete_input_dict": data},
|
|
||||||
)
|
)
|
||||||
print_verbose(f"raw model_response: {response.text}")
|
|
||||||
## RESPONSE OBJECT
|
|
||||||
completion_response = response.json()
|
|
||||||
if "error" in completion_response:
|
|
||||||
raise MaritalkError(
|
|
||||||
message=completion_response["error"],
|
|
||||||
status_code=response.status_code,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
if len(completion_response["answer"]) > 0:
|
|
||||||
model_response.choices[0].message.content = completion_response[ # type: ignore
|
|
||||||
"answer"
|
|
||||||
]
|
|
||||||
except Exception:
|
|
||||||
raise MaritalkError(
|
|
||||||
message=response.text, status_code=response.status_code
|
|
||||||
)
|
|
||||||
|
|
||||||
## CALCULATING USAGE
|
|
||||||
prompt = "".join(m["content"] for m in messages)
|
|
||||||
prompt_tokens = len(encoding.encode(prompt))
|
|
||||||
completion_tokens = len(
|
|
||||||
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
|
||||||
)
|
|
||||||
|
|
||||||
model_response.created = int(time.time())
|
|
||||||
model_response.model = model
|
|
||||||
usage = Usage(
|
|
||||||
prompt_tokens=prompt_tokens,
|
|
||||||
completion_tokens=completion_tokens,
|
|
||||||
total_tokens=prompt_tokens + completion_tokens,
|
|
||||||
)
|
|
||||||
setattr(model_response, "usage", usage)
|
|
||||||
return model_response
|
|
||||||
|
|
||||||
|
|
||||||
def embedding(
|
|
||||||
model: str,
|
|
||||||
input: list,
|
|
||||||
api_key: Optional[str],
|
|
||||||
logging_obj: Any,
|
|
||||||
model_response=None,
|
|
||||||
encoding=None,
|
|
||||||
):
|
|
||||||
pass
|
|
||||||
|
|
|
@ -9,11 +9,16 @@ Docs - https://docs.mistral.ai/api/
|
||||||
import types
|
import types
|
||||||
from typing import List, Literal, Optional, Tuple, Union
|
from typing import List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||||
|
from litellm.llms.prompt_templates.common_utils import (
|
||||||
|
handle_messages_with_content_list_to_str_conversion,
|
||||||
|
strip_none_values_from_message,
|
||||||
|
)
|
||||||
from litellm.secret_managers.main import get_secret_str
|
from litellm.secret_managers.main import get_secret_str
|
||||||
from litellm.types.llms.openai import AllMessageValues
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
|
|
||||||
|
|
||||||
class MistralConfig:
|
class MistralConfig(OpenAIGPTConfig):
|
||||||
"""
|
"""
|
||||||
Reference: https://docs.mistral.ai/api/
|
Reference: https://docs.mistral.ai/api/
|
||||||
|
|
||||||
|
@ -67,23 +72,9 @@ class MistralConfig:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config(cls):
|
def get_config(cls):
|
||||||
return {
|
return super().get_config()
|
||||||
k: v
|
|
||||||
for k, v in cls.__dict__.items()
|
|
||||||
if not k.startswith("__")
|
|
||||||
and not isinstance(
|
|
||||||
v,
|
|
||||||
(
|
|
||||||
types.FunctionType,
|
|
||||||
types.BuiltinFunctionType,
|
|
||||||
classmethod,
|
|
||||||
staticmethod,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
and v is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_supported_openai_params(self):
|
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||||
return [
|
return [
|
||||||
"stream",
|
"stream",
|
||||||
"temperature",
|
"temperature",
|
||||||
|
@ -104,7 +95,13 @@ class MistralConfig:
|
||||||
else: # openai 'tool_choice' object param not supported by Mistral API
|
else: # openai 'tool_choice' object param not supported by Mistral API
|
||||||
return "any"
|
return "any"
|
||||||
|
|
||||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
def map_openai_params(
|
||||||
|
self,
|
||||||
|
non_default_params: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
model: str,
|
||||||
|
drop_params: bool,
|
||||||
|
) -> dict:
|
||||||
for param, value in non_default_params.items():
|
for param, value in non_default_params.items():
|
||||||
if param == "max_tokens":
|
if param == "max_tokens":
|
||||||
optional_params["max_tokens"] = value
|
optional_params["max_tokens"] = value
|
||||||
|
@ -150,8 +147,9 @@ class MistralConfig:
|
||||||
)
|
)
|
||||||
return api_base, dynamic_api_key
|
return api_base, dynamic_api_key
|
||||||
|
|
||||||
@classmethod
|
def _transform_messages(
|
||||||
def _transform_messages(cls, messages: List[AllMessageValues]):
|
self, messages: List[AllMessageValues]
|
||||||
|
) -> List[AllMessageValues]:
|
||||||
"""
|
"""
|
||||||
- handles scenario where content is list and not string
|
- handles scenario where content is list and not string
|
||||||
- content list is just text, and no images
|
- content list is just text, and no images
|
||||||
|
@ -160,48 +158,36 @@ class MistralConfig:
|
||||||
|
|
||||||
Motivation: mistral api doesn't support content as a list
|
Motivation: mistral api doesn't support content as a list
|
||||||
"""
|
"""
|
||||||
new_messages = []
|
## 1. If 'image_url' in content, then return as is
|
||||||
for m in messages:
|
for m in messages:
|
||||||
special_keys = ["role", "content", "tool_calls", "function_call"]
|
_content_block = m.get("content")
|
||||||
extra_args = {}
|
if _content_block and isinstance(_content_block, list):
|
||||||
if isinstance(m, dict):
|
for c in _content_block:
|
||||||
for k, v in m.items():
|
if c.get("type") == "image_url":
|
||||||
if k not in special_keys:
|
|
||||||
extra_args[k] = v
|
|
||||||
texts = ""
|
|
||||||
_content = m.get("content")
|
|
||||||
if _content is not None and isinstance(_content, list):
|
|
||||||
for c in _content:
|
|
||||||
_text: Optional[str] = c.get("text")
|
|
||||||
if c["type"] == "image_url":
|
|
||||||
return messages
|
return messages
|
||||||
elif c["type"] == "text" and isinstance(_text, str):
|
|
||||||
texts += _text
|
|
||||||
elif _content is not None and isinstance(_content, str):
|
|
||||||
texts = _content
|
|
||||||
|
|
||||||
new_m = {"role": m["role"], "content": texts, **extra_args}
|
## 2. If content is list, then convert to string
|
||||||
|
messages = handle_messages_with_content_list_to_str_conversion(messages)
|
||||||
|
|
||||||
if m.get("tool_calls"):
|
## 3. Handle name in message
|
||||||
new_m["tool_calls"] = m.get("tool_calls")
|
new_messages: List[AllMessageValues] = []
|
||||||
|
for m in messages:
|
||||||
|
m = MistralConfig._handle_name_in_message(m)
|
||||||
|
m = strip_none_values_from_message(m) # prevents 'extra_forbidden' error
|
||||||
|
new_messages.append(m)
|
||||||
|
|
||||||
new_m = cls._handle_name_in_message(new_m)
|
|
||||||
|
|
||||||
new_messages.append(new_m)
|
|
||||||
return new_messages
|
return new_messages
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _handle_name_in_message(cls, message: dict) -> dict:
|
def _handle_name_in_message(cls, message: AllMessageValues) -> AllMessageValues:
|
||||||
"""
|
"""
|
||||||
Mistral API only supports `name` in tool messages
|
Mistral API only supports `name` in tool messages
|
||||||
|
|
||||||
If role == tool, then we keep `name`
|
If role == tool, then we keep `name`
|
||||||
Otherwise, we drop `name`
|
Otherwise, we drop `name`
|
||||||
"""
|
"""
|
||||||
if message.get("name") is not None:
|
_name = message.get("name") # type: ignore
|
||||||
if message["role"] == "tool":
|
if _name is not None and message["role"] != "tool":
|
||||||
message["name"] = message.get("name")
|
message.pop("name", None) # type: ignore
|
||||||
else:
|
|
||||||
message.pop("name", None)
|
|
||||||
|
|
||||||
return message
|
return message
|
||||||
|
|
140
litellm/llms/nlp_cloud/chat/handler.py
Normal file
140
litellm/llms/nlp_cloud/chat/handler.py
Normal file
|
@ -0,0 +1,140 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import types
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Callable, List, Optional, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
||||||
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
AsyncHTTPHandler,
|
||||||
|
HTTPHandler,
|
||||||
|
_get_httpx_client,
|
||||||
|
get_async_httpx_client,
|
||||||
|
)
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
|
from litellm.utils import ModelResponse, Usage
|
||||||
|
|
||||||
|
from ..common_utils import NLPCloudError
|
||||||
|
from .transformation import NLPCloudConfig
|
||||||
|
|
||||||
|
nlp_config = NLPCloudConfig()
|
||||||
|
|
||||||
|
|
||||||
|
def completion(
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
encoding,
|
||||||
|
api_key,
|
||||||
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
logger_fn=None,
|
||||||
|
default_max_tokens_to_sample=None,
|
||||||
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||||
|
headers={},
|
||||||
|
):
|
||||||
|
headers = nlp_config.validate_environment(
|
||||||
|
api_key=api_key,
|
||||||
|
headers=headers,
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
## Load Config
|
||||||
|
config = litellm.NLPCloudConfig.get_config()
|
||||||
|
for k, v in config.items():
|
||||||
|
if (
|
||||||
|
k not in optional_params
|
||||||
|
): # completion(top_k=3) > togetherai_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
|
optional_params[k] = v
|
||||||
|
|
||||||
|
completion_url_fragment_1 = api_base
|
||||||
|
completion_url_fragment_2 = "/generation"
|
||||||
|
model = model
|
||||||
|
|
||||||
|
completion_url = completion_url_fragment_1 + model + completion_url_fragment_2
|
||||||
|
data = nlp_config.transform_request(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=None,
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": data,
|
||||||
|
"headers": headers,
|
||||||
|
"api_base": completion_url,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
## COMPLETION CALL
|
||||||
|
if client is None or not isinstance(client, HTTPHandler):
|
||||||
|
client = _get_httpx_client()
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
completion_url,
|
||||||
|
headers=headers,
|
||||||
|
data=json.dumps(data),
|
||||||
|
stream=optional_params["stream"] if "stream" in optional_params else False,
|
||||||
|
)
|
||||||
|
if "stream" in optional_params and optional_params["stream"] is True:
|
||||||
|
return clean_and_iterate_chunks(response)
|
||||||
|
else:
|
||||||
|
return nlp_config.transform_response(
|
||||||
|
model=model,
|
||||||
|
raw_response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_key=api_key,
|
||||||
|
request_data=data,
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# def clean_and_iterate_chunks(response):
|
||||||
|
# def process_chunk(chunk):
|
||||||
|
# print(f"received chunk: {chunk}")
|
||||||
|
# cleaned_chunk = chunk.decode("utf-8")
|
||||||
|
# # Perform further processing based on your needs
|
||||||
|
# return cleaned_chunk
|
||||||
|
|
||||||
|
|
||||||
|
# for line in response.iter_lines():
|
||||||
|
# if line:
|
||||||
|
# yield process_chunk(line)
|
||||||
|
def clean_and_iterate_chunks(response):
|
||||||
|
buffer = b""
|
||||||
|
|
||||||
|
for chunk in response.iter_content(chunk_size=1024):
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
|
||||||
|
buffer += chunk
|
||||||
|
while b"\x00" in buffer:
|
||||||
|
buffer = buffer.replace(b"\x00", b"")
|
||||||
|
yield buffer.decode("utf-8")
|
||||||
|
buffer = b""
|
||||||
|
|
||||||
|
# No more data expected, yield any remaining data in the buffer
|
||||||
|
if buffer:
|
||||||
|
yield buffer.decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def embedding():
|
||||||
|
# logic for parsing in - calling - parsing out model embedding calls
|
||||||
|
pass
|
|
@ -1,26 +1,25 @@
|
||||||
import json
|
import json
|
||||||
import os
|
|
||||||
import time
|
import time
|
||||||
import types
|
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||||
from enum import Enum
|
|
||||||
from typing import Callable, Optional
|
|
||||||
|
|
||||||
import requests # type: ignore
|
import httpx
|
||||||
|
|
||||||
import litellm
|
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
||||||
|
from litellm.llms.prompt_templates.common_utils import convert_content_list_to_str
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
from litellm.utils import ModelResponse, Usage
|
from litellm.utils import ModelResponse, Usage
|
||||||
|
|
||||||
|
from ..common_utils import NLPCloudError
|
||||||
|
|
||||||
class NLPCloudError(Exception):
|
if TYPE_CHECKING:
|
||||||
def __init__(self, status_code, message):
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
self.status_code = status_code
|
|
||||||
self.message = message
|
LoggingClass = LiteLLMLoggingObj
|
||||||
super().__init__(
|
else:
|
||||||
self.message
|
LoggingClass = Any
|
||||||
) # Call the base class constructor with the parameters it needs
|
|
||||||
|
|
||||||
|
|
||||||
class NLPCloudConfig:
|
class NLPCloudConfig(BaseConfig):
|
||||||
"""
|
"""
|
||||||
Reference: https://docs.nlpcloud.com/#generation
|
Reference: https://docs.nlpcloud.com/#generation
|
||||||
|
|
||||||
|
@ -84,106 +83,119 @@ class NLPCloudConfig:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config(cls):
|
def get_config(cls):
|
||||||
return {
|
return super().get_config()
|
||||||
k: v
|
|
||||||
for k, v in cls.__dict__.items()
|
def validate_environment(
|
||||||
if not k.startswith("__")
|
self,
|
||||||
and not isinstance(
|
headers: dict,
|
||||||
v,
|
model: str,
|
||||||
(
|
messages: List[AllMessageValues],
|
||||||
types.FunctionType,
|
optional_params: dict,
|
||||||
types.BuiltinFunctionType,
|
api_key: Optional[str] = None,
|
||||||
classmethod,
|
) -> dict:
|
||||||
staticmethod,
|
headers = {
|
||||||
),
|
"accept": "application/json",
|
||||||
)
|
"content-type": "application/json",
|
||||||
and v is not None
|
}
|
||||||
|
if api_key:
|
||||||
|
headers["Authorization"] = f"Token {api_key}"
|
||||||
|
return headers
|
||||||
|
|
||||||
|
def get_supported_openai_params(self, model: str) -> List:
|
||||||
|
return [
|
||||||
|
"max_tokens",
|
||||||
|
"stream",
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"presence_penalty",
|
||||||
|
"frequency_penalty",
|
||||||
|
"n",
|
||||||
|
"stop",
|
||||||
|
]
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self,
|
||||||
|
non_default_params: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
model: str,
|
||||||
|
drop_params: bool,
|
||||||
|
) -> dict:
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param == "max_tokens":
|
||||||
|
optional_params["max_length"] = value
|
||||||
|
if param == "stream":
|
||||||
|
optional_params["stream"] = value
|
||||||
|
if param == "temperature":
|
||||||
|
optional_params["temperature"] = value
|
||||||
|
if param == "top_p":
|
||||||
|
optional_params["top_p"] = value
|
||||||
|
if param == "presence_penalty":
|
||||||
|
optional_params["presence_penalty"] = value
|
||||||
|
if param == "frequency_penalty":
|
||||||
|
optional_params["frequency_penalty"] = value
|
||||||
|
if param == "n":
|
||||||
|
optional_params["num_return_sequences"] = value
|
||||||
|
if param == "stop":
|
||||||
|
optional_params["stop_sequences"] = value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
def get_error_class(
|
||||||
|
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||||
|
) -> BaseLLMException:
|
||||||
|
return NLPCloudError(
|
||||||
|
status_code=status_code, message=error_message, headers=headers
|
||||||
|
)
|
||||||
|
|
||||||
|
def transform_request(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
headers: dict,
|
||||||
|
) -> dict:
|
||||||
|
text = " ".join(convert_content_list_to_str(message) for message in messages)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"text": text,
|
||||||
|
**optional_params,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
def validate_environment(api_key):
|
def transform_response(
|
||||||
headers = {
|
self,
|
||||||
"accept": "application/json",
|
model: str,
|
||||||
"content-type": "application/json",
|
raw_response: httpx.Response,
|
||||||
}
|
model_response: ModelResponse,
|
||||||
if api_key:
|
logging_obj: LoggingClass,
|
||||||
headers["Authorization"] = f"Token {api_key}"
|
request_data: dict,
|
||||||
return headers
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
def completion(
|
encoding: Any,
|
||||||
model: str,
|
api_key: Optional[str] = None,
|
||||||
messages: list,
|
json_mode: Optional[bool] = None,
|
||||||
api_base: str,
|
) -> ModelResponse:
|
||||||
model_response: ModelResponse,
|
|
||||||
print_verbose: Callable,
|
|
||||||
encoding,
|
|
||||||
api_key,
|
|
||||||
logging_obj,
|
|
||||||
optional_params: dict,
|
|
||||||
litellm_params=None,
|
|
||||||
logger_fn=None,
|
|
||||||
default_max_tokens_to_sample=None,
|
|
||||||
):
|
|
||||||
headers = validate_environment(api_key)
|
|
||||||
|
|
||||||
## Load Config
|
|
||||||
config = litellm.NLPCloudConfig.get_config()
|
|
||||||
for k, v in config.items():
|
|
||||||
if (
|
|
||||||
k not in optional_params
|
|
||||||
): # completion(top_k=3) > togetherai_config(top_k=3) <- allows for dynamic variables to be passed in
|
|
||||||
optional_params[k] = v
|
|
||||||
|
|
||||||
completion_url_fragment_1 = api_base
|
|
||||||
completion_url_fragment_2 = "/generation"
|
|
||||||
model = model
|
|
||||||
text = " ".join(message["content"] for message in messages)
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"text": text,
|
|
||||||
**optional_params,
|
|
||||||
}
|
|
||||||
|
|
||||||
completion_url = completion_url_fragment_1 + model + completion_url_fragment_2
|
|
||||||
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.pre_call(
|
|
||||||
input=text,
|
|
||||||
api_key=api_key,
|
|
||||||
additional_args={
|
|
||||||
"complete_input_dict": data,
|
|
||||||
"headers": headers,
|
|
||||||
"api_base": completion_url,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
## COMPLETION CALL
|
|
||||||
response = requests.post(
|
|
||||||
completion_url,
|
|
||||||
headers=headers,
|
|
||||||
data=json.dumps(data),
|
|
||||||
stream=optional_params["stream"] if "stream" in optional_params else False,
|
|
||||||
)
|
|
||||||
if "stream" in optional_params and optional_params["stream"] is True:
|
|
||||||
return clean_and_iterate_chunks(response)
|
|
||||||
else:
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=text,
|
input=None,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
original_response=response.text,
|
original_response=raw_response.text,
|
||||||
additional_args={"complete_input_dict": data},
|
additional_args={"complete_input_dict": request_data},
|
||||||
)
|
)
|
||||||
print_verbose(f"raw model_response: {response.text}")
|
|
||||||
## RESPONSE OBJECT
|
## RESPONSE OBJECT
|
||||||
try:
|
try:
|
||||||
completion_response = response.json()
|
completion_response = raw_response.json()
|
||||||
except Exception:
|
except Exception:
|
||||||
raise NLPCloudError(message=response.text, status_code=response.status_code)
|
raise NLPCloudError(
|
||||||
|
message=raw_response.text, status_code=raw_response.status_code
|
||||||
|
)
|
||||||
if "error" in completion_response:
|
if "error" in completion_response:
|
||||||
raise NLPCloudError(
|
raise NLPCloudError(
|
||||||
message=completion_response["error"],
|
message=completion_response["error"],
|
||||||
status_code=response.status_code,
|
status_code=raw_response.status_code,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
|
@ -194,7 +206,7 @@ def completion(
|
||||||
except Exception:
|
except Exception:
|
||||||
raise NLPCloudError(
|
raise NLPCloudError(
|
||||||
message=json.dumps(completion_response),
|
message=json.dumps(completion_response),
|
||||||
status_code=response.status_code,
|
status_code=raw_response.status_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
||||||
|
@ -210,37 +222,3 @@ def completion(
|
||||||
)
|
)
|
||||||
setattr(model_response, "usage", usage)
|
setattr(model_response, "usage", usage)
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
|
||||||
# def clean_and_iterate_chunks(response):
|
|
||||||
# def process_chunk(chunk):
|
|
||||||
# print(f"received chunk: {chunk}")
|
|
||||||
# cleaned_chunk = chunk.decode("utf-8")
|
|
||||||
# # Perform further processing based on your needs
|
|
||||||
# return cleaned_chunk
|
|
||||||
|
|
||||||
|
|
||||||
# for line in response.iter_lines():
|
|
||||||
# if line:
|
|
||||||
# yield process_chunk(line)
|
|
||||||
def clean_and_iterate_chunks(response):
|
|
||||||
buffer = b""
|
|
||||||
|
|
||||||
for chunk in response.iter_content(chunk_size=1024):
|
|
||||||
if not chunk:
|
|
||||||
break
|
|
||||||
|
|
||||||
buffer += chunk
|
|
||||||
while b"\x00" in buffer:
|
|
||||||
buffer = buffer.replace(b"\x00", b"")
|
|
||||||
yield buffer.decode("utf-8")
|
|
||||||
buffer = b""
|
|
||||||
|
|
||||||
# No more data expected, yield any remaining data in the buffer
|
|
||||||
if buffer:
|
|
||||||
yield buffer.decode("utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
def embedding():
|
|
||||||
# logic for parsing in - calling - parsing out model embedding calls
|
|
||||||
pass
|
|
15
litellm/llms/nlp_cloud/common_utils.py
Normal file
15
litellm/llms/nlp_cloud/common_utils.py
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from litellm.llms.base_llm.transformation import BaseLLMException
|
||||||
|
|
||||||
|
|
||||||
|
class NLPCloudError(BaseLLMException):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
status_code: int,
|
||||||
|
message: str,
|
||||||
|
headers: Optional[Union[dict, httpx.Headers]] = None,
|
||||||
|
):
|
||||||
|
super().__init__(status_code=status_code, message=message, headers=headers)
|
|
@ -11,8 +11,10 @@ API calling is done using the OpenAI SDK with an api_base
|
||||||
import types
|
import types
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||||
|
|
||||||
class NvidiaNimConfig:
|
|
||||||
|
class NvidiaNimConfig(OpenAIGPTConfig):
|
||||||
"""
|
"""
|
||||||
Reference: https://docs.api.nvidia.com/nim/reference/databricks-dbrx-instruct-infer
|
Reference: https://docs.api.nvidia.com/nim/reference/databricks-dbrx-instruct-infer
|
||||||
|
|
||||||
|
@ -42,21 +44,7 @@ class NvidiaNimConfig:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config(cls):
|
def get_config(cls):
|
||||||
return {
|
return super().get_config()
|
||||||
k: v
|
|
||||||
for k, v in cls.__dict__.items()
|
|
||||||
if not k.startswith("__")
|
|
||||||
and not isinstance(
|
|
||||||
v,
|
|
||||||
(
|
|
||||||
types.FunctionType,
|
|
||||||
types.BuiltinFunctionType,
|
|
||||||
classmethod,
|
|
||||||
staticmethod,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
and v is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_supported_openai_params(self, model: str) -> list:
|
def get_supported_openai_params(self, model: str) -> list:
|
||||||
"""
|
"""
|
||||||
|
@ -132,7 +120,11 @@ class NvidiaNimConfig:
|
||||||
]
|
]
|
||||||
|
|
||||||
def map_openai_params(
|
def map_openai_params(
|
||||||
self, model: str, non_default_params: dict, optional_params: dict
|
self,
|
||||||
|
non_default_params: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
model: str,
|
||||||
|
drop_params: bool,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
supported_openai_params = self.get_supported_openai_params(model=model)
|
supported_openai_params = self.get_supported_openai_params(model=model)
|
||||||
for param, value in non_default_params.items():
|
for param, value in non_default_params.items():
|
||||||
|
|
|
@ -242,6 +242,7 @@ class OllamaConfig(BaseConfig):
|
||||||
request_data: dict,
|
request_data: dict,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
encoding: str,
|
encoding: str,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
json_mode: Optional[bool] = None,
|
json_mode: Optional[bool] = None,
|
||||||
|
|
|
@ -14,6 +14,7 @@ from pydantic import BaseModel
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import verbose_logger
|
from litellm import verbose_logger
|
||||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||||
|
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||||
from litellm.types.llms.ollama import OllamaToolCall, OllamaToolCallFunction
|
from litellm.types.llms.ollama import OllamaToolCall, OllamaToolCallFunction
|
||||||
from litellm.types.llms.openai import ChatCompletionAssistantToolCall
|
from litellm.types.llms.openai import ChatCompletionAssistantToolCall
|
||||||
from litellm.types.utils import StreamingChoices
|
from litellm.types.utils import StreamingChoices
|
||||||
|
@ -30,7 +31,7 @@ class OllamaError(Exception):
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
class OllamaChatConfig:
|
class OllamaChatConfig(OpenAIGPTConfig):
|
||||||
"""
|
"""
|
||||||
Reference: https://github.com/ollama/ollama/blob/main/docs/api.md#parameters
|
Reference: https://github.com/ollama/ollama/blob/main/docs/api.md#parameters
|
||||||
|
|
||||||
|
@ -81,15 +82,10 @@ class OllamaChatConfig:
|
||||||
num_thread: Optional[int] = None
|
num_thread: Optional[int] = None
|
||||||
repeat_last_n: Optional[int] = None
|
repeat_last_n: Optional[int] = None
|
||||||
repeat_penalty: Optional[float] = None
|
repeat_penalty: Optional[float] = None
|
||||||
temperature: Optional[float] = None
|
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
stop: Optional[list] = (
|
|
||||||
None # stop is a list based on this - https://github.com/ollama/ollama/pull/442
|
|
||||||
)
|
|
||||||
tfs_z: Optional[float] = None
|
tfs_z: Optional[float] = None
|
||||||
num_predict: Optional[int] = None
|
num_predict: Optional[int] = None
|
||||||
top_k: Optional[int] = None
|
top_k: Optional[int] = None
|
||||||
top_p: Optional[float] = None
|
|
||||||
system: Optional[str] = None
|
system: Optional[str] = None
|
||||||
template: Optional[str] = None
|
template: Optional[str] = None
|
||||||
|
|
||||||
|
@ -120,26 +116,9 @@ class OllamaChatConfig:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config(cls):
|
def get_config(cls):
|
||||||
return {
|
return super().get_config()
|
||||||
k: v
|
|
||||||
for k, v in cls.__dict__.items()
|
|
||||||
if not k.startswith("__")
|
|
||||||
and k != "function_name" # special param for function calling
|
|
||||||
and not isinstance(
|
|
||||||
v,
|
|
||||||
(
|
|
||||||
types.FunctionType,
|
|
||||||
types.BuiltinFunctionType,
|
|
||||||
classmethod,
|
|
||||||
staticmethod,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
and v is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_supported_openai_params(
|
def get_supported_openai_params(self, model: str):
|
||||||
self,
|
|
||||||
):
|
|
||||||
return [
|
return [
|
||||||
"max_tokens",
|
"max_tokens",
|
||||||
"max_completion_tokens",
|
"max_completion_tokens",
|
||||||
|
@ -156,8 +135,12 @@ class OllamaChatConfig:
|
||||||
]
|
]
|
||||||
|
|
||||||
def map_openai_params(
|
def map_openai_params(
|
||||||
self, model: str, non_default_params: dict, optional_params: dict
|
self,
|
||||||
):
|
non_default_params: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
model: str,
|
||||||
|
drop_params: bool,
|
||||||
|
) -> dict:
|
||||||
for param, value in non_default_params.items():
|
for param, value in non_default_params.items():
|
||||||
if param == "max_tokens" or param == "max_completion_tokens":
|
if param == "max_tokens" or param == "max_completion_tokens":
|
||||||
optional_params["num_predict"] = value
|
optional_params["num_predict"] = value
|
||||||
|
|
|
@ -6,28 +6,14 @@ from typing import Any, Callable, Optional
|
||||||
|
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
|
|
||||||
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler, _get_httpx_client
|
||||||
from litellm.utils import EmbeddingResponse, ModelResponse, Usage
|
from litellm.utils import EmbeddingResponse, ModelResponse, Usage
|
||||||
|
|
||||||
from .prompt_templates.factory import custom_prompt, prompt_factory
|
from ...prompt_templates.factory import custom_prompt, prompt_factory
|
||||||
|
from ..common_utils import OobaboogaError
|
||||||
|
from .transformation import OobaboogaConfig
|
||||||
|
|
||||||
|
oobabooga_config = OobaboogaConfig()
|
||||||
class OobaboogaError(Exception):
|
|
||||||
def __init__(self, status_code, message):
|
|
||||||
self.status_code = status_code
|
|
||||||
self.message = message
|
|
||||||
super().__init__(
|
|
||||||
self.message
|
|
||||||
) # Call the base class constructor with the parameters it needs
|
|
||||||
|
|
||||||
|
|
||||||
def validate_environment(api_key):
|
|
||||||
headers = {
|
|
||||||
"accept": "application/json",
|
|
||||||
"content-type": "application/json",
|
|
||||||
}
|
|
||||||
if api_key:
|
|
||||||
headers["Authorization"] = f"Token {api_key}"
|
|
||||||
return headers
|
|
||||||
|
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
|
@ -40,12 +26,18 @@ def completion(
|
||||||
api_key,
|
api_key,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
custom_prompt_dict={},
|
custom_prompt_dict={},
|
||||||
litellm_params=None,
|
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
default_max_tokens_to_sample=None,
|
default_max_tokens_to_sample=None,
|
||||||
):
|
):
|
||||||
headers = validate_environment(api_key)
|
headers = oobabooga_config.validate_environment(
|
||||||
|
api_key=api_key,
|
||||||
|
headers={},
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
)
|
||||||
if "https" in model:
|
if "https" in model:
|
||||||
completion_url = model
|
completion_url = model
|
||||||
elif api_base:
|
elif api_base:
|
||||||
|
@ -58,10 +50,13 @@ def completion(
|
||||||
model = model
|
model = model
|
||||||
|
|
||||||
completion_url = completion_url + "/v1/chat/completions"
|
completion_url = completion_url + "/v1/chat/completions"
|
||||||
data = {
|
data = oobabooga_config.transform_request(
|
||||||
"messages": messages,
|
model=model,
|
||||||
**optional_params,
|
messages=messages,
|
||||||
}
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
|
@ -70,8 +65,8 @@ def completion(
|
||||||
additional_args={"complete_input_dict": data},
|
additional_args={"complete_input_dict": data},
|
||||||
)
|
)
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
|
client = _get_httpx_client()
|
||||||
response = requests.post(
|
response = client.post(
|
||||||
completion_url,
|
completion_url,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
data=json.dumps(data),
|
data=json.dumps(data),
|
||||||
|
@ -80,44 +75,18 @@ def completion(
|
||||||
if "stream" in optional_params and optional_params["stream"] is True:
|
if "stream" in optional_params and optional_params["stream"] is True:
|
||||||
return response.iter_lines()
|
return response.iter_lines()
|
||||||
else:
|
else:
|
||||||
## LOGGING
|
return oobabooga_config.transform_response(
|
||||||
logging_obj.post_call(
|
model=model,
|
||||||
input=messages,
|
raw_response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
logging_obj=logging_obj,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
original_response=response.text,
|
request_data=data,
|
||||||
additional_args={"complete_input_dict": data},
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
encoding=encoding,
|
||||||
)
|
)
|
||||||
print_verbose(f"raw model_response: {response.text}")
|
|
||||||
## RESPONSE OBJECT
|
|
||||||
try:
|
|
||||||
completion_response = response.json()
|
|
||||||
except Exception:
|
|
||||||
raise OobaboogaError(
|
|
||||||
message=response.text, status_code=response.status_code
|
|
||||||
)
|
|
||||||
if "error" in completion_response:
|
|
||||||
raise OobaboogaError(
|
|
||||||
message=completion_response["error"],
|
|
||||||
status_code=response.status_code,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
model_response.choices[0].message.content = completion_response["choices"][0]["message"]["content"] # type: ignore
|
|
||||||
except Exception:
|
|
||||||
raise OobaboogaError(
|
|
||||||
message=json.dumps(completion_response),
|
|
||||||
status_code=response.status_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
model_response.created = int(time.time())
|
|
||||||
model_response.model = model
|
|
||||||
usage = Usage(
|
|
||||||
prompt_tokens=completion_response["usage"]["prompt_tokens"],
|
|
||||||
completion_tokens=completion_response["usage"]["completion_tokens"],
|
|
||||||
total_tokens=completion_response["usage"]["total_tokens"],
|
|
||||||
)
|
|
||||||
setattr(model_response, "usage", usage)
|
|
||||||
return model_response
|
|
||||||
|
|
||||||
|
|
||||||
def embedding(
|
def embedding(
|
||||||
|
@ -127,7 +96,7 @@ def embedding(
|
||||||
api_key: Optional[str],
|
api_key: Optional[str],
|
||||||
api_base: Optional[str],
|
api_base: Optional[str],
|
||||||
logging_obj: Any,
|
logging_obj: Any,
|
||||||
optional_params=None,
|
optional_params: dict,
|
||||||
encoding=None,
|
encoding=None,
|
||||||
):
|
):
|
||||||
# Create completion URL
|
# Create completion URL
|
||||||
|
@ -153,7 +122,13 @@ def embedding(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send POST request
|
# Send POST request
|
||||||
headers = validate_environment(api_key)
|
headers = oobabooga_config.validate_environment(
|
||||||
|
api_key=api_key,
|
||||||
|
headers={},
|
||||||
|
model=model,
|
||||||
|
messages=[],
|
||||||
|
optional_params=optional_params,
|
||||||
|
)
|
||||||
response = requests.post(embeddings_url, headers=headers, json=data)
|
response = requests.post(embeddings_url, headers=headers, json=data)
|
||||||
if not response.ok:
|
if not response.ok:
|
||||||
raise OobaboogaError(message=response.text, status_code=response.status_code)
|
raise OobaboogaError(message=response.text, status_code=response.status_code)
|
110
litellm/llms/oobabooga/chat/transformation.py
Normal file
110
litellm/llms/oobabooga/chat/transformation.py
Normal file
|
@ -0,0 +1,110 @@
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import types
|
||||||
|
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
||||||
|
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||||
|
from litellm.llms.prompt_templates.common_utils import convert_content_list_to_str
|
||||||
|
from litellm.llms.prompt_templates.factory import custom_prompt, prompt_factory
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
|
from litellm.types.utils import Choices, Message, ModelResponse, Usage
|
||||||
|
from litellm.utils import token_counter
|
||||||
|
|
||||||
|
from ..common_utils import OobaboogaError
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
|
|
||||||
|
LoggingClass = LiteLLMLoggingObj
|
||||||
|
else:
|
||||||
|
LoggingClass = Any
|
||||||
|
|
||||||
|
|
||||||
|
class OobaboogaConfig(OpenAIGPTConfig):
|
||||||
|
def _transform_messages(
|
||||||
|
self, messages: List[AllMessageValues]
|
||||||
|
) -> List[AllMessageValues]:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def get_error_class(
|
||||||
|
self,
|
||||||
|
error_message: str,
|
||||||
|
status_code: int,
|
||||||
|
headers: Optional[Union[dict, httpx.Headers]] = None,
|
||||||
|
) -> BaseLLMException:
|
||||||
|
return OobaboogaError(
|
||||||
|
status_code=status_code, message=error_message, headers=headers
|
||||||
|
)
|
||||||
|
|
||||||
|
def transform_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
raw_response: httpx.Response,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
logging_obj: LoggingClass,
|
||||||
|
request_data: dict,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
encoding: Any,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
json_mode: Optional[bool] = None,
|
||||||
|
) -> ModelResponse:
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key=api_key,
|
||||||
|
original_response=raw_response.text,
|
||||||
|
additional_args={"complete_input_dict": request_data},
|
||||||
|
)
|
||||||
|
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
try:
|
||||||
|
completion_response = raw_response.json()
|
||||||
|
except Exception:
|
||||||
|
raise OobaboogaError(
|
||||||
|
message=raw_response.text, status_code=raw_response.status_code
|
||||||
|
)
|
||||||
|
if "error" in completion_response:
|
||||||
|
raise OobaboogaError(
|
||||||
|
message=completion_response["error"],
|
||||||
|
status_code=raw_response.status_code,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
model_response.choices[0].message.content = completion_response["choices"][0]["message"]["content"] # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
raise OobaboogaError(
|
||||||
|
message=str(e),
|
||||||
|
status_code=raw_response.status_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_response.created = int(time.time())
|
||||||
|
model_response.model = model
|
||||||
|
usage = Usage(
|
||||||
|
prompt_tokens=completion_response["usage"]["prompt_tokens"],
|
||||||
|
completion_tokens=completion_response["usage"]["completion_tokens"],
|
||||||
|
total_tokens=completion_response["usage"]["total_tokens"],
|
||||||
|
)
|
||||||
|
setattr(model_response, "usage", usage)
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
def validate_environment(
|
||||||
|
self,
|
||||||
|
headers: dict,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
headers = {
|
||||||
|
"accept": "application/json",
|
||||||
|
"content-type": "application/json",
|
||||||
|
}
|
||||||
|
if api_key is not None:
|
||||||
|
headers["Authorization"] = f"Token {api_key}"
|
||||||
|
return headers
|
15
litellm/llms/oobabooga/common_utils.py
Normal file
15
litellm/llms/oobabooga/common_utils.py
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from litellm.llms.base_llm.transformation import BaseLLMException
|
||||||
|
|
||||||
|
|
||||||
|
class OobaboogaError(BaseLLMException):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
status_code: int,
|
||||||
|
message: str,
|
||||||
|
headers: Optional[Union[dict, httpx.Headers]] = None,
|
||||||
|
):
|
||||||
|
super().__init__(status_code=status_code, message=message, headers=headers)
|
|
@ -197,7 +197,8 @@ class OpenAIGPTConfig(BaseConfig):
|
||||||
request_data: dict,
|
request_data: dict,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
encoding: str,
|
litellm_params: dict,
|
||||||
|
encoding: Any,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
json_mode: Optional[bool] = None,
|
json_mode: Optional[bool] = None,
|
||||||
) -> ModelResponse:
|
) -> ModelResponse:
|
||||||
|
|
|
@ -1,63 +1,3 @@
|
||||||
"""
|
"""
|
||||||
Handler file for calls to OpenAI's o1 family of models
|
LLM Calling done in `openai/openai.py`
|
||||||
|
|
||||||
Written separately to handle faking streaming for o1 models.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from typing import Any, Callable, List, Optional, Union
|
|
||||||
|
|
||||||
from httpx._config import Timeout
|
|
||||||
|
|
||||||
from litellm.llms.bedrock.chat.invoke_handler import MockResponseIterator
|
|
||||||
from litellm.llms.openai.openai import OpenAIChatCompletion
|
|
||||||
from litellm.types.utils import ModelResponse
|
|
||||||
from litellm.utils import CustomStreamWrapper
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIO1ChatCompletion(OpenAIChatCompletion):
|
|
||||||
|
|
||||||
def completion(
|
|
||||||
self,
|
|
||||||
model_response: ModelResponse,
|
|
||||||
timeout: Union[float, Timeout],
|
|
||||||
optional_params: dict,
|
|
||||||
logging_obj: Any,
|
|
||||||
model: Optional[str] = None,
|
|
||||||
messages: Optional[list] = None,
|
|
||||||
print_verbose: Optional[Callable[..., Any]] = None,
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
api_base: Optional[str] = None,
|
|
||||||
acompletion: bool = False,
|
|
||||||
litellm_params=None,
|
|
||||||
logger_fn=None,
|
|
||||||
headers: Optional[dict] = None,
|
|
||||||
custom_prompt_dict: dict = {},
|
|
||||||
client=None,
|
|
||||||
organization: Optional[str] = None,
|
|
||||||
custom_llm_provider: Optional[str] = None,
|
|
||||||
drop_params: Optional[bool] = None,
|
|
||||||
):
|
|
||||||
# stream: Optional[bool] = optional_params.pop("stream", False)
|
|
||||||
response = super().completion(
|
|
||||||
model_response,
|
|
||||||
timeout,
|
|
||||||
optional_params,
|
|
||||||
logging_obj,
|
|
||||||
model,
|
|
||||||
messages,
|
|
||||||
print_verbose,
|
|
||||||
api_key,
|
|
||||||
api_base,
|
|
||||||
acompletion,
|
|
||||||
litellm_params,
|
|
||||||
logger_fn,
|
|
||||||
headers,
|
|
||||||
custom_prompt_dict,
|
|
||||||
client,
|
|
||||||
organization,
|
|
||||||
custom_llm_provider,
|
|
||||||
drop_params,
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ Common helpers / utils across al OpenAI endpoints
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import openai
|
import openai
|
||||||
|
@ -18,7 +18,7 @@ class OpenAIError(BaseLLMException):
|
||||||
message: str,
|
message: str,
|
||||||
request: Optional[httpx.Request] = None,
|
request: Optional[httpx.Request] = None,
|
||||||
response: Optional[httpx.Response] = None,
|
response: Optional[httpx.Response] = None,
|
||||||
headers: Optional[httpx.Headers] = None,
|
headers: Optional[Union[dict, httpx.Headers]] = None,
|
||||||
):
|
):
|
||||||
self.status_code = status_code
|
self.status_code = status_code
|
||||||
self.message = message
|
self.message = message
|
||||||
|
|
|
@ -4,7 +4,17 @@ import os
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
import types
|
import types
|
||||||
from typing import Any, Callable, Coroutine, Iterable, Literal, Optional, Union, cast
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Coroutine,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import openai
|
import openai
|
||||||
|
@ -18,6 +28,7 @@ import litellm
|
||||||
from litellm import LlmProviders
|
from litellm import LlmProviders
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
|
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
||||||
from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS
|
from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS
|
||||||
from litellm.secret_managers.main import get_secret_str
|
from litellm.secret_managers.main import get_secret_str
|
||||||
from litellm.types.utils import ProviderField
|
from litellm.types.utils import ProviderField
|
||||||
|
@ -35,6 +46,7 @@ from litellm.utils import (
|
||||||
from ...types.llms.openai import *
|
from ...types.llms.openai import *
|
||||||
from ..base import BaseLLM
|
from ..base import BaseLLM
|
||||||
from ..prompt_templates.factory import custom_prompt, prompt_factory
|
from ..prompt_templates.factory import custom_prompt, prompt_factory
|
||||||
|
from .chat.gpt_transformation import OpenAIGPTConfig
|
||||||
from .common_utils import OpenAIError, drop_params_from_unprocessable_entity_error
|
from .common_utils import OpenAIError, drop_params_from_unprocessable_entity_error
|
||||||
|
|
||||||
|
|
||||||
|
@ -81,135 +93,7 @@ class MistralEmbeddingConfig:
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
class DeepInfraConfig:
|
class OpenAIConfig(BaseConfig):
|
||||||
"""
|
|
||||||
Reference: https://deepinfra.com/docs/advanced/openai_api
|
|
||||||
|
|
||||||
The class `DeepInfra` provides configuration for the DeepInfra's Chat Completions API interface. Below are the parameters:
|
|
||||||
"""
|
|
||||||
|
|
||||||
frequency_penalty: Optional[int] = None
|
|
||||||
function_call: Optional[Union[str, dict]] = None
|
|
||||||
functions: Optional[list] = None
|
|
||||||
logit_bias: Optional[dict] = None
|
|
||||||
max_tokens: Optional[int] = None
|
|
||||||
n: Optional[int] = None
|
|
||||||
presence_penalty: Optional[int] = None
|
|
||||||
stop: Optional[Union[str, list]] = None
|
|
||||||
temperature: Optional[int] = None
|
|
||||||
top_p: Optional[int] = None
|
|
||||||
response_format: Optional[dict] = None
|
|
||||||
tools: Optional[list] = None
|
|
||||||
tool_choice: Optional[Union[str, dict]] = None
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
frequency_penalty: Optional[int] = None,
|
|
||||||
function_call: Optional[Union[str, dict]] = None,
|
|
||||||
functions: Optional[list] = None,
|
|
||||||
logit_bias: Optional[dict] = None,
|
|
||||||
max_tokens: Optional[int] = None,
|
|
||||||
n: Optional[int] = None,
|
|
||||||
presence_penalty: Optional[int] = None,
|
|
||||||
stop: Optional[Union[str, list]] = None,
|
|
||||||
temperature: Optional[int] = None,
|
|
||||||
top_p: Optional[int] = None,
|
|
||||||
response_format: Optional[dict] = None,
|
|
||||||
tools: Optional[list] = None,
|
|
||||||
tool_choice: Optional[Union[str, dict]] = None,
|
|
||||||
) -> None:
|
|
||||||
locals_ = locals().copy()
|
|
||||||
for key, value in locals_.items():
|
|
||||||
if key != "self" and value is not None:
|
|
||||||
setattr(self.__class__, key, value)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_config(cls):
|
|
||||||
return {
|
|
||||||
k: v
|
|
||||||
for k, v in cls.__dict__.items()
|
|
||||||
if not k.startswith("__")
|
|
||||||
and not isinstance(
|
|
||||||
v,
|
|
||||||
(
|
|
||||||
types.FunctionType,
|
|
||||||
types.BuiltinFunctionType,
|
|
||||||
classmethod,
|
|
||||||
staticmethod,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
and v is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_supported_openai_params(self):
|
|
||||||
return [
|
|
||||||
"stream",
|
|
||||||
"frequency_penalty",
|
|
||||||
"function_call",
|
|
||||||
"functions",
|
|
||||||
"logit_bias",
|
|
||||||
"max_tokens",
|
|
||||||
"max_completion_tokens",
|
|
||||||
"n",
|
|
||||||
"presence_penalty",
|
|
||||||
"stop",
|
|
||||||
"temperature",
|
|
||||||
"top_p",
|
|
||||||
"response_format",
|
|
||||||
"tools",
|
|
||||||
"tool_choice",
|
|
||||||
]
|
|
||||||
|
|
||||||
def map_openai_params(
|
|
||||||
self,
|
|
||||||
non_default_params: dict,
|
|
||||||
optional_params: dict,
|
|
||||||
model: str,
|
|
||||||
drop_params: bool,
|
|
||||||
) -> dict:
|
|
||||||
supported_openai_params = self.get_supported_openai_params()
|
|
||||||
for param, value in non_default_params.items():
|
|
||||||
if (
|
|
||||||
param == "temperature"
|
|
||||||
and value == 0
|
|
||||||
and model == "mistralai/Mistral-7B-Instruct-v0.1"
|
|
||||||
): # this model does no support temperature == 0
|
|
||||||
value = 0.0001 # close to 0
|
|
||||||
if param == "tool_choice":
|
|
||||||
if (
|
|
||||||
value != "auto" and value != "none"
|
|
||||||
): # https://deepinfra.com/docs/advanced/function_calling
|
|
||||||
## UNSUPPORTED TOOL CHOICE VALUE
|
|
||||||
if litellm.drop_params is True or drop_params is True:
|
|
||||||
value = None
|
|
||||||
else:
|
|
||||||
raise litellm.utils.UnsupportedParamsError(
|
|
||||||
message="Deepinfra doesn't support tool_choice={}. To drop unsupported openai params from the call, set `litellm.drop_params = True`".format(
|
|
||||||
value
|
|
||||||
),
|
|
||||||
status_code=400,
|
|
||||||
)
|
|
||||||
elif param == "max_completion_tokens":
|
|
||||||
optional_params["max_tokens"] = value
|
|
||||||
elif param in supported_openai_params:
|
|
||||||
if value is not None:
|
|
||||||
optional_params[param] = value
|
|
||||||
return optional_params
|
|
||||||
|
|
||||||
def _get_openai_compatible_provider_info(
|
|
||||||
self, api_base: Optional[str], api_key: Optional[str]
|
|
||||||
) -> Tuple[Optional[str], Optional[str]]:
|
|
||||||
# deepinfra is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1
|
|
||||||
api_base = (
|
|
||||||
api_base
|
|
||||||
or get_secret_str("DEEPINFRA_API_BASE")
|
|
||||||
or "https://api.deepinfra.com/v1/openai"
|
|
||||||
)
|
|
||||||
dynamic_api_key = api_key or get_secret_str("DEEPINFRA_API_KEY")
|
|
||||||
return api_base, dynamic_api_key
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIConfig:
|
|
||||||
"""
|
"""
|
||||||
Reference: https://platform.openai.com/docs/api-reference/chat/create
|
Reference: https://platform.openai.com/docs/api-reference/chat/create
|
||||||
|
|
||||||
|
@ -273,25 +157,12 @@ class OpenAIConfig:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config(cls):
|
def get_config(cls):
|
||||||
return {
|
return super().get_config()
|
||||||
k: v
|
|
||||||
for k, v in cls.__dict__.items()
|
|
||||||
if not k.startswith("__")
|
|
||||||
and not isinstance(
|
|
||||||
v,
|
|
||||||
(
|
|
||||||
types.FunctionType,
|
|
||||||
types.BuiltinFunctionType,
|
|
||||||
classmethod,
|
|
||||||
staticmethod,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
and v is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_supported_openai_params(self, model: str) -> list:
|
def get_supported_openai_params(self, model: str) -> list:
|
||||||
"""
|
"""
|
||||||
This function returns the list of supported openai parameters for a given OpenAI Model
|
This function returns the list
|
||||||
|
of supported openai parameters for a given OpenAI Model
|
||||||
|
|
||||||
- If O1 model, returns O1 supported params
|
- If O1 model, returns O1 supported params
|
||||||
- If gpt-audio model, returns gpt-audio supported params
|
- If gpt-audio model, returns gpt-audio supported params
|
||||||
|
@ -319,6 +190,11 @@ class OpenAIConfig:
|
||||||
optional_params[param] = value
|
optional_params[param] = value
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
|
def _transform_messages(
|
||||||
|
self, messages: List[AllMessageValues]
|
||||||
|
) -> List[AllMessageValues]:
|
||||||
|
return messages
|
||||||
|
|
||||||
def map_openai_params(
|
def map_openai_params(
|
||||||
self,
|
self,
|
||||||
non_default_params: dict,
|
non_default_params: dict,
|
||||||
|
@ -349,6 +225,55 @@ class OpenAIConfig:
|
||||||
drop_params=drop_params,
|
drop_params=drop_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_error_class(
|
||||||
|
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||||
|
) -> BaseLLMException:
|
||||||
|
return OpenAIError(
|
||||||
|
status_code=status_code,
|
||||||
|
message=error_message,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
def transform_request(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
headers: dict,
|
||||||
|
) -> dict:
|
||||||
|
return {"model": model, "messages": messages, **optional_params}
|
||||||
|
|
||||||
|
def transform_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
raw_response: httpx.Response,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
request_data: dict,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
encoding: Any,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
json_mode: Optional[bool] = None,
|
||||||
|
) -> ModelResponse:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"OpenAI handler does this transformation as it uses the OpenAI SDK."
|
||||||
|
)
|
||||||
|
|
||||||
|
def validate_environment(
|
||||||
|
self,
|
||||||
|
headers: dict,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"OpenAI handler does this validation as it uses the OpenAI SDK."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class OpenAIChatCompletion(BaseLLM):
|
class OpenAIChatCompletion(BaseLLM):
|
||||||
|
|
||||||
|
@ -483,6 +408,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
logging_obj: Any,
|
logging_obj: Any,
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
messages: Optional[list] = None,
|
messages: Optional[list] = None,
|
||||||
|
@ -490,7 +416,6 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
acompletion: bool = False,
|
acompletion: bool = False,
|
||||||
litellm_params=None,
|
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
headers: Optional[dict] = None,
|
headers: Optional[dict] = None,
|
||||||
custom_prompt_dict: dict = {},
|
custom_prompt_dict: dict = {},
|
||||||
|
@ -516,31 +441,26 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
|
|
||||||
if custom_llm_provider is not None and custom_llm_provider != "openai":
|
if custom_llm_provider is not None and custom_llm_provider != "openai":
|
||||||
model_response.model = f"{custom_llm_provider}/{model}"
|
model_response.model = f"{custom_llm_provider}/{model}"
|
||||||
# process all OpenAI compatible provider logic here
|
|
||||||
if custom_llm_provider == "mistral":
|
|
||||||
# check if message content passed in as list, and not string
|
|
||||||
messages = prompt_factory( # type: ignore
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
custom_llm_provider=custom_llm_provider,
|
|
||||||
)
|
|
||||||
if custom_llm_provider == "perplexity" and messages is not None:
|
|
||||||
# check if messages.name is passed + supported, if not supported remove
|
|
||||||
messages = prompt_factory( # type: ignore
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
custom_llm_provider=custom_llm_provider,
|
|
||||||
)
|
|
||||||
if messages is not None and custom_llm_provider is not None:
|
if messages is not None and custom_llm_provider is not None:
|
||||||
provider_config = ProviderConfigManager.get_provider_chat_config(
|
provider_config = ProviderConfigManager.get_provider_chat_config(
|
||||||
model=model, provider=LlmProviders(custom_llm_provider)
|
model=model, provider=LlmProviders(custom_llm_provider)
|
||||||
)
|
)
|
||||||
messages = provider_config._transform_messages(messages)
|
if isinstance(provider_config, OpenAIGPTConfig) or isinstance(
|
||||||
|
provider_config, OpenAIConfig
|
||||||
|
):
|
||||||
|
messages = provider_config._transform_messages(messages)
|
||||||
|
|
||||||
for _ in range(
|
for _ in range(
|
||||||
2
|
2
|
||||||
): # if call fails due to alternating messages, retry with reformatted message
|
): # if call fails due to alternating messages, retry with reformatted message
|
||||||
data = {"model": model, "messages": messages, **optional_params}
|
data = OpenAIConfig().transform_request(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
headers=headers or {},
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
max_retries = data.pop("max_retries", 2)
|
max_retries = data.pop("max_retries", 2)
|
||||||
|
@ -2430,7 +2350,7 @@ class OpenAIAssistantsAPI(BaseLLM):
|
||||||
"""
|
"""
|
||||||
Here's an example:
|
Here's an example:
|
||||||
```
|
```
|
||||||
from litellm.llms.OpenAI.openai import OpenAIAssistantsAPI, MessageData
|
from litellm.llms.openai.openai import OpenAIAssistantsAPI, MessageData
|
||||||
|
|
||||||
# create thread
|
# create thread
|
||||||
message: MessageData = {"role": "user", "content": "Hey, how's it going?"}
|
message: MessageData = {"role": "user", "content": "Hey, how's it going?"}
|
||||||
|
|
|
@ -26,6 +26,8 @@ from litellm.llms.custom_httpx.http_handler import (
|
||||||
get_async_httpx_client,
|
get_async_httpx_client,
|
||||||
)
|
)
|
||||||
from litellm.llms.databricks.streaming_utils import ModelResponseIterator
|
from litellm.llms.databricks.streaming_utils import ModelResponseIterator
|
||||||
|
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||||
|
from litellm.llms.openai.openai import OpenAIConfig
|
||||||
from litellm.types.utils import CustomStreamingDecoder, ModelResponse
|
from litellm.types.utils import CustomStreamingDecoder, ModelResponse
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
Choices,
|
Choices,
|
||||||
|
@ -205,6 +207,7 @@ class OpenAILikeChatHandler(OpenAILikeBase):
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
|
print(f"e.response.text: {e.response.text}")
|
||||||
raise OpenAILikeError(
|
raise OpenAILikeError(
|
||||||
status_code=e.response.status_code,
|
status_code=e.response.status_code,
|
||||||
message=e.response.text,
|
message=e.response.text,
|
||||||
|
@ -212,6 +215,7 @@ class OpenAILikeChatHandler(OpenAILikeBase):
|
||||||
except httpx.TimeoutException:
|
except httpx.TimeoutException:
|
||||||
raise OpenAILikeError(status_code=408, message="Timeout error occurred.")
|
raise OpenAILikeError(status_code=408, message="Timeout error occurred.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print(f"e: {e}")
|
||||||
raise OpenAILikeError(status_code=500, message=str(e))
|
raise OpenAILikeError(status_code=500, message=str(e))
|
||||||
|
|
||||||
return OpenAILikeChatConfig._transform_response(
|
return OpenAILikeChatConfig._transform_response(
|
||||||
|
@ -280,7 +284,10 @@ class OpenAILikeChatHandler(OpenAILikeBase):
|
||||||
provider_config = ProviderConfigManager.get_provider_chat_config(
|
provider_config = ProviderConfigManager.get_provider_chat_config(
|
||||||
model=model, provider=LlmProviders(custom_llm_provider)
|
model=model, provider=LlmProviders(custom_llm_provider)
|
||||||
)
|
)
|
||||||
messages = provider_config._transform_messages(messages)
|
if isinstance(provider_config, OpenAIGPTConfig) or isinstance(
|
||||||
|
provider_config, OpenAIConfig
|
||||||
|
):
|
||||||
|
messages = provider_config._transform_messages(messages)
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"model": model,
|
"model": model,
|
||||||
|
|
|
@ -75,6 +75,7 @@ class OpenAILikeChatConfig(OpenAIGPTConfig):
|
||||||
custom_llm_provider: str,
|
custom_llm_provider: str,
|
||||||
base_model: Optional[str],
|
base_model: Optional[str],
|
||||||
) -> ModelResponse:
|
) -> ModelResponse:
|
||||||
|
print(f"response: {response}")
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=messages,
|
input=messages,
|
||||||
|
@ -99,3 +100,25 @@ class OpenAILikeChatConfig(OpenAIGPTConfig):
|
||||||
if base_model is not None:
|
if base_model is not None:
|
||||||
returned_response._hidden_params["model"] = base_model
|
returned_response._hidden_params["model"] = base_model
|
||||||
return returned_response
|
return returned_response
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self,
|
||||||
|
non_default_params: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
model: str,
|
||||||
|
drop_params: bool,
|
||||||
|
replace_max_completion_tokens_with_max_tokens: bool = True,
|
||||||
|
) -> dict:
|
||||||
|
mapped_params = super().map_openai_params(
|
||||||
|
non_default_params, optional_params, model, drop_params
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
"max_completion_tokens" in non_default_params
|
||||||
|
and replace_max_completion_tokens_with_max_tokens
|
||||||
|
):
|
||||||
|
mapped_params["max_tokens"] = non_default_params[
|
||||||
|
"max_completion_tokens"
|
||||||
|
] # most openai-compatible providers support 'max_tokens' not 'max_completion_tokens'
|
||||||
|
mapped_params.pop("max_completion_tokens", None)
|
||||||
|
|
||||||
|
return mapped_params
|
||||||
|
|
|
@ -1,41 +0,0 @@
|
||||||
from typing import List, Dict
|
|
||||||
import types
|
|
||||||
|
|
||||||
|
|
||||||
class OpenrouterConfig:
|
|
||||||
"""
|
|
||||||
Reference: https://openrouter.ai/docs#format
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
# OpenRouter-only parameters
|
|
||||||
extra_body: Dict[str, List[str]] = {"transforms": []} # default transforms to []
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
transforms: List[str] = [],
|
|
||||||
models: List[str] = [],
|
|
||||||
route: str = "",
|
|
||||||
) -> None:
|
|
||||||
locals_ = locals()
|
|
||||||
for key, value in locals_.items():
|
|
||||||
if key != "self" and value is not None:
|
|
||||||
setattr(self.__class__, key, value)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_config(cls):
|
|
||||||
return {
|
|
||||||
k: v
|
|
||||||
for k, v in cls.__dict__.items()
|
|
||||||
if not k.startswith("__")
|
|
||||||
and not isinstance(
|
|
||||||
v,
|
|
||||||
(
|
|
||||||
types.FunctionType,
|
|
||||||
types.BuiltinFunctionType,
|
|
||||||
classmethod,
|
|
||||||
staticmethod,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
and v is not None
|
|
||||||
}
|
|
43
litellm/llms/openrouter/chat/transformation.py
Normal file
43
litellm/llms/openrouter/chat/transformation.py
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
"""
|
||||||
|
Support for OpenAI's `/v1/chat/completions` endpoint.
|
||||||
|
|
||||||
|
Calls done in OpenAI/openai.py as OpenRouter is openai-compatible.
|
||||||
|
|
||||||
|
Docs: https://openrouter.ai/docs/parameters
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from litellm import get_model_info, verbose_logger
|
||||||
|
|
||||||
|
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||||
|
|
||||||
|
|
||||||
|
class OpenrouterConfig(OpenAIGPTConfig):
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self,
|
||||||
|
non_default_params: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
model: str,
|
||||||
|
drop_params: bool,
|
||||||
|
) -> dict:
|
||||||
|
mapped_openai_params = super().map_openai_params(
|
||||||
|
non_default_params, optional_params, model, drop_params
|
||||||
|
)
|
||||||
|
|
||||||
|
# OpenRouter-only parameters
|
||||||
|
extra_body = {}
|
||||||
|
transforms = non_default_params.pop("transforms", None)
|
||||||
|
models = non_default_params.pop("models", None)
|
||||||
|
route = non_default_params.pop("route", None)
|
||||||
|
if transforms is not None:
|
||||||
|
extra_body["transforms"] = transforms
|
||||||
|
if models is not None:
|
||||||
|
extra_body["models"] = models
|
||||||
|
if route is not None:
|
||||||
|
extra_body["route"] = route
|
||||||
|
mapped_openai_params["extra_body"] = (
|
||||||
|
extra_body # openai client supports `extra_body` param
|
||||||
|
)
|
||||||
|
return mapped_openai_params
|
|
@ -4,7 +4,7 @@ Common utility functions used for translating messages across providers
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Dict, List, Literal, Optional, Union
|
from typing import Dict, List, Literal, Optional, Union, cast
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
|
@ -53,6 +53,13 @@ def strip_name_from_messages(
|
||||||
return new_messages
|
return new_messages
|
||||||
|
|
||||||
|
|
||||||
|
def strip_none_values_from_message(message: AllMessageValues) -> AllMessageValues:
|
||||||
|
"""
|
||||||
|
Strips None values from message
|
||||||
|
"""
|
||||||
|
return cast(AllMessageValues, {k: v for k, v in message.items() if v is not None})
|
||||||
|
|
||||||
|
|
||||||
def convert_content_list_to_str(message: AllMessageValues) -> str:
|
def convert_content_list_to_str(message: AllMessageValues) -> str:
|
||||||
"""
|
"""
|
||||||
- handles scenario where content is list and not string
|
- handles scenario where content is list and not string
|
||||||
|
|
|
@ -2856,7 +2856,7 @@ def prompt_factory(
|
||||||
else:
|
else:
|
||||||
return gemini_text_image_pt(messages=messages)
|
return gemini_text_image_pt(messages=messages)
|
||||||
elif custom_llm_provider == "mistral":
|
elif custom_llm_provider == "mistral":
|
||||||
return litellm.MistralConfig._transform_messages(messages=messages)
|
return litellm.MistralConfig()._transform_messages(messages=messages)
|
||||||
elif custom_llm_provider == "bedrock":
|
elif custom_llm_provider == "bedrock":
|
||||||
if "amazon.titan-text" in model:
|
if "amazon.titan-text" in model:
|
||||||
return amazon_titan_pt(messages=messages)
|
return amazon_titan_pt(messages=messages)
|
||||||
|
|
|
@ -1,609 +0,0 @@
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
import types
|
|
||||||
from typing import Any, Callable, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import httpx # type: ignore
|
|
||||||
import requests # type: ignore
|
|
||||||
|
|
||||||
import litellm
|
|
||||||
from litellm.llms.custom_httpx.http_handler import (
|
|
||||||
AsyncHTTPHandler,
|
|
||||||
get_async_httpx_client,
|
|
||||||
)
|
|
||||||
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
|
||||||
|
|
||||||
from .prompt_templates.factory import custom_prompt, prompt_factory
|
|
||||||
|
|
||||||
|
|
||||||
class ReplicateError(Exception):
|
|
||||||
def __init__(self, status_code, message):
|
|
||||||
self.status_code = status_code
|
|
||||||
self.message = message
|
|
||||||
self.request = httpx.Request(
|
|
||||||
method="POST", url="https://api.replicate.com/v1/deployments"
|
|
||||||
)
|
|
||||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
|
||||||
super().__init__(
|
|
||||||
self.message
|
|
||||||
) # Call the base class constructor with the parameters it needs
|
|
||||||
|
|
||||||
|
|
||||||
class ReplicateConfig:
|
|
||||||
"""
|
|
||||||
Reference: https://replicate.com/meta/llama-2-70b-chat/api
|
|
||||||
- `prompt` (string): The prompt to send to the model.
|
|
||||||
|
|
||||||
- `system_prompt` (string): The system prompt to send to the model. This is prepended to the prompt and helps guide system behavior. Default value: `You are a helpful assistant`.
|
|
||||||
|
|
||||||
- `max_new_tokens` (integer): Maximum number of tokens to generate. Typically, a word is made up of 2-3 tokens. Default value: `128`.
|
|
||||||
|
|
||||||
- `min_new_tokens` (integer): Minimum number of tokens to generate. To disable, set to `-1`. A word is usually 2-3 tokens. Default value: `-1`.
|
|
||||||
|
|
||||||
- `temperature` (number): Adjusts the randomness of outputs. Values greater than 1 increase randomness, 0 is deterministic, and 0.75 is a reasonable starting value. Default value: `0.75`.
|
|
||||||
|
|
||||||
- `top_p` (number): During text decoding, it samples from the top `p` percentage of most likely tokens. Reduce this to ignore less probable tokens. Default value: `0.9`.
|
|
||||||
|
|
||||||
- `top_k` (integer): During text decoding, samples from the top `k` most likely tokens. Reduce this to ignore less probable tokens. Default value: `50`.
|
|
||||||
|
|
||||||
- `stop_sequences` (string): A comma-separated list of sequences to stop generation at. For example, inputting '<end>,<stop>' will cease generation at the first occurrence of either 'end' or '<stop>'.
|
|
||||||
|
|
||||||
- `seed` (integer): This is the seed for the random generator. Leave it blank to randomize the seed.
|
|
||||||
|
|
||||||
- `debug` (boolean): If set to `True`, it provides debugging output in logs.
|
|
||||||
|
|
||||||
Please note that Replicate's mapping of these parameters can be inconsistent across different models, indicating that not all of these parameters may be available for use with all models.
|
|
||||||
"""
|
|
||||||
|
|
||||||
system_prompt: Optional[str] = None
|
|
||||||
max_new_tokens: Optional[int] = None
|
|
||||||
min_new_tokens: Optional[int] = None
|
|
||||||
temperature: Optional[int] = None
|
|
||||||
top_p: Optional[int] = None
|
|
||||||
top_k: Optional[int] = None
|
|
||||||
stop_sequences: Optional[str] = None
|
|
||||||
seed: Optional[int] = None
|
|
||||||
debug: Optional[bool] = None
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
system_prompt: Optional[str] = None,
|
|
||||||
max_new_tokens: Optional[int] = None,
|
|
||||||
min_new_tokens: Optional[int] = None,
|
|
||||||
temperature: Optional[int] = None,
|
|
||||||
top_p: Optional[int] = None,
|
|
||||||
top_k: Optional[int] = None,
|
|
||||||
stop_sequences: Optional[str] = None,
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
debug: Optional[bool] = None,
|
|
||||||
) -> None:
|
|
||||||
locals_ = locals()
|
|
||||||
for key, value in locals_.items():
|
|
||||||
if key != "self" and value is not None:
|
|
||||||
setattr(self.__class__, key, value)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_config(cls):
|
|
||||||
return {
|
|
||||||
k: v
|
|
||||||
for k, v in cls.__dict__.items()
|
|
||||||
if not k.startswith("__")
|
|
||||||
and not isinstance(
|
|
||||||
v,
|
|
||||||
(
|
|
||||||
types.FunctionType,
|
|
||||||
types.BuiltinFunctionType,
|
|
||||||
classmethod,
|
|
||||||
staticmethod,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
and v is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# Function to start a prediction and get the prediction URL
|
|
||||||
def start_prediction(
|
|
||||||
version_id, input_data, api_token, api_base, logging_obj, print_verbose
|
|
||||||
):
|
|
||||||
base_url = api_base
|
|
||||||
if "deployments" in version_id:
|
|
||||||
print_verbose("\nLiteLLM: Request to custom replicate deployment")
|
|
||||||
version_id = version_id.replace("deployments/", "")
|
|
||||||
base_url = f"https://api.replicate.com/v1/deployments/{version_id}"
|
|
||||||
print_verbose(f"Deployment base URL: {base_url}\n")
|
|
||||||
else: # assume it's a model
|
|
||||||
base_url = f"https://api.replicate.com/v1/models/{version_id}"
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Token {api_token}",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
|
|
||||||
initial_prediction_data = {
|
|
||||||
"input": input_data,
|
|
||||||
}
|
|
||||||
|
|
||||||
if ":" in version_id and len(version_id) > 64:
|
|
||||||
model_parts = version_id.split(":")
|
|
||||||
if (
|
|
||||||
len(model_parts) > 1 and len(model_parts[1]) == 64
|
|
||||||
): ## checks if model name has a 64 digit code - e.g. "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3"
|
|
||||||
initial_prediction_data["version"] = model_parts[1]
|
|
||||||
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.pre_call(
|
|
||||||
input=input_data["prompt"],
|
|
||||||
api_key="",
|
|
||||||
additional_args={
|
|
||||||
"complete_input_dict": initial_prediction_data,
|
|
||||||
"headers": headers,
|
|
||||||
"api_base": base_url,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
response = requests.post(
|
|
||||||
f"{base_url}/predictions", json=initial_prediction_data, headers=headers
|
|
||||||
)
|
|
||||||
if response.status_code == 201:
|
|
||||||
response_data = response.json()
|
|
||||||
return response_data.get("urls", {}).get("get")
|
|
||||||
else:
|
|
||||||
raise ReplicateError(
|
|
||||||
response.status_code, f"Failed to start prediction {response.text}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def async_start_prediction(
|
|
||||||
version_id,
|
|
||||||
input_data,
|
|
||||||
api_token,
|
|
||||||
api_base,
|
|
||||||
logging_obj,
|
|
||||||
print_verbose,
|
|
||||||
http_handler: AsyncHTTPHandler,
|
|
||||||
) -> str:
|
|
||||||
base_url = api_base
|
|
||||||
if "deployments" in version_id:
|
|
||||||
print_verbose("\nLiteLLM: Request to custom replicate deployment")
|
|
||||||
version_id = version_id.replace("deployments/", "")
|
|
||||||
base_url = f"https://api.replicate.com/v1/deployments/{version_id}"
|
|
||||||
print_verbose(f"Deployment base URL: {base_url}\n")
|
|
||||||
else: # assume it's a model
|
|
||||||
base_url = f"https://api.replicate.com/v1/models/{version_id}"
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Token {api_token}",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
|
|
||||||
initial_prediction_data = {
|
|
||||||
"input": input_data,
|
|
||||||
}
|
|
||||||
|
|
||||||
if ":" in version_id and len(version_id) > 64:
|
|
||||||
model_parts = version_id.split(":")
|
|
||||||
if (
|
|
||||||
len(model_parts) > 1 and len(model_parts[1]) == 64
|
|
||||||
): ## checks if model name has a 64 digit code - e.g. "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3"
|
|
||||||
initial_prediction_data["version"] = model_parts[1]
|
|
||||||
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.pre_call(
|
|
||||||
input=input_data["prompt"],
|
|
||||||
api_key="",
|
|
||||||
additional_args={
|
|
||||||
"complete_input_dict": initial_prediction_data,
|
|
||||||
"headers": headers,
|
|
||||||
"api_base": base_url,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await http_handler.post(
|
|
||||||
url="{}/predictions".format(base_url),
|
|
||||||
data=json.dumps(initial_prediction_data),
|
|
||||||
headers=headers,
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code == 201:
|
|
||||||
response_data = response.json()
|
|
||||||
return response_data.get("urls", {}).get("get")
|
|
||||||
else:
|
|
||||||
raise ReplicateError(
|
|
||||||
response.status_code, f"Failed to start prediction {response.text}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Function to handle prediction response (non-streaming)
|
|
||||||
def handle_prediction_response(prediction_url, api_token, print_verbose):
|
|
||||||
output_string = ""
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Token {api_token}",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
|
|
||||||
status = ""
|
|
||||||
logs = ""
|
|
||||||
while True and (status not in ["succeeded", "failed", "canceled"]):
|
|
||||||
print_verbose(f"replicate: polling endpoint: {prediction_url}")
|
|
||||||
time.sleep(0.5)
|
|
||||||
response = requests.get(prediction_url, headers=headers)
|
|
||||||
if response.status_code == 200:
|
|
||||||
response_data = response.json()
|
|
||||||
if "output" in response_data:
|
|
||||||
output_string = "".join(response_data["output"])
|
|
||||||
print_verbose(f"Non-streamed output:{output_string}")
|
|
||||||
status = response_data.get("status", None)
|
|
||||||
logs = response_data.get("logs", "")
|
|
||||||
if status == "failed":
|
|
||||||
replicate_error = response_data.get("error", "")
|
|
||||||
raise ReplicateError(
|
|
||||||
status_code=400,
|
|
||||||
message=f"Error: {replicate_error}, \nReplicate logs:{logs}",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
|
|
||||||
print_verbose("Replicate: Failed to fetch prediction status and output.")
|
|
||||||
return output_string, logs
|
|
||||||
|
|
||||||
|
|
||||||
async def async_handle_prediction_response(
|
|
||||||
prediction_url, api_token, print_verbose, http_handler: AsyncHTTPHandler
|
|
||||||
) -> Tuple[str, Any]:
|
|
||||||
output_string = ""
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Token {api_token}",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
|
|
||||||
status = ""
|
|
||||||
logs = ""
|
|
||||||
while True and (status not in ["succeeded", "failed", "canceled"]):
|
|
||||||
print_verbose(f"replicate: polling endpoint: {prediction_url}")
|
|
||||||
await asyncio.sleep(0.5) # prevent replicate rate limit errors
|
|
||||||
response = await http_handler.get(prediction_url, headers=headers)
|
|
||||||
if response.status_code == 200:
|
|
||||||
response_data = response.json()
|
|
||||||
if "output" in response_data:
|
|
||||||
output_string = "".join(response_data["output"])
|
|
||||||
print_verbose(f"Non-streamed output:{output_string}")
|
|
||||||
status = response_data.get("status", None)
|
|
||||||
logs = response_data.get("logs", "")
|
|
||||||
if status == "failed":
|
|
||||||
replicate_error = response_data.get("error", "")
|
|
||||||
raise ReplicateError(
|
|
||||||
status_code=400,
|
|
||||||
message=f"Error: {replicate_error}, \nReplicate logs:{logs}",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
|
|
||||||
print_verbose("Replicate: Failed to fetch prediction status and output.")
|
|
||||||
return output_string, logs
|
|
||||||
|
|
||||||
|
|
||||||
# Function to handle prediction response (streaming)
|
|
||||||
def handle_prediction_response_streaming(prediction_url, api_token, print_verbose):
|
|
||||||
previous_output = ""
|
|
||||||
output_string = ""
|
|
||||||
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Token {api_token}",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
status = ""
|
|
||||||
while True and (status not in ["succeeded", "failed", "canceled"]):
|
|
||||||
time.sleep(0.5) # prevent being rate limited by replicate
|
|
||||||
print_verbose(f"replicate: polling endpoint: {prediction_url}")
|
|
||||||
response = requests.get(prediction_url, headers=headers)
|
|
||||||
if response.status_code == 200:
|
|
||||||
response_data = response.json()
|
|
||||||
status = response_data["status"]
|
|
||||||
if "output" in response_data:
|
|
||||||
try:
|
|
||||||
output_string = "".join(response_data["output"])
|
|
||||||
except Exception:
|
|
||||||
raise ReplicateError(
|
|
||||||
status_code=422,
|
|
||||||
message="Unable to parse response. Got={}".format(
|
|
||||||
response_data["output"]
|
|
||||||
),
|
|
||||||
)
|
|
||||||
new_output = output_string[len(previous_output) :]
|
|
||||||
print_verbose(f"New chunk: {new_output}")
|
|
||||||
yield {"output": new_output, "status": status}
|
|
||||||
previous_output = output_string
|
|
||||||
status = response_data["status"]
|
|
||||||
if status == "failed":
|
|
||||||
replicate_error = response_data.get("error", "")
|
|
||||||
raise ReplicateError(
|
|
||||||
status_code=400, message=f"Error: {replicate_error}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
|
|
||||||
print_verbose(
|
|
||||||
f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Function to handle prediction response (streaming)
|
|
||||||
async def async_handle_prediction_response_streaming(
|
|
||||||
prediction_url, api_token, print_verbose
|
|
||||||
):
|
|
||||||
http_handler = get_async_httpx_client(llm_provider=litellm.LlmProviders.REPLICATE)
|
|
||||||
previous_output = ""
|
|
||||||
output_string = ""
|
|
||||||
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Token {api_token}",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
status = ""
|
|
||||||
while True and (status not in ["succeeded", "failed", "canceled"]):
|
|
||||||
await asyncio.sleep(0.5) # prevent being rate limited by replicate
|
|
||||||
print_verbose(f"replicate: polling endpoint: {prediction_url}")
|
|
||||||
response = await http_handler.get(prediction_url, headers=headers)
|
|
||||||
if response.status_code == 200:
|
|
||||||
response_data = response.json()
|
|
||||||
status = response_data["status"]
|
|
||||||
if "output" in response_data:
|
|
||||||
try:
|
|
||||||
output_string = "".join(response_data["output"])
|
|
||||||
except Exception:
|
|
||||||
raise ReplicateError(
|
|
||||||
status_code=422,
|
|
||||||
message="Unable to parse response. Got={}".format(
|
|
||||||
response_data["output"]
|
|
||||||
),
|
|
||||||
)
|
|
||||||
new_output = output_string[len(previous_output) :]
|
|
||||||
print_verbose(f"New chunk: {new_output}")
|
|
||||||
yield {"output": new_output, "status": status}
|
|
||||||
previous_output = output_string
|
|
||||||
status = response_data["status"]
|
|
||||||
if status == "failed":
|
|
||||||
replicate_error = response_data.get("error", "")
|
|
||||||
raise ReplicateError(
|
|
||||||
status_code=400, message=f"Error: {replicate_error}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
|
|
||||||
print_verbose(
|
|
||||||
f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Function to extract version ID from model string
|
|
||||||
def model_to_version_id(model):
|
|
||||||
if ":" in model:
|
|
||||||
split_model = model.split(":")
|
|
||||||
return split_model[1]
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def process_response(
|
|
||||||
model_response: ModelResponse,
|
|
||||||
result: str,
|
|
||||||
model: str,
|
|
||||||
encoding: Any,
|
|
||||||
prompt: str,
|
|
||||||
) -> ModelResponse:
|
|
||||||
if len(result) == 0: # edge case, where result from replicate is empty
|
|
||||||
result = " "
|
|
||||||
|
|
||||||
## Building RESPONSE OBJECT
|
|
||||||
if len(result) >= 1:
|
|
||||||
model_response.choices[0].message.content = result # type: ignore
|
|
||||||
|
|
||||||
# Calculate usage
|
|
||||||
prompt_tokens = len(encoding.encode(prompt, disallowed_special=()))
|
|
||||||
completion_tokens = len(
|
|
||||||
encoding.encode(
|
|
||||||
model_response["choices"][0]["message"].get("content", ""),
|
|
||||||
disallowed_special=(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
model_response.model = "replicate/" + model
|
|
||||||
usage = Usage(
|
|
||||||
prompt_tokens=prompt_tokens,
|
|
||||||
completion_tokens=completion_tokens,
|
|
||||||
total_tokens=prompt_tokens + completion_tokens,
|
|
||||||
)
|
|
||||||
setattr(model_response, "usage", usage)
|
|
||||||
|
|
||||||
return model_response
|
|
||||||
|
|
||||||
|
|
||||||
# Main function for prediction completion
|
|
||||||
def completion(
|
|
||||||
model: str,
|
|
||||||
messages: list,
|
|
||||||
api_base: str,
|
|
||||||
model_response: ModelResponse,
|
|
||||||
print_verbose: Callable,
|
|
||||||
optional_params: dict,
|
|
||||||
logging_obj,
|
|
||||||
api_key,
|
|
||||||
encoding,
|
|
||||||
custom_prompt_dict={},
|
|
||||||
litellm_params=None,
|
|
||||||
logger_fn=None,
|
|
||||||
acompletion=None,
|
|
||||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
|
||||||
# Start a prediction and get the prediction URL
|
|
||||||
version_id = model_to_version_id(model)
|
|
||||||
## Load Config
|
|
||||||
config = litellm.ReplicateConfig.get_config()
|
|
||||||
for k, v in config.items():
|
|
||||||
if (
|
|
||||||
k not in optional_params
|
|
||||||
): # completion(top_k=3) > replicate_config(top_k=3) <- allows for dynamic variables to be passed in
|
|
||||||
optional_params[k] = v
|
|
||||||
|
|
||||||
system_prompt = None
|
|
||||||
if optional_params is not None and "supports_system_prompt" in optional_params:
|
|
||||||
supports_sys_prompt = optional_params.pop("supports_system_prompt")
|
|
||||||
else:
|
|
||||||
supports_sys_prompt = False
|
|
||||||
|
|
||||||
if supports_sys_prompt:
|
|
||||||
for i in range(len(messages)):
|
|
||||||
if messages[i]["role"] == "system":
|
|
||||||
first_sys_message = messages.pop(i)
|
|
||||||
system_prompt = first_sys_message["content"]
|
|
||||||
break
|
|
||||||
|
|
||||||
if model in custom_prompt_dict:
|
|
||||||
# check if the model has a registered custom prompt
|
|
||||||
model_prompt_details = custom_prompt_dict[model]
|
|
||||||
prompt = custom_prompt(
|
|
||||||
role_dict=model_prompt_details.get("roles", {}),
|
|
||||||
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
|
|
||||||
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
|
||||||
bos_token=model_prompt_details.get("bos_token", ""),
|
|
||||||
eos_token=model_prompt_details.get("eos_token", ""),
|
|
||||||
messages=messages,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
prompt = prompt_factory(model=model, messages=messages)
|
|
||||||
|
|
||||||
if prompt is None or not isinstance(prompt, str):
|
|
||||||
raise ReplicateError(
|
|
||||||
status_code=400,
|
|
||||||
message="LiteLLM Error - prompt is not a string - {}".format(prompt),
|
|
||||||
)
|
|
||||||
|
|
||||||
# If system prompt is supported, and a system prompt is provided, use it
|
|
||||||
if system_prompt is not None:
|
|
||||||
input_data = {
|
|
||||||
"prompt": prompt,
|
|
||||||
"system_prompt": system_prompt,
|
|
||||||
**optional_params,
|
|
||||||
}
|
|
||||||
# Otherwise, use the prompt as is
|
|
||||||
else:
|
|
||||||
input_data = {"prompt": prompt, **optional_params}
|
|
||||||
|
|
||||||
if acompletion is not None and acompletion is True:
|
|
||||||
return async_completion(
|
|
||||||
model_response=model_response,
|
|
||||||
model=model,
|
|
||||||
prompt=prompt,
|
|
||||||
encoding=encoding,
|
|
||||||
optional_params=optional_params,
|
|
||||||
version_id=version_id,
|
|
||||||
input_data=input_data,
|
|
||||||
api_key=api_key,
|
|
||||||
api_base=api_base,
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
print_verbose=print_verbose,
|
|
||||||
) # type: ignore
|
|
||||||
## COMPLETION CALL
|
|
||||||
## Replicate Compeltion calls have 2 steps
|
|
||||||
## Step1: Start Prediction: gets a prediction url
|
|
||||||
## Step2: Poll prediction url for response
|
|
||||||
## Step2: is handled with and without streaming
|
|
||||||
model_response.created = int(
|
|
||||||
time.time()
|
|
||||||
) # for pricing this must remain right before calling api
|
|
||||||
|
|
||||||
prediction_url = start_prediction(
|
|
||||||
version_id,
|
|
||||||
input_data,
|
|
||||||
api_key,
|
|
||||||
api_base,
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
print_verbose=print_verbose,
|
|
||||||
)
|
|
||||||
print_verbose(prediction_url)
|
|
||||||
|
|
||||||
# Handle the prediction response (streaming or non-streaming)
|
|
||||||
if "stream" in optional_params and optional_params["stream"] is True:
|
|
||||||
print_verbose("streaming request")
|
|
||||||
_response = handle_prediction_response_streaming(
|
|
||||||
prediction_url, api_key, print_verbose
|
|
||||||
)
|
|
||||||
return CustomStreamWrapper(_response, model, logging_obj=logging_obj, custom_llm_provider="replicate") # type: ignore
|
|
||||||
else:
|
|
||||||
result, logs = handle_prediction_response(
|
|
||||||
prediction_url, api_key, print_verbose
|
|
||||||
)
|
|
||||||
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.post_call(
|
|
||||||
input=prompt,
|
|
||||||
api_key="",
|
|
||||||
original_response=result,
|
|
||||||
additional_args={
|
|
||||||
"complete_input_dict": input_data,
|
|
||||||
"logs": logs,
|
|
||||||
"api_base": prediction_url,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
print_verbose(f"raw model_response: {result}")
|
|
||||||
|
|
||||||
return process_response(
|
|
||||||
model_response=model_response,
|
|
||||||
result=result,
|
|
||||||
model=model,
|
|
||||||
encoding=encoding,
|
|
||||||
prompt=prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def async_completion(
|
|
||||||
model_response: ModelResponse,
|
|
||||||
model: str,
|
|
||||||
prompt: str,
|
|
||||||
encoding,
|
|
||||||
optional_params: dict,
|
|
||||||
version_id,
|
|
||||||
input_data,
|
|
||||||
api_key,
|
|
||||||
api_base,
|
|
||||||
logging_obj,
|
|
||||||
print_verbose,
|
|
||||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
|
||||||
http_handler = get_async_httpx_client(
|
|
||||||
llm_provider=litellm.LlmProviders.REPLICATE,
|
|
||||||
)
|
|
||||||
prediction_url = await async_start_prediction(
|
|
||||||
version_id,
|
|
||||||
input_data,
|
|
||||||
api_key,
|
|
||||||
api_base,
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
print_verbose=print_verbose,
|
|
||||||
http_handler=http_handler,
|
|
||||||
)
|
|
||||||
|
|
||||||
if "stream" in optional_params and optional_params["stream"] is True:
|
|
||||||
_response = async_handle_prediction_response_streaming(
|
|
||||||
prediction_url, api_key, print_verbose
|
|
||||||
)
|
|
||||||
return CustomStreamWrapper(_response, model, logging_obj=logging_obj, custom_llm_provider="replicate") # type: ignore
|
|
||||||
|
|
||||||
result, logs = await async_handle_prediction_response(
|
|
||||||
prediction_url, api_key, print_verbose, http_handler=http_handler
|
|
||||||
)
|
|
||||||
|
|
||||||
return process_response(
|
|
||||||
model_response=model_response,
|
|
||||||
result=result,
|
|
||||||
model=model,
|
|
||||||
encoding=encoding,
|
|
||||||
prompt=prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# # Example usage:
|
|
||||||
# response = completion(
|
|
||||||
# api_key="",
|
|
||||||
# messages=[{"content": "good morning"}],
|
|
||||||
# model="replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf",
|
|
||||||
# model_response=ModelResponse(),
|
|
||||||
# print_verbose=print,
|
|
||||||
# logging_obj=print, # stub logging_obj
|
|
||||||
# optional_params={"stream": False}
|
|
||||||
# )
|
|
||||||
|
|
||||||
# print(response)
|
|
285
litellm/llms/replicate/chat/handler.py
Normal file
285
litellm/llms/replicate/chat/handler.py
Normal file
|
@ -0,0 +1,285 @@
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import types
|
||||||
|
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import httpx # type: ignore
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
AsyncHTTPHandler,
|
||||||
|
HTTPHandler,
|
||||||
|
_get_httpx_client,
|
||||||
|
get_async_httpx_client,
|
||||||
|
)
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
|
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
||||||
|
|
||||||
|
from ...prompt_templates.factory import custom_prompt, prompt_factory
|
||||||
|
from ..common_utils import ReplicateError
|
||||||
|
from .transformation import ReplicateConfig
|
||||||
|
|
||||||
|
replicate_config = ReplicateConfig()
|
||||||
|
|
||||||
|
|
||||||
|
# Function to handle prediction response (streaming)
|
||||||
|
def handle_prediction_response_streaming(
|
||||||
|
prediction_url, api_token, print_verbose, headers: dict, http_client: HTTPHandler
|
||||||
|
):
|
||||||
|
previous_output = ""
|
||||||
|
output_string = ""
|
||||||
|
|
||||||
|
status = ""
|
||||||
|
while True and (status not in ["succeeded", "failed", "canceled"]):
|
||||||
|
time.sleep(0.5) # prevent being rate limited by replicate
|
||||||
|
print_verbose(f"replicate: polling endpoint: {prediction_url}")
|
||||||
|
response = http_client.get(prediction_url, headers=headers)
|
||||||
|
if response.status_code == 200:
|
||||||
|
response_data = response.json()
|
||||||
|
status = response_data["status"]
|
||||||
|
if "output" in response_data:
|
||||||
|
try:
|
||||||
|
output_string = "".join(response_data["output"])
|
||||||
|
except Exception:
|
||||||
|
raise ReplicateError(
|
||||||
|
status_code=422,
|
||||||
|
message="Unable to parse response. Got={}".format(
|
||||||
|
response_data["output"]
|
||||||
|
),
|
||||||
|
headers=response.headers,
|
||||||
|
)
|
||||||
|
new_output = output_string[len(previous_output) :]
|
||||||
|
print_verbose(f"New chunk: {new_output}")
|
||||||
|
yield {"output": new_output, "status": status}
|
||||||
|
previous_output = output_string
|
||||||
|
status = response_data["status"]
|
||||||
|
if status == "failed":
|
||||||
|
replicate_error = response_data.get("error", "")
|
||||||
|
raise ReplicateError(
|
||||||
|
status_code=400,
|
||||||
|
message=f"Error: {replicate_error}",
|
||||||
|
headers=response.headers,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
|
||||||
|
print_verbose(
|
||||||
|
f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Function to handle prediction response (streaming)
|
||||||
|
async def async_handle_prediction_response_streaming(
|
||||||
|
prediction_url,
|
||||||
|
api_token,
|
||||||
|
print_verbose,
|
||||||
|
headers: dict,
|
||||||
|
http_client: AsyncHTTPHandler,
|
||||||
|
):
|
||||||
|
previous_output = ""
|
||||||
|
output_string = ""
|
||||||
|
|
||||||
|
status = ""
|
||||||
|
while True and (status not in ["succeeded", "failed", "canceled"]):
|
||||||
|
await asyncio.sleep(0.5) # prevent being rate limited by replicate
|
||||||
|
print_verbose(f"replicate: polling endpoint: {prediction_url}")
|
||||||
|
response = await http_client.get(prediction_url, headers=headers)
|
||||||
|
if response.status_code == 200:
|
||||||
|
response_data = response.json()
|
||||||
|
status = response_data["status"]
|
||||||
|
if "output" in response_data:
|
||||||
|
try:
|
||||||
|
output_string = "".join(response_data["output"])
|
||||||
|
except Exception:
|
||||||
|
raise ReplicateError(
|
||||||
|
status_code=422,
|
||||||
|
message="Unable to parse response. Got={}".format(
|
||||||
|
response_data["output"]
|
||||||
|
),
|
||||||
|
headers=response.headers,
|
||||||
|
)
|
||||||
|
new_output = output_string[len(previous_output) :]
|
||||||
|
print_verbose(f"New chunk: {new_output}")
|
||||||
|
yield {"output": new_output, "status": status}
|
||||||
|
previous_output = output_string
|
||||||
|
status = response_data["status"]
|
||||||
|
if status == "failed":
|
||||||
|
replicate_error = response_data.get("error", "")
|
||||||
|
raise ReplicateError(
|
||||||
|
status_code=400,
|
||||||
|
message=f"Error: {replicate_error}",
|
||||||
|
headers=response.headers,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
|
||||||
|
print_verbose(
|
||||||
|
f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Main function for prediction completion
|
||||||
|
def completion(
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
logging_obj,
|
||||||
|
api_key,
|
||||||
|
encoding,
|
||||||
|
custom_prompt_dict={},
|
||||||
|
logger_fn=None,
|
||||||
|
acompletion=None,
|
||||||
|
headers={},
|
||||||
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
|
headers = replicate_config.validate_environment(
|
||||||
|
api_key=api_key,
|
||||||
|
headers=headers,
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
)
|
||||||
|
# Start a prediction and get the prediction URL
|
||||||
|
version_id = replicate_config.model_to_version_id(model)
|
||||||
|
input_data = replicate_config.transform_request(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
if acompletion is not None and acompletion is True:
|
||||||
|
return async_completion(
|
||||||
|
model_response=model_response,
|
||||||
|
model=model,
|
||||||
|
encoding=encoding,
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
version_id=version_id,
|
||||||
|
input_data=input_data,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
headers=headers,
|
||||||
|
) # type: ignore
|
||||||
|
## COMPLETION CALL
|
||||||
|
model_response.created = int(
|
||||||
|
time.time()
|
||||||
|
) # for pricing this must remain right before calling api
|
||||||
|
|
||||||
|
prediction_url = replicate_config.get_complete_url(api_base, model)
|
||||||
|
|
||||||
|
## COMPLETION CALL
|
||||||
|
httpx_client = _get_httpx_client(
|
||||||
|
params={"timeout": 600.0},
|
||||||
|
)
|
||||||
|
response = httpx_client.post(
|
||||||
|
url=prediction_url,
|
||||||
|
headers=headers,
|
||||||
|
data=json.dumps(input_data),
|
||||||
|
)
|
||||||
|
|
||||||
|
prediction_url = replicate_config.get_prediction_url(response)
|
||||||
|
|
||||||
|
# Handle the prediction response (streaming or non-streaming)
|
||||||
|
if "stream" in optional_params and optional_params["stream"] is True:
|
||||||
|
print_verbose("streaming request")
|
||||||
|
_response = handle_prediction_response_streaming(
|
||||||
|
prediction_url,
|
||||||
|
api_key,
|
||||||
|
print_verbose,
|
||||||
|
headers=headers,
|
||||||
|
http_client=httpx_client,
|
||||||
|
)
|
||||||
|
return CustomStreamWrapper(_response, model, logging_obj=logging_obj, custom_llm_provider="replicate") # type: ignore
|
||||||
|
else:
|
||||||
|
for _ in range(litellm.DEFAULT_MAX_RETRIES):
|
||||||
|
time.sleep(
|
||||||
|
1
|
||||||
|
) # wait 1s to allow response to be generated by replicate - else partial output is generated with status=="processing"
|
||||||
|
response = httpx_client.get(url=prediction_url, headers=headers)
|
||||||
|
return litellm.ReplicateConfig().transform_response(
|
||||||
|
model=model,
|
||||||
|
raw_response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_key=api_key,
|
||||||
|
request_data=input_data,
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise ReplicateError(
|
||||||
|
status_code=500,
|
||||||
|
message="No response received from Replicate API after max retries",
|
||||||
|
headers=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_completion(
|
||||||
|
model_response: ModelResponse,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
encoding,
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
version_id,
|
||||||
|
input_data,
|
||||||
|
api_key,
|
||||||
|
api_base,
|
||||||
|
logging_obj,
|
||||||
|
print_verbose,
|
||||||
|
headers: dict,
|
||||||
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
|
|
||||||
|
prediction_url = replicate_config.get_complete_url(api_base=api_base, model=model)
|
||||||
|
async_handler = get_async_httpx_client(
|
||||||
|
llm_provider=litellm.LlmProviders.REPLICATE,
|
||||||
|
params={"timeout": 600.0},
|
||||||
|
)
|
||||||
|
response = await async_handler.post(
|
||||||
|
url=prediction_url, headers=headers, data=json.dumps(input_data)
|
||||||
|
)
|
||||||
|
prediction_url = replicate_config.get_prediction_url(response)
|
||||||
|
|
||||||
|
if "stream" in optional_params and optional_params["stream"] is True:
|
||||||
|
_response = async_handle_prediction_response_streaming(
|
||||||
|
prediction_url,
|
||||||
|
api_key,
|
||||||
|
print_verbose,
|
||||||
|
headers=headers,
|
||||||
|
http_client=async_handler,
|
||||||
|
)
|
||||||
|
return CustomStreamWrapper(_response, model, logging_obj=logging_obj, custom_llm_provider="replicate") # type: ignore
|
||||||
|
|
||||||
|
for _ in range(litellm.DEFAULT_MAX_RETRIES):
|
||||||
|
await asyncio.sleep(
|
||||||
|
1
|
||||||
|
) # wait 1s to allow response to be generated by replicate - else partial output is generated with status=="processing"
|
||||||
|
response = await async_handler.get(url=prediction_url, headers=headers)
|
||||||
|
return litellm.ReplicateConfig().transform_response(
|
||||||
|
model=model,
|
||||||
|
raw_response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_key=api_key,
|
||||||
|
request_data=input_data,
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
# Add a fallback return if no response is received after max retries
|
||||||
|
raise ReplicateError(
|
||||||
|
status_code=500,
|
||||||
|
message="No response received from Replicate API after max retries",
|
||||||
|
headers=None,
|
||||||
|
)
|
312
litellm/llms/replicate/chat/transformation.py
Normal file
312
litellm/llms/replicate/chat/transformation.py
Normal file
|
@ -0,0 +1,312 @@
|
||||||
|
import types
|
||||||
|
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
||||||
|
from litellm.llms.prompt_templates.common_utils import convert_content_list_to_str
|
||||||
|
from litellm.llms.prompt_templates.factory import custom_prompt, prompt_factory
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
|
from litellm.types.utils import Choices, Message, ModelResponse, Usage
|
||||||
|
from litellm.utils import token_counter
|
||||||
|
|
||||||
|
from ..common_utils import ReplicateError
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
|
|
||||||
|
LoggingClass = LiteLLMLoggingObj
|
||||||
|
else:
|
||||||
|
LoggingClass = Any
|
||||||
|
|
||||||
|
|
||||||
|
class ReplicateConfig(BaseConfig):
|
||||||
|
"""
|
||||||
|
Reference: https://replicate.com/meta/llama-2-70b-chat/api
|
||||||
|
- `prompt` (string): The prompt to send to the model.
|
||||||
|
|
||||||
|
- `system_prompt` (string): The system prompt to send to the model. This is prepended to the prompt and helps guide system behavior. Default value: `You are a helpful assistant`.
|
||||||
|
|
||||||
|
- `max_new_tokens` (integer): Maximum number of tokens to generate. Typically, a word is made up of 2-3 tokens. Default value: `128`.
|
||||||
|
|
||||||
|
- `min_new_tokens` (integer): Minimum number of tokens to generate. To disable, set to `-1`. A word is usually 2-3 tokens. Default value: `-1`.
|
||||||
|
|
||||||
|
- `temperature` (number): Adjusts the randomness of outputs. Values greater than 1 increase randomness, 0 is deterministic, and 0.75 is a reasonable starting value. Default value: `0.75`.
|
||||||
|
|
||||||
|
- `top_p` (number): During text decoding, it samples from the top `p` percentage of most likely tokens. Reduce this to ignore less probable tokens. Default value: `0.9`.
|
||||||
|
|
||||||
|
- `top_k` (integer): During text decoding, samples from the top `k` most likely tokens. Reduce this to ignore less probable tokens. Default value: `50`.
|
||||||
|
|
||||||
|
- `stop_sequences` (string): A comma-separated list of sequences to stop generation at. For example, inputting '<end>,<stop>' will cease generation at the first occurrence of either 'end' or '<stop>'.
|
||||||
|
|
||||||
|
- `seed` (integer): This is the seed for the random generator. Leave it blank to randomize the seed.
|
||||||
|
|
||||||
|
- `debug` (boolean): If set to `True`, it provides debugging output in logs.
|
||||||
|
|
||||||
|
Please note that Replicate's mapping of these parameters can be inconsistent across different models, indicating that not all of these parameters may be available for use with all models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
system_prompt: Optional[str] = None
|
||||||
|
max_new_tokens: Optional[int] = None
|
||||||
|
min_new_tokens: Optional[int] = None
|
||||||
|
temperature: Optional[int] = None
|
||||||
|
top_p: Optional[int] = None
|
||||||
|
top_k: Optional[int] = None
|
||||||
|
stop_sequences: Optional[str] = None
|
||||||
|
seed: Optional[int] = None
|
||||||
|
debug: Optional[bool] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
max_new_tokens: Optional[int] = None,
|
||||||
|
min_new_tokens: Optional[int] = None,
|
||||||
|
temperature: Optional[int] = None,
|
||||||
|
top_p: Optional[int] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
stop_sequences: Optional[str] = None,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
debug: Optional[bool] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return super().get_config()
|
||||||
|
|
||||||
|
def get_supported_openai_params(self, model: str) -> list:
|
||||||
|
return [
|
||||||
|
"stream",
|
||||||
|
"temperature",
|
||||||
|
"max_tokens",
|
||||||
|
"top_p",
|
||||||
|
"stop",
|
||||||
|
"seed",
|
||||||
|
"tools",
|
||||||
|
"tool_choice",
|
||||||
|
"functions",
|
||||||
|
"function_call",
|
||||||
|
]
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self,
|
||||||
|
non_default_params: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
model: str,
|
||||||
|
drop_params: bool,
|
||||||
|
) -> dict:
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param == "stream":
|
||||||
|
optional_params["stream"] = value
|
||||||
|
if param == "max_tokens":
|
||||||
|
if "vicuna" in model or "flan" in model:
|
||||||
|
optional_params["max_length"] = value
|
||||||
|
elif "meta/codellama-13b" in model:
|
||||||
|
optional_params["max_tokens"] = value
|
||||||
|
else:
|
||||||
|
optional_params["max_new_tokens"] = value
|
||||||
|
if param == "temperature":
|
||||||
|
optional_params["temperature"] = value
|
||||||
|
if param == "top_p":
|
||||||
|
optional_params["top_p"] = value
|
||||||
|
if param == "stop":
|
||||||
|
optional_params["stop_sequences"] = value
|
||||||
|
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
# Function to extract version ID from model string
|
||||||
|
def model_to_version_id(self, model: str) -> str:
|
||||||
|
if ":" in model:
|
||||||
|
split_model = model.split(":")
|
||||||
|
return split_model[1]
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _transform_messages(
|
||||||
|
self, messages: List[AllMessageValues]
|
||||||
|
) -> List[AllMessageValues]:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def get_error_class(
|
||||||
|
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||||
|
) -> BaseLLMException:
|
||||||
|
return ReplicateError(
|
||||||
|
status_code=status_code, message=error_message, headers=headers
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_complete_url(self, api_base: str, model: str) -> str:
|
||||||
|
version_id = self.model_to_version_id(model)
|
||||||
|
base_url = api_base
|
||||||
|
if "deployments" in version_id:
|
||||||
|
version_id = version_id.replace("deployments/", "")
|
||||||
|
base_url = f"https://api.replicate.com/v1/deployments/{version_id}"
|
||||||
|
else: # assume it's a model
|
||||||
|
base_url = f"https://api.replicate.com/v1/models/{version_id}"
|
||||||
|
|
||||||
|
base_url = f"{base_url}/predictions"
|
||||||
|
return base_url
|
||||||
|
|
||||||
|
def transform_request(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
headers: dict,
|
||||||
|
) -> dict:
|
||||||
|
## Load Config
|
||||||
|
config = litellm.ReplicateConfig.get_config()
|
||||||
|
for k, v in config.items():
|
||||||
|
if (
|
||||||
|
k not in optional_params
|
||||||
|
): # completion(top_k=3) > replicate_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
|
optional_params[k] = v
|
||||||
|
|
||||||
|
system_prompt = None
|
||||||
|
if optional_params is not None and "supports_system_prompt" in optional_params:
|
||||||
|
supports_sys_prompt = optional_params.pop("supports_system_prompt")
|
||||||
|
else:
|
||||||
|
supports_sys_prompt = False
|
||||||
|
|
||||||
|
if supports_sys_prompt:
|
||||||
|
for i in range(len(messages)):
|
||||||
|
if messages[i]["role"] == "system":
|
||||||
|
first_sys_message = messages.pop(i)
|
||||||
|
system_prompt = convert_content_list_to_str(first_sys_message)
|
||||||
|
break
|
||||||
|
|
||||||
|
if model in litellm.custom_prompt_dict:
|
||||||
|
# check if the model has a registered custom prompt
|
||||||
|
model_prompt_details = litellm.custom_prompt_dict[model]
|
||||||
|
prompt = custom_prompt(
|
||||||
|
role_dict=model_prompt_details.get("roles", {}),
|
||||||
|
initial_prompt_value=model_prompt_details.get(
|
||||||
|
"initial_prompt_value", ""
|
||||||
|
),
|
||||||
|
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
||||||
|
bos_token=model_prompt_details.get("bos_token", ""),
|
||||||
|
eos_token=model_prompt_details.get("eos_token", ""),
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt = prompt_factory(model=model, messages=messages)
|
||||||
|
|
||||||
|
if prompt is None or not isinstance(prompt, str):
|
||||||
|
raise ReplicateError(
|
||||||
|
status_code=400,
|
||||||
|
message="LiteLLM Error - prompt is not a string - {}".format(prompt),
|
||||||
|
headers={},
|
||||||
|
)
|
||||||
|
|
||||||
|
# If system prompt is supported, and a system prompt is provided, use it
|
||||||
|
if system_prompt is not None:
|
||||||
|
input_data = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"system_prompt": system_prompt,
|
||||||
|
**optional_params,
|
||||||
|
}
|
||||||
|
# Otherwise, use the prompt as is
|
||||||
|
else:
|
||||||
|
input_data = {"prompt": prompt, **optional_params}
|
||||||
|
|
||||||
|
version_id = self.model_to_version_id(model)
|
||||||
|
request_data: dict = {"input": input_data}
|
||||||
|
if ":" in version_id and len(version_id) > 64:
|
||||||
|
model_parts = version_id.split(":")
|
||||||
|
if (
|
||||||
|
len(model_parts) > 1 and len(model_parts[1]) == 64
|
||||||
|
): ## checks if model name has a 64 digit code - e.g. "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3"
|
||||||
|
request_data["version"] = model_parts[1]
|
||||||
|
|
||||||
|
return request_data
|
||||||
|
|
||||||
|
def transform_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
raw_response: httpx.Response,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
logging_obj: LoggingClass,
|
||||||
|
request_data: dict,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
encoding: Any,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
json_mode: Optional[bool] = None,
|
||||||
|
) -> ModelResponse:
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key=api_key,
|
||||||
|
original_response=raw_response.text,
|
||||||
|
additional_args={"complete_input_dict": request_data},
|
||||||
|
)
|
||||||
|
raw_response_json = raw_response.json()
|
||||||
|
if raw_response_json.get("status") != "succeeded":
|
||||||
|
raise ReplicateError(
|
||||||
|
status_code=422,
|
||||||
|
message="LiteLLM Error - prediction not succeeded - {}".format(
|
||||||
|
raw_response_json
|
||||||
|
),
|
||||||
|
headers=raw_response.headers,
|
||||||
|
)
|
||||||
|
outputs = raw_response_json.get("output", [])
|
||||||
|
response_str = "".join(outputs)
|
||||||
|
if len(response_str) == 0: # edge case, where result from replicate is empty
|
||||||
|
response_str = " "
|
||||||
|
|
||||||
|
## Building RESPONSE OBJECT
|
||||||
|
if len(response_str) >= 1:
|
||||||
|
model_response.choices[0].message.content = response_str # type: ignore
|
||||||
|
|
||||||
|
# Calculate usage
|
||||||
|
prompt_tokens = token_counter(model=model, messages=messages)
|
||||||
|
completion_tokens = token_counter(
|
||||||
|
model=model,
|
||||||
|
text=response_str,
|
||||||
|
count_response_tokens=True,
|
||||||
|
)
|
||||||
|
model_response.model = "replicate/" + model
|
||||||
|
usage = Usage(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
)
|
||||||
|
setattr(model_response, "usage", usage)
|
||||||
|
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
def get_prediction_url(self, response: httpx.Response) -> str:
|
||||||
|
"""
|
||||||
|
response json: {
|
||||||
|
...,
|
||||||
|
"urls":{"cancel":"https://api.replicate.com/v1/predictions/gqsmqmp1pdrj00cknr08dgmvb4/cancel","get":"https://api.replicate.com/v1/predictions/gqsmqmp1pdrj00cknr08dgmvb4","stream":"https://stream-b.svc.rno2.c.replicate.net/v1/streams/eot4gbydowuin4snhncydwxt57dfwgsc3w3snycx5nid7oef7jga"}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
response_json = response.json()
|
||||||
|
prediction_url = response_json.get("urls", {}).get("get")
|
||||||
|
if prediction_url is None:
|
||||||
|
raise ReplicateError(
|
||||||
|
status_code=400,
|
||||||
|
message="LiteLLM Error - prediction url is None - {}".format(
|
||||||
|
response_json
|
||||||
|
),
|
||||||
|
headers=response.headers,
|
||||||
|
)
|
||||||
|
return prediction_url
|
||||||
|
|
||||||
|
def validate_environment(
|
||||||
|
self,
|
||||||
|
headers: dict,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Token {api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
return headers
|
15
litellm/llms/replicate/common_utils.py
Normal file
15
litellm/llms/replicate/common_utils.py
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from litellm.llms.base_llm.transformation import BaseLLMException
|
||||||
|
|
||||||
|
|
||||||
|
class ReplicateError(BaseLLMException):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
status_code: int,
|
||||||
|
message: str,
|
||||||
|
headers: Optional[Union[dict, httpx.Headers]],
|
||||||
|
):
|
||||||
|
super().__init__(status_code=status_code, message=message, headers=headers)
|
|
@ -363,6 +363,7 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
messages=messages,
|
messages=messages,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def make_async_call(
|
async def make_async_call(
|
||||||
|
@ -562,6 +563,7 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
messages=messages,
|
messages=messages,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
def embedding(
|
def embedding(
|
||||||
|
|
|
@ -202,6 +202,7 @@ class SagemakerConfig(BaseConfig):
|
||||||
request_data: dict,
|
request_data: dict,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
encoding: str,
|
encoding: str,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
json_mode: Optional[bool] = None,
|
json_mode: Optional[bool] = None,
|
||||||
|
|
|
@ -7,8 +7,10 @@ this is OpenAI compatible - no translation needed / occurs
|
||||||
import types
|
import types
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||||
|
|
||||||
class SambanovaConfig:
|
|
||||||
|
class SambanovaConfig(OpenAIGPTConfig):
|
||||||
"""
|
"""
|
||||||
Reference: https://community.sambanova.ai/t/create-chat-completion-api/
|
Reference: https://community.sambanova.ai/t/create-chat-completion-api/
|
||||||
|
|
||||||
|
@ -18,9 +20,7 @@ class SambanovaConfig:
|
||||||
max_tokens: Optional[int] = None
|
max_tokens: Optional[int] = None
|
||||||
response_format: Optional[dict] = None
|
response_format: Optional[dict] = None
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
stop: Optional[str] = None
|
|
||||||
stream: Optional[bool] = None
|
stream: Optional[bool] = None
|
||||||
temperature: Optional[float] = None
|
|
||||||
top_p: Optional[int] = None
|
top_p: Optional[int] = None
|
||||||
tool_choice: Optional[str] = None
|
tool_choice: Optional[str] = None
|
||||||
tools: Optional[list] = None
|
tools: Optional[list] = None
|
||||||
|
@ -46,21 +46,7 @@ class SambanovaConfig:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config(cls):
|
def get_config(cls):
|
||||||
return {
|
return super().get_config()
|
||||||
k: v
|
|
||||||
for k, v in cls.__dict__.items()
|
|
||||||
if not k.startswith("__")
|
|
||||||
and not isinstance(
|
|
||||||
v,
|
|
||||||
(
|
|
||||||
types.FunctionType,
|
|
||||||
types.BuiltinFunctionType,
|
|
||||||
classmethod,
|
|
||||||
staticmethod,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
and v is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_supported_openai_params(self, model: str) -> list:
|
def get_supported_openai_params(self, model: str) -> list:
|
||||||
"""
|
"""
|
||||||
|
@ -80,12 +66,3 @@ class SambanovaConfig:
|
||||||
"tools",
|
"tools",
|
||||||
"user",
|
"user",
|
||||||
]
|
]
|
||||||
|
|
||||||
def map_openai_params(
|
|
||||||
self, model: str, non_default_params: dict, optional_params: dict
|
|
||||||
) -> dict:
|
|
||||||
supported_openai_params = self.get_supported_openai_params(model=model)
|
|
||||||
for param, value in non_default_params.items():
|
|
||||||
if param in supported_openai_params:
|
|
||||||
optional_params[param] = value
|
|
||||||
return optional_params
|
|
||||||
|
|
|
@ -22,6 +22,7 @@ from litellm.llms.custom_httpx.http_handler import (
|
||||||
AsyncHTTPHandler,
|
AsyncHTTPHandler,
|
||||||
get_async_httpx_client,
|
get_async_httpx_client,
|
||||||
)
|
)
|
||||||
|
from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig
|
||||||
from litellm.types.llms.databricks import GenericStreamingChunk
|
from litellm.types.llms.databricks import GenericStreamingChunk
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
Choices,
|
Choices,
|
||||||
|
@ -91,19 +92,17 @@ async def make_call(
|
||||||
return completion_stream
|
return completion_stream
|
||||||
|
|
||||||
|
|
||||||
class MistralTextCompletionConfig:
|
class MistralTextCompletionConfig(OpenAITextCompletionConfig):
|
||||||
"""
|
"""
|
||||||
Reference: https://docs.mistral.ai/api/#operation/createFIMCompletion
|
Reference: https://docs.mistral.ai/api/#operation/createFIMCompletion
|
||||||
"""
|
"""
|
||||||
|
|
||||||
suffix: Optional[str] = None
|
suffix: Optional[str] = None
|
||||||
temperature: Optional[int] = None
|
temperature: Optional[int] = None
|
||||||
top_p: Optional[float] = None
|
|
||||||
max_tokens: Optional[int] = None
|
max_tokens: Optional[int] = None
|
||||||
min_tokens: Optional[int] = None
|
min_tokens: Optional[int] = None
|
||||||
stream: Optional[bool] = None
|
stream: Optional[bool] = None
|
||||||
random_seed: Optional[int] = None
|
random_seed: Optional[int] = None
|
||||||
stop: Optional[str] = None
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -123,23 +122,9 @@ class MistralTextCompletionConfig:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config(cls):
|
def get_config(cls):
|
||||||
return {
|
return super().get_config()
|
||||||
k: v
|
|
||||||
for k, v in cls.__dict__.items()
|
|
||||||
if not k.startswith("__")
|
|
||||||
and not isinstance(
|
|
||||||
v,
|
|
||||||
(
|
|
||||||
types.FunctionType,
|
|
||||||
types.BuiltinFunctionType,
|
|
||||||
classmethod,
|
|
||||||
staticmethod,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
and v is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_supported_openai_params(self):
|
def get_supported_openai_params(self, model: str):
|
||||||
return [
|
return [
|
||||||
"suffix",
|
"suffix",
|
||||||
"temperature",
|
"temperature",
|
||||||
|
@ -151,7 +136,13 @@ class MistralTextCompletionConfig:
|
||||||
"stop",
|
"stop",
|
||||||
]
|
]
|
||||||
|
|
||||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
def map_openai_params(
|
||||||
|
self,
|
||||||
|
non_default_params: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
model: str,
|
||||||
|
drop_params: bool,
|
||||||
|
) -> dict:
|
||||||
for param, value in non_default_params.items():
|
for param, value in non_default_params.items():
|
||||||
if param == "suffix":
|
if param == "suffix":
|
||||||
optional_params["suffix"] = value
|
optional_params["suffix"] = value
|
||||||
|
|
|
@ -1,22 +1,20 @@
|
||||||
from typing import List, Literal, Tuple
|
from typing import Dict, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from litellm import supports_response_schema, supports_system_messages, verbose_logger
|
from litellm import supports_response_schema, supports_system_messages, verbose_logger
|
||||||
|
from litellm.llms.base_llm.transformation import BaseLLMException
|
||||||
from litellm.types.llms.vertex_ai import PartType
|
from litellm.types.llms.vertex_ai import PartType
|
||||||
|
|
||||||
|
|
||||||
class VertexAIError(Exception):
|
class VertexAIError(BaseLLMException):
|
||||||
def __init__(self, status_code, message):
|
def __init__(
|
||||||
self.status_code = status_code
|
self,
|
||||||
self.message = message
|
status_code: int,
|
||||||
self.request = httpx.Request(
|
message: str,
|
||||||
method="POST", url=" https://cloud.google.com/vertex-ai/"
|
headers: Optional[Union[Dict, httpx.Headers]] = None,
|
||||||
)
|
):
|
||||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
super().__init__(message=message, status_code=status_code, headers=headers)
|
||||||
super().__init__(
|
|
||||||
self.message
|
|
||||||
) # Call the base class constructor with the parameters it needs
|
|
||||||
|
|
||||||
|
|
||||||
def get_supports_system_message(
|
def get_supports_system_message(
|
||||||
|
|
|
@ -299,11 +299,13 @@ def _transform_request_body(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if custom_llm_provider == "gemini":
|
if custom_llm_provider == "gemini":
|
||||||
content = litellm.GoogleAIStudioGeminiConfig._transform_messages(
|
content = litellm.GoogleAIStudioGeminiConfig()._transform_messages(
|
||||||
messages=messages
|
messages=messages
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
content = litellm.VertexGeminiConfig._transform_messages(messages=messages)
|
content = litellm.VertexGeminiConfig()._transform_messages(
|
||||||
|
messages=messages
|
||||||
|
)
|
||||||
tools: Optional[Tools] = optional_params.pop("tools", None)
|
tools: Optional[Tools] = optional_params.pop("tools", None)
|
||||||
tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None)
|
tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None)
|
||||||
safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop(
|
safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop(
|
||||||
|
@ -460,15 +462,3 @@ def _transform_system_message(
|
||||||
return SystemInstructions(parts=system_content_blocks), messages
|
return SystemInstructions(parts=system_content_blocks), messages
|
||||||
|
|
||||||
return None, messages
|
return None, messages
|
||||||
|
|
||||||
|
|
||||||
def set_headers(auth_header: Optional[str], extra_headers: Optional[dict]) -> dict:
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
if auth_header is not None:
|
|
||||||
headers["Authorization"] = f"Bearer {auth_header}"
|
|
||||||
if extra_headers is not None:
|
|
||||||
headers.update(extra_headers)
|
|
||||||
|
|
||||||
return headers
|
|
||||||
|
|
|
@ -20,6 +20,7 @@ from typing import (
|
||||||
Optional,
|
Optional,
|
||||||
Tuple,
|
Tuple,
|
||||||
Union,
|
Union,
|
||||||
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
import httpx # type: ignore
|
import httpx # type: ignore
|
||||||
|
@ -30,6 +31,7 @@ import litellm.litellm_core_utils
|
||||||
import litellm.litellm_core_utils.litellm_logging
|
import litellm.litellm_core_utils.litellm_logging
|
||||||
from litellm import verbose_logger
|
from litellm import verbose_logger
|
||||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
|
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
||||||
from litellm.llms.custom_httpx.http_handler import (
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
AsyncHTTPHandler,
|
AsyncHTTPHandler,
|
||||||
HTTPHandler,
|
HTTPHandler,
|
||||||
|
@ -86,10 +88,16 @@ from .transformation import (
|
||||||
_gemini_convert_messages_with_history,
|
_gemini_convert_messages_with_history,
|
||||||
_process_gemini_image,
|
_process_gemini_image,
|
||||||
async_transform_request_body,
|
async_transform_request_body,
|
||||||
set_headers,
|
|
||||||
sync_transform_request_body,
|
sync_transform_request_body,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
|
|
||||||
|
LoggingClass = LiteLLMLoggingObj
|
||||||
|
else:
|
||||||
|
LoggingClass = Any
|
||||||
|
|
||||||
|
|
||||||
class VertexAIConfig:
|
class VertexAIConfig:
|
||||||
"""
|
"""
|
||||||
|
@ -277,7 +285,7 @@ class VertexAIConfig:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class VertexGeminiConfig:
|
class VertexGeminiConfig(BaseConfig):
|
||||||
"""
|
"""
|
||||||
Reference: https://cloud.google.com/vertex-ai/docs/generative-ai/chat/test-chat-prompts
|
Reference: https://cloud.google.com/vertex-ai/docs/generative-ai/chat/test-chat-prompts
|
||||||
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
|
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
|
||||||
|
@ -338,23 +346,9 @@ class VertexGeminiConfig:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config(cls):
|
def get_config(cls):
|
||||||
return {
|
return super().get_config()
|
||||||
k: v
|
|
||||||
for k, v in cls.__dict__.items()
|
|
||||||
if not k.startswith("__")
|
|
||||||
and not isinstance(
|
|
||||||
v,
|
|
||||||
(
|
|
||||||
types.FunctionType,
|
|
||||||
types.BuiltinFunctionType,
|
|
||||||
classmethod,
|
|
||||||
staticmethod,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
and v is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_supported_openai_params(self):
|
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||||
return [
|
return [
|
||||||
"temperature",
|
"temperature",
|
||||||
"top_p",
|
"top_p",
|
||||||
|
@ -473,12 +467,11 @@ class VertexGeminiConfig:
|
||||||
|
|
||||||
def map_openai_params(
|
def map_openai_params(
|
||||||
self,
|
self,
|
||||||
|
non_default_params: Dict,
|
||||||
|
optional_params: Dict,
|
||||||
model: str,
|
model: str,
|
||||||
non_default_params: dict,
|
|
||||||
optional_params: dict,
|
|
||||||
drop_params: bool,
|
drop_params: bool,
|
||||||
):
|
) -> Dict:
|
||||||
|
|
||||||
for param, value in non_default_params.items():
|
for param, value in non_default_params.items():
|
||||||
if param == "temperature":
|
if param == "temperature":
|
||||||
optional_params["temperature"] = value
|
optional_params["temperature"] = value
|
||||||
|
@ -751,38 +744,38 @@ class VertexGeminiConfig:
|
||||||
|
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
def _transform_response(
|
def transform_response(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
response: httpx.Response,
|
raw_response: httpx.Response,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
logging_obj: litellm.litellm_core_utils.litellm_logging.Logging,
|
logging_obj: LoggingClass,
|
||||||
optional_params: dict,
|
request_data: Dict,
|
||||||
litellm_params: dict,
|
messages: List[AllMessageValues],
|
||||||
api_key: str,
|
optional_params: Dict,
|
||||||
data: Union[dict, str, RequestBody],
|
litellm_params: Dict,
|
||||||
messages: List,
|
encoding: Any,
|
||||||
print_verbose,
|
api_key: Optional[str] = None,
|
||||||
encoding,
|
json_mode: Optional[bool] = None,
|
||||||
) -> ModelResponse:
|
) -> ModelResponse:
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=messages,
|
input=messages,
|
||||||
api_key="",
|
api_key="",
|
||||||
original_response=response.text,
|
original_response=raw_response.text,
|
||||||
additional_args={"complete_input_dict": data},
|
additional_args={"complete_input_dict": request_data},
|
||||||
)
|
)
|
||||||
|
|
||||||
## RESPONSE OBJECT
|
## RESPONSE OBJECT
|
||||||
try:
|
try:
|
||||||
completion_response = GenerateContentResponseBody(**response.json()) # type: ignore
|
completion_response = GenerateContentResponseBody(**raw_response.json()) # type: ignore
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise VertexAIError(
|
raise VertexAIError(
|
||||||
message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format(
|
message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format(
|
||||||
response.text, str(e)
|
raw_response.text, str(e)
|
||||||
),
|
),
|
||||||
status_code=422,
|
status_code=422,
|
||||||
|
headers=raw_response.headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
## GET MODEL ##
|
## GET MODEL ##
|
||||||
|
@ -915,14 +908,53 @@ class VertexGeminiConfig:
|
||||||
completion_response, str(e)
|
completion_response, str(e)
|
||||||
),
|
),
|
||||||
status_code=422,
|
status_code=422,
|
||||||
|
headers=raw_response.headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
@staticmethod
|
def _transform_messages(
|
||||||
def _transform_messages(messages: List[AllMessageValues]) -> List[ContentType]:
|
self, messages: List[AllMessageValues]
|
||||||
|
) -> List[ContentType]:
|
||||||
return _gemini_convert_messages_with_history(messages=messages)
|
return _gemini_convert_messages_with_history(messages=messages)
|
||||||
|
|
||||||
|
def get_error_class(
|
||||||
|
self, error_message: str, status_code: int, headers: Union[Dict, httpx.Headers]
|
||||||
|
) -> BaseLLMException:
|
||||||
|
return VertexAIError(
|
||||||
|
message=error_message, status_code=status_code, headers=headers
|
||||||
|
)
|
||||||
|
|
||||||
|
def transform_request(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: Dict,
|
||||||
|
litellm_params: Dict,
|
||||||
|
headers: Dict,
|
||||||
|
) -> Dict:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Vertex AI has a custom implementation of transform_request. Needs sync + async."
|
||||||
|
)
|
||||||
|
|
||||||
|
def validate_environment(
|
||||||
|
self,
|
||||||
|
headers: Optional[Dict],
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: Dict,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
) -> Dict:
|
||||||
|
default_headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
if api_key is not None:
|
||||||
|
default_headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
if headers is not None:
|
||||||
|
default_headers.update(headers)
|
||||||
|
|
||||||
|
return default_headers
|
||||||
|
|
||||||
|
|
||||||
class GoogleAIStudioGeminiConfig(
|
class GoogleAIStudioGeminiConfig(
|
||||||
VertexGeminiConfig
|
VertexGeminiConfig
|
||||||
|
@ -978,23 +1010,9 @@ class GoogleAIStudioGeminiConfig(
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config(cls):
|
def get_config(cls):
|
||||||
return {
|
return super().get_config()
|
||||||
k: v
|
|
||||||
for k, v in cls.__dict__.items()
|
|
||||||
if not k.startswith("__")
|
|
||||||
and not isinstance(
|
|
||||||
v,
|
|
||||||
(
|
|
||||||
types.FunctionType,
|
|
||||||
types.BuiltinFunctionType,
|
|
||||||
classmethod,
|
|
||||||
staticmethod,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
and v is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_supported_openai_params(self):
|
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||||
return [
|
return [
|
||||||
"temperature",
|
"temperature",
|
||||||
"top_p",
|
"top_p",
|
||||||
|
@ -1012,22 +1030,27 @@ class GoogleAIStudioGeminiConfig(
|
||||||
|
|
||||||
def map_openai_params(
|
def map_openai_params(
|
||||||
self,
|
self,
|
||||||
model: str,
|
|
||||||
non_default_params: Dict,
|
non_default_params: Dict,
|
||||||
optional_params: Dict,
|
optional_params: Dict,
|
||||||
|
model: str,
|
||||||
drop_params: bool,
|
drop_params: bool,
|
||||||
):
|
) -> Dict:
|
||||||
|
|
||||||
# drop frequency_penalty and presence_penalty
|
# drop frequency_penalty and presence_penalty
|
||||||
if "frequency_penalty" in non_default_params:
|
if "frequency_penalty" in non_default_params:
|
||||||
del non_default_params["frequency_penalty"]
|
del non_default_params["frequency_penalty"]
|
||||||
if "presence_penalty" in non_default_params:
|
if "presence_penalty" in non_default_params:
|
||||||
del non_default_params["presence_penalty"]
|
del non_default_params["presence_penalty"]
|
||||||
return super().map_openai_params(
|
return super().map_openai_params(
|
||||||
model, non_default_params, optional_params, drop_params
|
model=model,
|
||||||
|
non_default_params=non_default_params,
|
||||||
|
optional_params=optional_params,
|
||||||
|
drop_params=drop_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
def _transform_messages(
|
||||||
def _transform_messages(messages: List[AllMessageValues]) -> List[ContentType]:
|
self, messages: List[AllMessageValues]
|
||||||
|
) -> List[ContentType]:
|
||||||
"""
|
"""
|
||||||
Google AI Studio Gemini does not support image urls in messages.
|
Google AI Studio Gemini does not support image urls in messages.
|
||||||
"""
|
"""
|
||||||
|
@ -1075,9 +1098,14 @@ async def make_call(
|
||||||
raise VertexAIError(
|
raise VertexAIError(
|
||||||
status_code=e.response.status_code,
|
status_code=e.response.status_code,
|
||||||
message=VertexGeminiConfig().translate_exception_str(exception_string),
|
message=VertexGeminiConfig().translate_exception_str(exception_string),
|
||||||
|
headers=e.response.headers,
|
||||||
)
|
)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise VertexAIError(status_code=response.status_code, message=response.text)
|
raise VertexAIError(
|
||||||
|
status_code=response.status_code,
|
||||||
|
message=response.text,
|
||||||
|
headers=response.headers,
|
||||||
|
)
|
||||||
|
|
||||||
completion_stream = ModelResponseIterator(
|
completion_stream = ModelResponseIterator(
|
||||||
streaming_response=response.aiter_lines(), sync_stream=False
|
streaming_response=response.aiter_lines(), sync_stream=False
|
||||||
|
@ -1111,7 +1139,11 @@ def make_sync_call(
|
||||||
response = client.post(api_base, headers=headers, data=data, stream=True)
|
response = client.post(api_base, headers=headers, data=data, stream=True)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise VertexAIError(status_code=response.status_code, message=response.read())
|
raise VertexAIError(
|
||||||
|
status_code=response.status_code,
|
||||||
|
message=str(response.read()),
|
||||||
|
headers=response.headers,
|
||||||
|
)
|
||||||
|
|
||||||
completion_stream = ModelResponseIterator(
|
completion_stream = ModelResponseIterator(
|
||||||
streaming_response=response.iter_lines(), sync_stream=True
|
streaming_response=response.iter_lines(), sync_stream=True
|
||||||
|
@ -1182,7 +1214,13 @@ class VertexLLM(VertexBase):
|
||||||
should_use_v1beta1_features=should_use_v1beta1_features,
|
should_use_v1beta1_features=should_use_v1beta1_features,
|
||||||
)
|
)
|
||||||
|
|
||||||
headers = set_headers(auth_header=auth_header, extra_headers=extra_headers)
|
headers = VertexGeminiConfig().validate_environment(
|
||||||
|
api_key=auth_header,
|
||||||
|
headers=extra_headers,
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
)
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
|
@ -1263,7 +1301,13 @@ class VertexLLM(VertexBase):
|
||||||
should_use_v1beta1_features=should_use_v1beta1_features,
|
should_use_v1beta1_features=should_use_v1beta1_features,
|
||||||
)
|
)
|
||||||
|
|
||||||
headers = set_headers(auth_header=auth_header, extra_headers=extra_headers)
|
headers = VertexGeminiConfig().validate_environment(
|
||||||
|
api_key=auth_header,
|
||||||
|
headers=extra_headers,
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
)
|
||||||
|
|
||||||
request_body = await async_transform_request_body(**data) # type: ignore
|
request_body = await async_transform_request_body(**data) # type: ignore
|
||||||
_async_client_params = {}
|
_async_client_params = {}
|
||||||
|
@ -1287,23 +1331,32 @@ class VertexLLM(VertexBase):
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await client.post(api_base, headers=headers, json=request_body) # type: ignore
|
response = await client.post(
|
||||||
|
api_base, headers=headers, json=cast(dict, request_body)
|
||||||
|
) # type: ignore
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
except httpx.HTTPStatusError as err:
|
except httpx.HTTPStatusError as err:
|
||||||
error_code = err.response.status_code
|
error_code = err.response.status_code
|
||||||
raise VertexAIError(status_code=error_code, message=err.response.text)
|
raise VertexAIError(
|
||||||
|
status_code=error_code,
|
||||||
|
message=err.response.text,
|
||||||
|
headers=err.response.headers,
|
||||||
|
)
|
||||||
except httpx.TimeoutException:
|
except httpx.TimeoutException:
|
||||||
raise VertexAIError(status_code=408, message="Timeout error occurred.")
|
raise VertexAIError(
|
||||||
|
status_code=408,
|
||||||
|
message="Timeout error occurred.",
|
||||||
|
headers=None,
|
||||||
|
)
|
||||||
|
|
||||||
return VertexGeminiConfig()._transform_response(
|
return VertexGeminiConfig().transform_response(
|
||||||
model=model,
|
model=model,
|
||||||
response=response,
|
raw_response=response,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
api_key="",
|
api_key="",
|
||||||
data=request_body,
|
request_data=cast(dict, request_body),
|
||||||
messages=messages,
|
messages=messages,
|
||||||
print_verbose=print_verbose,
|
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
|
@ -1421,7 +1474,13 @@ class VertexLLM(VertexBase):
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
should_use_v1beta1_features=should_use_v1beta1_features,
|
should_use_v1beta1_features=should_use_v1beta1_features,
|
||||||
)
|
)
|
||||||
headers = set_headers(auth_header=auth_header, extra_headers=extra_headers)
|
headers = VertexGeminiConfig().validate_environment(
|
||||||
|
api_key=auth_header,
|
||||||
|
headers=extra_headers,
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
)
|
||||||
|
|
||||||
## TRANSFORMATION ##
|
## TRANSFORMATION ##
|
||||||
data = sync_transform_request_body(**transform_request_params)
|
data = sync_transform_request_body(**transform_request_params)
|
||||||
|
@ -1479,21 +1538,28 @@ class VertexLLM(VertexBase):
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
except httpx.HTTPStatusError as err:
|
except httpx.HTTPStatusError as err:
|
||||||
error_code = err.response.status_code
|
error_code = err.response.status_code
|
||||||
raise VertexAIError(status_code=error_code, message=err.response.text)
|
raise VertexAIError(
|
||||||
|
status_code=error_code,
|
||||||
|
message=err.response.text,
|
||||||
|
headers=err.response.headers,
|
||||||
|
)
|
||||||
except httpx.TimeoutException:
|
except httpx.TimeoutException:
|
||||||
raise VertexAIError(status_code=408, message="Timeout error occurred.")
|
raise VertexAIError(
|
||||||
|
status_code=408,
|
||||||
|
message="Timeout error occurred.",
|
||||||
|
headers=None,
|
||||||
|
)
|
||||||
|
|
||||||
return VertexGeminiConfig()._transform_response(
|
return VertexGeminiConfig().transform_response(
|
||||||
model=model,
|
model=model,
|
||||||
response=response,
|
raw_response=response,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
api_key="",
|
api_key="",
|
||||||
data=data, # type: ignore
|
request_data=data, # type: ignore
|
||||||
messages=messages,
|
messages=messages,
|
||||||
print_verbose=print_verbose,
|
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -2,9 +2,10 @@ import types
|
||||||
from typing import Literal, Optional, Union
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm.llms.openai_like.chat.transformation import OpenAILikeChatConfig
|
||||||
|
|
||||||
|
|
||||||
class VolcEngineConfig:
|
class VolcEngineConfig(OpenAILikeChatConfig):
|
||||||
frequency_penalty: Optional[int] = None
|
frequency_penalty: Optional[int] = None
|
||||||
function_call: Optional[Union[str, dict]] = None
|
function_call: Optional[Union[str, dict]] = None
|
||||||
functions: Optional[list] = None
|
functions: Optional[list] = None
|
||||||
|
@ -38,21 +39,7 @@ class VolcEngineConfig:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config(cls):
|
def get_config(cls):
|
||||||
return {
|
return super().get_config()
|
||||||
k: v
|
|
||||||
for k, v in cls.__dict__.items()
|
|
||||||
if not k.startswith("__")
|
|
||||||
and not isinstance(
|
|
||||||
v,
|
|
||||||
(
|
|
||||||
types.FunctionType,
|
|
||||||
types.BuiltinFunctionType,
|
|
||||||
classmethod,
|
|
||||||
staticmethod,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
and v is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_supported_openai_params(self, model: str) -> list:
|
def get_supported_openai_params(self, model: str) -> list:
|
||||||
return [
|
return [
|
||||||
|
@ -77,14 +64,3 @@ class VolcEngineConfig:
|
||||||
"max_retries",
|
"max_retries",
|
||||||
"extra_headers",
|
"extra_headers",
|
||||||
] # works across all models
|
] # works across all models
|
||||||
|
|
||||||
def map_openai_params(
|
|
||||||
self, non_default_params: dict, optional_params: dict, model: str
|
|
||||||
) -> dict:
|
|
||||||
supported_openai_params = self.get_supported_openai_params(model)
|
|
||||||
for param, value in non_default_params.items():
|
|
||||||
if param == "max_completion_tokens":
|
|
||||||
optional_params["max_tokens"] = value
|
|
||||||
elif param in supported_openai_params:
|
|
||||||
optional_params[param] = value
|
|
||||||
return optional_params
|
|
||||||
|
|
|
@ -274,6 +274,7 @@ class IBMWatsonXAIConfig(BaseConfig):
|
||||||
request_data: Dict,
|
request_data: Dict,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: Dict,
|
optional_params: Dict,
|
||||||
|
litellm_params: Dict,
|
||||||
encoding: str,
|
encoding: str,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
json_mode: Optional[bool] = None,
|
json_mode: Optional[bool] = None,
|
||||||
|
|
200
litellm/main.py
200
litellm/main.py
|
@ -83,26 +83,13 @@ from litellm.utils import (
|
||||||
from ._logging import verbose_logger
|
from ._logging import verbose_logger
|
||||||
from .caching.caching import disable_cache, enable_cache, update_cache
|
from .caching.caching import disable_cache, enable_cache, update_cache
|
||||||
from .litellm_core_utils.streaming_chunk_builder_utils import ChunkProcessor
|
from .litellm_core_utils.streaming_chunk_builder_utils import ChunkProcessor
|
||||||
from .llms import (
|
from .llms import aleph_alpha, baseten, maritalk, ollama_chat, petals
|
||||||
aleph_alpha,
|
|
||||||
baseten,
|
|
||||||
maritalk,
|
|
||||||
nlp_cloud,
|
|
||||||
ollama_chat,
|
|
||||||
oobabooga,
|
|
||||||
openrouter,
|
|
||||||
palm,
|
|
||||||
petals,
|
|
||||||
replicate,
|
|
||||||
)
|
|
||||||
from .llms.ai21 import completion as ai21
|
|
||||||
from .llms.anthropic.chat import AnthropicChatCompletion
|
from .llms.anthropic.chat import AnthropicChatCompletion
|
||||||
from .llms.azure.audio_transcriptions import AzureAudioTranscription
|
from .llms.azure.audio_transcriptions import AzureAudioTranscription
|
||||||
from .llms.azure.azure import AzureChatCompletion, _check_dynamic_azure_params
|
from .llms.azure.azure import AzureChatCompletion, _check_dynamic_azure_params
|
||||||
from .llms.azure.chat.o1_handler import AzureOpenAIO1ChatCompletion
|
from .llms.azure.chat.o1_handler import AzureOpenAIO1ChatCompletion
|
||||||
from .llms.azure_ai.chat import AzureAIChatCompletion
|
from .llms.azure.completion.handler import AzureTextCompletion
|
||||||
from .llms.azure_ai.embed import AzureAIEmbedding
|
from .llms.azure_ai.embed import AzureAIEmbedding
|
||||||
from .llms.azure_text import AzureTextCompletion
|
|
||||||
from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
|
from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
|
||||||
from .llms.bedrock.embed.embedding import BedrockEmbedding
|
from .llms.bedrock.embed.embedding import BedrockEmbedding
|
||||||
from .llms.bedrock.image.image_handler import BedrockImageGeneration
|
from .llms.bedrock.image.image_handler import BedrockImageGeneration
|
||||||
|
@ -111,13 +98,16 @@ from .llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
||||||
from .llms.custom_llm import CustomLLM, custom_chat_llm_router
|
from .llms.custom_llm import CustomLLM, custom_chat_llm_router
|
||||||
from .llms.databricks.chat.handler import DatabricksChatCompletion
|
from .llms.databricks.chat.handler import DatabricksChatCompletion
|
||||||
from .llms.databricks.embed.handler import DatabricksEmbeddingHandler
|
from .llms.databricks.embed.handler import DatabricksEmbeddingHandler
|
||||||
|
from .llms.deprecated_providers import palm
|
||||||
from .llms.groq.chat.handler import GroqChatCompletion
|
from .llms.groq.chat.handler import GroqChatCompletion
|
||||||
from .llms.huggingface_restapi import Huggingface
|
from .llms.huggingface.chat.handler import Huggingface
|
||||||
|
from .llms.nlp_cloud.chat.handler import completion as nlp_cloud_chat_completion
|
||||||
|
from .llms.oobabooga.chat import oobabooga
|
||||||
from .llms.ollama.completion import handler as ollama
|
from .llms.ollama.completion import handler as ollama
|
||||||
from .llms.openai.transcriptions.handler import OpenAIAudioTranscription
|
from .llms.openai.transcriptions.handler import OpenAIAudioTranscription
|
||||||
from .llms.openai.chat.o1_handler import OpenAIO1ChatCompletion
|
|
||||||
from .llms.openai.completion.handler import OpenAITextCompletion
|
from .llms.openai.completion.handler import OpenAITextCompletion
|
||||||
from .llms.openai.openai import OpenAIChatCompletion
|
from .llms.openai.openai import OpenAIChatCompletion
|
||||||
|
from .llms.openai_like.chat.handler import OpenAILikeChatHandler
|
||||||
from .llms.openai_like.embedding.handler import OpenAILikeEmbeddingHandler
|
from .llms.openai_like.embedding.handler import OpenAILikeEmbeddingHandler
|
||||||
from .llms.predibase import PredibaseChatCompletion
|
from .llms.predibase import PredibaseChatCompletion
|
||||||
from .llms.prompt_templates.common_utils import get_completion_messages
|
from .llms.prompt_templates.common_utils import get_completion_messages
|
||||||
|
@ -131,6 +121,7 @@ from .llms.prompt_templates.factory import (
|
||||||
)
|
)
|
||||||
from .llms.sagemaker.chat.handler import SagemakerChatHandler
|
from .llms.sagemaker.chat.handler import SagemakerChatHandler
|
||||||
from .llms.sagemaker.completion.handler import SagemakerLLM
|
from .llms.sagemaker.completion.handler import SagemakerLLM
|
||||||
|
from .llms.replicate.chat.handler import completion as replicate_chat_completion
|
||||||
from .llms.text_completion_codestral import CodestralTextCompletion
|
from .llms.text_completion_codestral import CodestralTextCompletion
|
||||||
from .llms.together_ai.completion.handler import TogetherAITextCompletion
|
from .llms.together_ai.completion.handler import TogetherAITextCompletion
|
||||||
from .llms.triton import TritonChatCompletion
|
from .llms.triton import TritonChatCompletion
|
||||||
|
@ -159,7 +150,7 @@ from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings.embedding_handler im
|
||||||
from .llms.vertex_ai_and_google_ai_studio.vertex_model_garden.main import (
|
from .llms.vertex_ai_and_google_ai_studio.vertex_model_garden.main import (
|
||||||
VertexAIModelGardenModels,
|
VertexAIModelGardenModels,
|
||||||
)
|
)
|
||||||
from .llms.vllm.completion import handler
|
from .llms.vllm.completion import handler as vllm_handler
|
||||||
from .llms.watsonx.chat.handler import WatsonXChatHandler
|
from .llms.watsonx.chat.handler import WatsonXChatHandler
|
||||||
from .llms.watsonx.completion.handler import IBMWatsonXAI
|
from .llms.watsonx.completion.handler import IBMWatsonXAI
|
||||||
from .types.llms.openai import (
|
from .types.llms.openai import (
|
||||||
|
@ -196,12 +187,10 @@ from litellm.utils import (
|
||||||
####### ENVIRONMENT VARIABLES ###################
|
####### ENVIRONMENT VARIABLES ###################
|
||||||
openai_chat_completions = OpenAIChatCompletion()
|
openai_chat_completions = OpenAIChatCompletion()
|
||||||
openai_text_completions = OpenAITextCompletion()
|
openai_text_completions = OpenAITextCompletion()
|
||||||
openai_o1_chat_completions = OpenAIO1ChatCompletion()
|
|
||||||
openai_audio_transcriptions = OpenAIAudioTranscription()
|
openai_audio_transcriptions = OpenAIAudioTranscription()
|
||||||
databricks_chat_completions = DatabricksChatCompletion()
|
databricks_chat_completions = DatabricksChatCompletion()
|
||||||
groq_chat_completions = GroqChatCompletion()
|
groq_chat_completions = GroqChatCompletion()
|
||||||
together_ai_text_completions = TogetherAITextCompletion()
|
together_ai_text_completions = TogetherAITextCompletion()
|
||||||
azure_ai_chat_completions = AzureAIChatCompletion()
|
|
||||||
azure_ai_embedding = AzureAIEmbedding()
|
azure_ai_embedding = AzureAIEmbedding()
|
||||||
anthropic_chat_completions = AnthropicChatCompletion()
|
anthropic_chat_completions = AnthropicChatCompletion()
|
||||||
azure_chat_completions = AzureChatCompletion()
|
azure_chat_completions = AzureChatCompletion()
|
||||||
|
@ -228,6 +217,7 @@ watsonxai = IBMWatsonXAI()
|
||||||
sagemaker_llm = SagemakerLLM()
|
sagemaker_llm = SagemakerLLM()
|
||||||
watsonx_chat_completion = WatsonXChatHandler()
|
watsonx_chat_completion = WatsonXChatHandler()
|
||||||
openai_like_embedding = OpenAILikeEmbeddingHandler()
|
openai_like_embedding = OpenAILikeEmbeddingHandler()
|
||||||
|
openai_like_chat_completion = OpenAILikeChatHandler()
|
||||||
databricks_embedding = DatabricksEmbeddingHandler()
|
databricks_embedding = DatabricksEmbeddingHandler()
|
||||||
base_llm_http_handler = BaseLLMHTTPHandler()
|
base_llm_http_handler = BaseLLMHTTPHandler()
|
||||||
sagemaker_chat_completion = SagemakerChatHandler()
|
sagemaker_chat_completion = SagemakerChatHandler()
|
||||||
|
@ -449,6 +439,7 @@ async def acompletion(
|
||||||
or custom_llm_provider == "cerebras"
|
or custom_llm_provider == "cerebras"
|
||||||
or custom_llm_provider == "sambanova"
|
or custom_llm_provider == "sambanova"
|
||||||
or custom_llm_provider == "ai21_chat"
|
or custom_llm_provider == "ai21_chat"
|
||||||
|
or custom_llm_provider == "ai21"
|
||||||
or custom_llm_provider == "volcengine"
|
or custom_llm_provider == "volcengine"
|
||||||
or custom_llm_provider == "codestral"
|
or custom_llm_provider == "codestral"
|
||||||
or custom_llm_provider == "text-completion-codestral"
|
or custom_llm_provider == "text-completion-codestral"
|
||||||
|
@ -1316,7 +1307,7 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
|
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
try:
|
try:
|
||||||
response = azure_ai_chat_completions.completion(
|
response = openai_chat_completions.completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
|
@ -1513,9 +1504,7 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
or custom_llm_provider == "nvidia_nim"
|
or custom_llm_provider == "nvidia_nim"
|
||||||
or custom_llm_provider == "cerebras"
|
or custom_llm_provider == "cerebras"
|
||||||
or custom_llm_provider == "sambanova"
|
or custom_llm_provider == "sambanova"
|
||||||
or custom_llm_provider == "ai21_chat"
|
|
||||||
or custom_llm_provider == "volcengine"
|
or custom_llm_provider == "volcengine"
|
||||||
or custom_llm_provider == "codestral"
|
|
||||||
or custom_llm_provider == "deepseek"
|
or custom_llm_provider == "deepseek"
|
||||||
or custom_llm_provider == "anyscale"
|
or custom_llm_provider == "anyscale"
|
||||||
or custom_llm_provider == "mistral"
|
or custom_llm_provider == "mistral"
|
||||||
|
@ -1562,46 +1551,25 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
|
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
try:
|
try:
|
||||||
if litellm.openAIO1Config.is_model_o1_reasoning_model(model=model):
|
response = openai_chat_completions.completion(
|
||||||
response = openai_o1_chat_completions.completion(
|
model=model,
|
||||||
model=model,
|
messages=messages,
|
||||||
messages=messages,
|
headers=headers,
|
||||||
headers=headers,
|
model_response=model_response,
|
||||||
model_response=model_response,
|
print_verbose=print_verbose,
|
||||||
print_verbose=print_verbose,
|
api_key=api_key,
|
||||||
api_key=api_key,
|
api_base=api_base,
|
||||||
api_base=api_base,
|
acompletion=acompletion,
|
||||||
acompletion=acompletion,
|
logging_obj=logging,
|
||||||
logging_obj=logging,
|
optional_params=optional_params,
|
||||||
optional_params=optional_params,
|
litellm_params=litellm_params,
|
||||||
litellm_params=litellm_params,
|
logger_fn=logger_fn,
|
||||||
logger_fn=logger_fn,
|
timeout=timeout, # type: ignore
|
||||||
timeout=timeout, # type: ignore
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
custom_prompt_dict=custom_prompt_dict,
|
client=client, # pass AsyncOpenAI, OpenAI client
|
||||||
client=client, # pass AsyncOpenAI, OpenAI client
|
organization=organization,
|
||||||
organization=organization,
|
custom_llm_provider=custom_llm_provider,
|
||||||
custom_llm_provider=custom_llm_provider,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
response = openai_chat_completions.completion(
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
headers=headers,
|
|
||||||
model_response=model_response,
|
|
||||||
print_verbose=print_verbose,
|
|
||||||
api_key=api_key,
|
|
||||||
api_base=api_base,
|
|
||||||
acompletion=acompletion,
|
|
||||||
logging_obj=logging,
|
|
||||||
optional_params=optional_params,
|
|
||||||
litellm_params=litellm_params,
|
|
||||||
logger_fn=logger_fn,
|
|
||||||
timeout=timeout, # type: ignore
|
|
||||||
custom_prompt_dict=custom_prompt_dict,
|
|
||||||
client=client, # pass AsyncOpenAI, OpenAI client
|
|
||||||
organization=organization,
|
|
||||||
custom_llm_provider=custom_llm_provider,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
## LOGGING - log the original exception returned
|
## LOGGING - log the original exception returned
|
||||||
logging.post_call(
|
logging.post_call(
|
||||||
|
@ -1627,7 +1595,6 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
or model in litellm.replicate_models
|
or model in litellm.replicate_models
|
||||||
):
|
):
|
||||||
# Setting the relevant API KEY for replicate, replicate defaults to using os.environ.get("REPLICATE_API_TOKEN")
|
# Setting the relevant API KEY for replicate, replicate defaults to using os.environ.get("REPLICATE_API_TOKEN")
|
||||||
replicate_key = None
|
|
||||||
replicate_key = (
|
replicate_key = (
|
||||||
api_key
|
api_key
|
||||||
or litellm.replicate_key
|
or litellm.replicate_key
|
||||||
|
@ -1645,7 +1612,7 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
|
|
||||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||||
|
|
||||||
model_response = replicate.completion( # type: ignore
|
model_response = replicate_chat_completion( # type: ignore
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -1659,6 +1626,7 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
custom_prompt_dict=custom_prompt_dict,
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
acompletion=acompletion,
|
acompletion=acompletion,
|
||||||
|
headers=headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
if optional_params.get("stream", False) is True:
|
if optional_params.get("stream", False) is True:
|
||||||
|
@ -1806,7 +1774,7 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
or "https://api.nlpcloud.io/v1/gpu/"
|
or "https://api.nlpcloud.io/v1/gpu/"
|
||||||
)
|
)
|
||||||
|
|
||||||
response = nlp_cloud.completion(
|
response = nlp_cloud_chat_completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -1969,10 +1937,10 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
api_base
|
api_base
|
||||||
or litellm.api_base
|
or litellm.api_base
|
||||||
or get_secret("MARITALK_API_BASE")
|
or get_secret("MARITALK_API_BASE")
|
||||||
or "https://chat.maritaca.ai/api/chat/inference"
|
or "https://chat.maritaca.ai/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
model_response = maritalk.completion(
|
model_response = openai_like_chat_completion.completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -1984,17 +1952,10 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
api_key=maritalk_key,
|
api_key=maritalk_key,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
|
custom_llm_provider="maritalk",
|
||||||
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
if "stream" in optional_params and optional_params["stream"] is True:
|
|
||||||
# don't try to access stream object,
|
|
||||||
response = CustomStreamWrapper(
|
|
||||||
model_response,
|
|
||||||
model,
|
|
||||||
custom_llm_provider="maritalk",
|
|
||||||
logging_obj=logging,
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
response = model_response
|
response = model_response
|
||||||
elif custom_llm_provider == "huggingface":
|
elif custom_llm_provider == "huggingface":
|
||||||
custom_llm_provider = "huggingface"
|
custom_llm_provider = "huggingface"
|
||||||
|
@ -2012,7 +1973,7 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
api_base=api_base, # type: ignore
|
api_base=api_base, # type: ignore
|
||||||
headers=hf_headers,
|
headers=hf_headers or {},
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
@ -2024,6 +1985,7 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
custom_prompt_dict=custom_prompt_dict,
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
timeout=timeout, # type: ignore
|
timeout=timeout, # type: ignore
|
||||||
|
client=client,
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
"stream" in optional_params
|
"stream" in optional_params
|
||||||
|
@ -2146,7 +2108,7 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
headers = openrouter_headers
|
headers = openrouter_headers
|
||||||
|
|
||||||
## Load Config
|
## Load Config
|
||||||
config = openrouter.OpenrouterConfig.get_config()
|
config = litellm.OpenrouterConfig.get_config()
|
||||||
for k, v in config.items():
|
for k, v in config.items():
|
||||||
if k == "extra_body":
|
if k == "extra_body":
|
||||||
# we use openai 'extra_body' to pass openrouter specific params - transforms, route, models
|
# we use openai 'extra_body' to pass openrouter specific params - transforms, route, models
|
||||||
|
@ -2190,30 +2152,9 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
elif custom_llm_provider == "palm":
|
elif custom_llm_provider == "palm":
|
||||||
palm_api_key = api_key or get_secret("PALM_API_KEY") or litellm.api_key
|
raise ValueError(
|
||||||
|
"Palm was decommisioned on October 2024. Please use the `gemini/` route for Gemini Google AI Studio Models. Announcement: https://ai.google.dev/palm_docs/palm?hl=en"
|
||||||
# palm does not support streaming as yet :(
|
|
||||||
model_response = palm.completion(
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
model_response=model_response,
|
|
||||||
print_verbose=print_verbose,
|
|
||||||
optional_params=optional_params,
|
|
||||||
litellm_params=litellm_params,
|
|
||||||
logger_fn=logger_fn,
|
|
||||||
encoding=encoding,
|
|
||||||
api_key=palm_api_key,
|
|
||||||
logging_obj=logging,
|
|
||||||
)
|
)
|
||||||
# fake palm streaming
|
|
||||||
if "stream" in optional_params and optional_params["stream"] is True:
|
|
||||||
# fake streaming for palm
|
|
||||||
resp_string = model_response["choices"][0]["message"]["content"]
|
|
||||||
response = CustomStreamWrapper(
|
|
||||||
resp_string, model, custom_llm_provider="palm", logging_obj=logging
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
response = model_response
|
|
||||||
elif custom_llm_provider == "vertex_ai_beta" or custom_llm_provider == "gemini":
|
elif custom_llm_provider == "vertex_ai_beta" or custom_llm_provider == "gemini":
|
||||||
vertex_ai_project = (
|
vertex_ai_project = (
|
||||||
optional_params.pop("vertex_project", None)
|
optional_params.pop("vertex_project", None)
|
||||||
|
@ -2475,51 +2416,9 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
):
|
):
|
||||||
return _model_response
|
return _model_response
|
||||||
response = _model_response
|
response = _model_response
|
||||||
elif custom_llm_provider == "ai21":
|
|
||||||
custom_llm_provider = "ai21"
|
|
||||||
ai21_key = (
|
|
||||||
api_key
|
|
||||||
or litellm.ai21_key
|
|
||||||
or os.environ.get("AI21_API_KEY")
|
|
||||||
or litellm.api_key
|
|
||||||
)
|
|
||||||
|
|
||||||
api_base = (
|
|
||||||
api_base
|
|
||||||
or litellm.api_base
|
|
||||||
or get_secret("AI21_API_BASE")
|
|
||||||
or "https://api.ai21.com/studio/v1/"
|
|
||||||
)
|
|
||||||
|
|
||||||
model_response = ai21.completion(
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
api_base=api_base,
|
|
||||||
model_response=model_response,
|
|
||||||
print_verbose=print_verbose,
|
|
||||||
optional_params=optional_params,
|
|
||||||
litellm_params=litellm_params,
|
|
||||||
logger_fn=logger_fn,
|
|
||||||
encoding=encoding,
|
|
||||||
api_key=ai21_key,
|
|
||||||
logging_obj=logging,
|
|
||||||
)
|
|
||||||
|
|
||||||
if "stream" in optional_params and optional_params["stream"] is True:
|
|
||||||
# don't try to access stream object,
|
|
||||||
response = CustomStreamWrapper(
|
|
||||||
model_response,
|
|
||||||
model,
|
|
||||||
custom_llm_provider="ai21",
|
|
||||||
logging_obj=logging,
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
|
|
||||||
## RESPONSE OBJECT
|
|
||||||
response = model_response
|
|
||||||
elif custom_llm_provider == "sagemaker_chat":
|
elif custom_llm_provider == "sagemaker_chat":
|
||||||
# boto3 reads keys from .env
|
# boto3 reads keys from .env
|
||||||
response = sagemaker_chat_completion.completion(
|
model_response = sagemaker_chat_completion.completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
|
@ -2531,9 +2430,13 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
acompletion=acompletion,
|
acompletion=acompletion,
|
||||||
headers=headers or {},
|
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "sagemaker":
|
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
response = model_response
|
||||||
|
elif (
|
||||||
|
custom_llm_provider == "sagemaker"
|
||||||
|
):
|
||||||
# boto3 reads keys from .env
|
# boto3 reads keys from .env
|
||||||
model_response = sagemaker_llm.completion(
|
model_response = sagemaker_llm.completion(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -2691,7 +2594,7 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
response = response
|
response = response
|
||||||
elif custom_llm_provider == "vllm":
|
elif custom_llm_provider == "vllm":
|
||||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||||
model_response = handler.completion(
|
model_response = vllm_handler.completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
custom_prompt_dict=custom_prompt_dict,
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
|
@ -3872,6 +3775,7 @@ async def atext_completion(
|
||||||
or custom_llm_provider == "cerebras"
|
or custom_llm_provider == "cerebras"
|
||||||
or custom_llm_provider == "sambanova"
|
or custom_llm_provider == "sambanova"
|
||||||
or custom_llm_provider == "ai21_chat"
|
or custom_llm_provider == "ai21_chat"
|
||||||
|
or custom_llm_provider == "ai21"
|
||||||
or custom_llm_provider == "volcengine"
|
or custom_llm_provider == "volcengine"
|
||||||
or custom_llm_provider == "text-completion-codestral"
|
or custom_llm_provider == "text-completion-codestral"
|
||||||
or custom_llm_provider == "deepseek"
|
or custom_llm_provider == "deepseek"
|
||||||
|
|
|
@ -56,6 +56,7 @@ class AnthropicPassthroughLoggingHandler:
|
||||||
request_data={},
|
request_data={},
|
||||||
encoding=litellm.encoding,
|
encoding=litellm.encoding,
|
||||||
json_mode=False,
|
json_mode=False,
|
||||||
|
litellm_params={},
|
||||||
)
|
)
|
||||||
|
|
||||||
kwargs = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload(
|
kwargs = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload(
|
||||||
|
|
|
@ -41,20 +41,19 @@ class VertexPassthroughLoggingHandler:
|
||||||
|
|
||||||
instance_of_vertex_llm = litellm.VertexGeminiConfig()
|
instance_of_vertex_llm = litellm.VertexGeminiConfig()
|
||||||
litellm_model_response: litellm.ModelResponse = (
|
litellm_model_response: litellm.ModelResponse = (
|
||||||
instance_of_vertex_llm._transform_response(
|
instance_of_vertex_llm.transform_response(
|
||||||
model=model,
|
model=model,
|
||||||
messages=[
|
messages=[
|
||||||
{"role": "user", "content": "no-message-pass-through-endpoint"}
|
{"role": "user", "content": "no-message-pass-through-endpoint"}
|
||||||
],
|
],
|
||||||
response=httpx_response,
|
raw_response=httpx_response,
|
||||||
model_response=litellm.ModelResponse(),
|
model_response=litellm.ModelResponse(),
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
optional_params={},
|
optional_params={},
|
||||||
litellm_params={},
|
litellm_params={},
|
||||||
api_key="",
|
api_key="",
|
||||||
data={},
|
request_data={},
|
||||||
print_verbose=litellm.print_verbose,
|
encoding=litellm.encoding,
|
||||||
encoding=None,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
kwargs = VertexPassthroughLoggingHandler._create_vertex_response_logging_payload_for_generate_content(
|
kwargs = VertexPassthroughLoggingHandler._create_vertex_response_logging_payload_for_generate_content(
|
||||||
|
|
288
litellm/utils.py
288
litellm/utils.py
|
@ -2923,22 +2923,16 @@ def get_optional_params( # noqa: PLR0915
|
||||||
)
|
)
|
||||||
_check_valid_arg(supported_params=supported_params)
|
_check_valid_arg(supported_params=supported_params)
|
||||||
|
|
||||||
if stream:
|
optional_params = litellm.ReplicateConfig().map_openai_params(
|
||||||
optional_params["stream"] = stream
|
non_default_params=non_default_params,
|
||||||
# return optional_params
|
optional_params=optional_params,
|
||||||
if max_tokens is not None:
|
model=model,
|
||||||
if "vicuna" in model or "flan" in model:
|
drop_params=(
|
||||||
optional_params["max_length"] = max_tokens
|
drop_params
|
||||||
elif "meta/codellama-13b" in model:
|
if drop_params is not None and isinstance(drop_params, bool)
|
||||||
optional_params["max_tokens"] = max_tokens
|
else False
|
||||||
else:
|
),
|
||||||
optional_params["max_new_tokens"] = max_tokens
|
)
|
||||||
if temperature is not None:
|
|
||||||
optional_params["temperature"] = temperature
|
|
||||||
if top_p is not None:
|
|
||||||
optional_params["top_p"] = top_p
|
|
||||||
if stop is not None:
|
|
||||||
optional_params["stop_sequences"] = stop
|
|
||||||
elif custom_llm_provider == "predibase":
|
elif custom_llm_provider == "predibase":
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
model=model, custom_llm_provider=custom_llm_provider
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
@ -2954,7 +2948,14 @@ def get_optional_params( # noqa: PLR0915
|
||||||
)
|
)
|
||||||
_check_valid_arg(supported_params=supported_params)
|
_check_valid_arg(supported_params=supported_params)
|
||||||
optional_params = litellm.HuggingfaceConfig().map_openai_params(
|
optional_params = litellm.HuggingfaceConfig().map_openai_params(
|
||||||
non_default_params=non_default_params, optional_params=optional_params
|
non_default_params=non_default_params,
|
||||||
|
optional_params=optional_params,
|
||||||
|
model=model,
|
||||||
|
drop_params=(
|
||||||
|
drop_params
|
||||||
|
if drop_params is not None and isinstance(drop_params, bool)
|
||||||
|
else False
|
||||||
|
),
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "together_ai":
|
elif custom_llm_provider == "together_ai":
|
||||||
## check if unsupported param passed in
|
## check if unsupported param passed in
|
||||||
|
@ -2973,53 +2974,6 @@ def get_optional_params( # noqa: PLR0915
|
||||||
else False
|
else False
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "ai21":
|
|
||||||
## check if unsupported param passed in
|
|
||||||
supported_params = get_supported_openai_params(
|
|
||||||
model=model, custom_llm_provider=custom_llm_provider
|
|
||||||
)
|
|
||||||
_check_valid_arg(supported_params=supported_params)
|
|
||||||
|
|
||||||
if stream:
|
|
||||||
optional_params["stream"] = stream
|
|
||||||
if n is not None:
|
|
||||||
optional_params["numResults"] = n
|
|
||||||
if max_tokens is not None:
|
|
||||||
optional_params["maxTokens"] = max_tokens
|
|
||||||
if temperature is not None:
|
|
||||||
optional_params["temperature"] = temperature
|
|
||||||
if top_p is not None:
|
|
||||||
optional_params["topP"] = top_p
|
|
||||||
if stop is not None:
|
|
||||||
optional_params["stopSequences"] = stop
|
|
||||||
if frequency_penalty is not None:
|
|
||||||
optional_params["frequencyPenalty"] = {"scale": frequency_penalty}
|
|
||||||
if presence_penalty is not None:
|
|
||||||
optional_params["presencePenalty"] = {"scale": presence_penalty}
|
|
||||||
elif (
|
|
||||||
custom_llm_provider == "palm"
|
|
||||||
): # https://developers.generativeai.google/tutorials/curl_quickstart
|
|
||||||
## check if unsupported param passed in
|
|
||||||
supported_params = get_supported_openai_params(
|
|
||||||
model=model, custom_llm_provider=custom_llm_provider
|
|
||||||
)
|
|
||||||
_check_valid_arg(supported_params=supported_params)
|
|
||||||
|
|
||||||
if temperature is not None:
|
|
||||||
optional_params["temperature"] = temperature
|
|
||||||
if top_p is not None:
|
|
||||||
optional_params["top_p"] = top_p
|
|
||||||
if stream:
|
|
||||||
optional_params["stream"] = stream
|
|
||||||
if n is not None:
|
|
||||||
optional_params["candidate_count"] = n
|
|
||||||
if stop is not None:
|
|
||||||
if isinstance(stop, str):
|
|
||||||
optional_params["stop_sequences"] = [stop]
|
|
||||||
elif isinstance(stop, list):
|
|
||||||
optional_params["stop_sequences"] = stop
|
|
||||||
if max_tokens is not None:
|
|
||||||
optional_params["max_output_tokens"] = max_tokens
|
|
||||||
elif custom_llm_provider == "vertex_ai" and (
|
elif custom_llm_provider == "vertex_ai" and (
|
||||||
model in litellm.vertex_chat_models
|
model in litellm.vertex_chat_models
|
||||||
or model in litellm.vertex_code_chat_models
|
or model in litellm.vertex_code_chat_models
|
||||||
|
@ -3120,12 +3074,25 @@ def get_optional_params( # noqa: PLR0915
|
||||||
_check_valid_arg(supported_params=supported_params)
|
_check_valid_arg(supported_params=supported_params)
|
||||||
if "codestral" in model:
|
if "codestral" in model:
|
||||||
optional_params = litellm.MistralTextCompletionConfig().map_openai_params(
|
optional_params = litellm.MistralTextCompletionConfig().map_openai_params(
|
||||||
non_default_params=non_default_params, optional_params=optional_params
|
model=model,
|
||||||
|
non_default_params=non_default_params,
|
||||||
|
optional_params=optional_params,
|
||||||
|
drop_params=(
|
||||||
|
drop_params
|
||||||
|
if drop_params is not None and isinstance(drop_params, bool)
|
||||||
|
else False
|
||||||
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
optional_params = litellm.MistralConfig().map_openai_params(
|
optional_params = litellm.MistralConfig().map_openai_params(
|
||||||
|
model=model,
|
||||||
non_default_params=non_default_params,
|
non_default_params=non_default_params,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
drop_params=(
|
||||||
|
drop_params
|
||||||
|
if drop_params is not None and isinstance(drop_params, bool)
|
||||||
|
else False
|
||||||
|
),
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "vertex_ai" and model in litellm.vertex_ai_ai21_models:
|
elif custom_llm_provider == "vertex_ai" and model in litellm.vertex_ai_ai21_models:
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
|
@ -3326,29 +3293,28 @@ def get_optional_params( # noqa: PLR0915
|
||||||
model=model,
|
model=model,
|
||||||
non_default_params=non_default_params,
|
non_default_params=non_default_params,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
drop_params=(
|
||||||
|
drop_params
|
||||||
|
if drop_params is not None and isinstance(drop_params, bool)
|
||||||
|
else False
|
||||||
|
),
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "nlp_cloud":
|
elif custom_llm_provider == "nlp_cloud":
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
model=model, custom_llm_provider=custom_llm_provider
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
)
|
)
|
||||||
_check_valid_arg(supported_params=supported_params)
|
_check_valid_arg(supported_params=supported_params)
|
||||||
|
optional_params = litellm.NLPCloudConfig().map_openai_params(
|
||||||
|
non_default_params=non_default_params,
|
||||||
|
optional_params=optional_params,
|
||||||
|
model=model,
|
||||||
|
drop_params=(
|
||||||
|
drop_params
|
||||||
|
if drop_params is not None and isinstance(drop_params, bool)
|
||||||
|
else False
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
if max_tokens is not None:
|
|
||||||
optional_params["max_length"] = max_tokens
|
|
||||||
if stream:
|
|
||||||
optional_params["stream"] = stream
|
|
||||||
if temperature is not None:
|
|
||||||
optional_params["temperature"] = temperature
|
|
||||||
if top_p is not None:
|
|
||||||
optional_params["top_p"] = top_p
|
|
||||||
if presence_penalty is not None:
|
|
||||||
optional_params["presence_penalty"] = presence_penalty
|
|
||||||
if frequency_penalty is not None:
|
|
||||||
optional_params["frequency_penalty"] = frequency_penalty
|
|
||||||
if n is not None:
|
|
||||||
optional_params["num_return_sequences"] = n
|
|
||||||
if stop is not None:
|
|
||||||
optional_params["stop_sequences"] = stop
|
|
||||||
elif custom_llm_provider == "petals":
|
elif custom_llm_provider == "petals":
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
model=model, custom_llm_provider=custom_llm_provider
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
@ -3435,7 +3401,14 @@ def get_optional_params( # noqa: PLR0915
|
||||||
)
|
)
|
||||||
_check_valid_arg(supported_params=supported_params)
|
_check_valid_arg(supported_params=supported_params)
|
||||||
optional_params = litellm.MistralConfig().map_openai_params(
|
optional_params = litellm.MistralConfig().map_openai_params(
|
||||||
non_default_params=non_default_params, optional_params=optional_params
|
non_default_params=non_default_params,
|
||||||
|
optional_params=optional_params,
|
||||||
|
model=model,
|
||||||
|
drop_params=(
|
||||||
|
drop_params
|
||||||
|
if drop_params is not None and isinstance(drop_params, bool)
|
||||||
|
else False
|
||||||
|
),
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "text-completion-codestral":
|
elif custom_llm_provider == "text-completion-codestral":
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
|
@ -3443,7 +3416,14 @@ def get_optional_params( # noqa: PLR0915
|
||||||
)
|
)
|
||||||
_check_valid_arg(supported_params=supported_params)
|
_check_valid_arg(supported_params=supported_params)
|
||||||
optional_params = litellm.MistralTextCompletionConfig().map_openai_params(
|
optional_params = litellm.MistralTextCompletionConfig().map_openai_params(
|
||||||
non_default_params=non_default_params, optional_params=optional_params
|
non_default_params=non_default_params,
|
||||||
|
optional_params=optional_params,
|
||||||
|
model=model,
|
||||||
|
drop_params=(
|
||||||
|
drop_params
|
||||||
|
if drop_params is not None and isinstance(drop_params, bool)
|
||||||
|
else False
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
elif custom_llm_provider == "databricks":
|
elif custom_llm_provider == "databricks":
|
||||||
|
@ -3470,6 +3450,11 @@ def get_optional_params( # noqa: PLR0915
|
||||||
model=model,
|
model=model,
|
||||||
non_default_params=non_default_params,
|
non_default_params=non_default_params,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
drop_params=(
|
||||||
|
drop_params
|
||||||
|
if drop_params is not None and isinstance(drop_params, bool)
|
||||||
|
else False
|
||||||
|
),
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "cerebras":
|
elif custom_llm_provider == "cerebras":
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
|
@ -3480,6 +3465,11 @@ def get_optional_params( # noqa: PLR0915
|
||||||
non_default_params=non_default_params,
|
non_default_params=non_default_params,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
model=model,
|
model=model,
|
||||||
|
drop_params=(
|
||||||
|
drop_params
|
||||||
|
if drop_params is not None and isinstance(drop_params, bool)
|
||||||
|
else False
|
||||||
|
),
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "xai":
|
elif custom_llm_provider == "xai":
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
|
@ -3491,7 +3481,7 @@ def get_optional_params( # noqa: PLR0915
|
||||||
non_default_params=non_default_params,
|
non_default_params=non_default_params,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "ai21_chat":
|
elif custom_llm_provider == "ai21_chat" or custom_llm_provider == "ai21":
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
model=model, custom_llm_provider=custom_llm_provider
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
)
|
)
|
||||||
|
@ -3500,6 +3490,11 @@ def get_optional_params( # noqa: PLR0915
|
||||||
non_default_params=non_default_params,
|
non_default_params=non_default_params,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
model=model,
|
model=model,
|
||||||
|
drop_params=(
|
||||||
|
drop_params
|
||||||
|
if drop_params is not None and isinstance(drop_params, bool)
|
||||||
|
else False
|
||||||
|
),
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "fireworks_ai":
|
elif custom_llm_provider == "fireworks_ai":
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
|
@ -3525,6 +3520,11 @@ def get_optional_params( # noqa: PLR0915
|
||||||
non_default_params=non_default_params,
|
non_default_params=non_default_params,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
model=model,
|
model=model,
|
||||||
|
drop_params=(
|
||||||
|
drop_params
|
||||||
|
if drop_params is not None and isinstance(drop_params, bool)
|
||||||
|
else False
|
||||||
|
),
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "hosted_vllm":
|
elif custom_llm_provider == "hosted_vllm":
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
|
@ -3594,55 +3594,17 @@ def get_optional_params( # noqa: PLR0915
|
||||||
)
|
)
|
||||||
_check_valid_arg(supported_params=supported_params)
|
_check_valid_arg(supported_params=supported_params)
|
||||||
|
|
||||||
if functions is not None:
|
optional_params = litellm.OpenrouterConfig().map_openai_params(
|
||||||
optional_params["functions"] = functions
|
non_default_params=non_default_params,
|
||||||
if function_call is not None:
|
optional_params=optional_params,
|
||||||
optional_params["function_call"] = function_call
|
model=model,
|
||||||
if temperature is not None:
|
drop_params=(
|
||||||
optional_params["temperature"] = temperature
|
drop_params
|
||||||
if top_p is not None:
|
if drop_params is not None and isinstance(drop_params, bool)
|
||||||
optional_params["top_p"] = top_p
|
else False
|
||||||
if n is not None:
|
),
|
||||||
optional_params["n"] = n
|
|
||||||
if stream is not None:
|
|
||||||
optional_params["stream"] = stream
|
|
||||||
if stop is not None:
|
|
||||||
optional_params["stop"] = stop
|
|
||||||
if max_tokens is not None:
|
|
||||||
optional_params["max_tokens"] = max_tokens
|
|
||||||
if presence_penalty is not None:
|
|
||||||
optional_params["presence_penalty"] = presence_penalty
|
|
||||||
if frequency_penalty is not None:
|
|
||||||
optional_params["frequency_penalty"] = frequency_penalty
|
|
||||||
if logit_bias is not None:
|
|
||||||
optional_params["logit_bias"] = logit_bias
|
|
||||||
if user is not None:
|
|
||||||
optional_params["user"] = user
|
|
||||||
if response_format is not None:
|
|
||||||
optional_params["response_format"] = response_format
|
|
||||||
if seed is not None:
|
|
||||||
optional_params["seed"] = seed
|
|
||||||
if tools is not None:
|
|
||||||
optional_params["tools"] = tools
|
|
||||||
if tool_choice is not None:
|
|
||||||
optional_params["tool_choice"] = tool_choice
|
|
||||||
if max_retries is not None:
|
|
||||||
optional_params["max_retries"] = max_retries
|
|
||||||
|
|
||||||
# OpenRouter-only parameters
|
|
||||||
extra_body = {}
|
|
||||||
transforms = passed_params.pop("transforms", None)
|
|
||||||
models = passed_params.pop("models", None)
|
|
||||||
route = passed_params.pop("route", None)
|
|
||||||
if transforms is not None:
|
|
||||||
extra_body["transforms"] = transforms
|
|
||||||
if models is not None:
|
|
||||||
extra_body["models"] = models
|
|
||||||
if route is not None:
|
|
||||||
extra_body["route"] = route
|
|
||||||
optional_params["extra_body"] = (
|
|
||||||
extra_body # openai client supports `extra_body` param
|
|
||||||
)
|
)
|
||||||
|
|
||||||
elif custom_llm_provider == "watsonx":
|
elif custom_llm_provider == "watsonx":
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
model=model, custom_llm_provider=custom_llm_provider
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
@ -3727,7 +3689,11 @@ def get_optional_params( # noqa: PLR0915
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
model=model,
|
model=model,
|
||||||
api_version=api_version, # type: ignore
|
api_version=api_version, # type: ignore
|
||||||
drop_params=drop_params,
|
drop_params=(
|
||||||
|
drop_params
|
||||||
|
if drop_params is not None and isinstance(drop_params, bool)
|
||||||
|
else False
|
||||||
|
),
|
||||||
)
|
)
|
||||||
else: # assume passing in params for text-completion openai
|
else: # assume passing in params for text-completion openai
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
|
@ -6271,7 +6237,7 @@ from litellm.llms.base_llm.transformation import BaseConfig
|
||||||
|
|
||||||
class ProviderConfigManager:
|
class ProviderConfigManager:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_provider_chat_config(
|
def get_provider_chat_config( # noqa: PLR0915
|
||||||
model: str, provider: litellm.LlmProviders
|
model: str, provider: litellm.LlmProviders
|
||||||
) -> BaseConfig:
|
) -> BaseConfig:
|
||||||
"""
|
"""
|
||||||
|
@ -6333,6 +6299,60 @@ class ProviderConfigManager:
|
||||||
return litellm.LMStudioChatConfig()
|
return litellm.LMStudioChatConfig()
|
||||||
elif litellm.LlmProviders.GALADRIEL == provider:
|
elif litellm.LlmProviders.GALADRIEL == provider:
|
||||||
return litellm.GaladrielChatConfig()
|
return litellm.GaladrielChatConfig()
|
||||||
|
elif litellm.LlmProviders.REPLICATE == provider:
|
||||||
|
return litellm.ReplicateConfig()
|
||||||
|
elif litellm.LlmProviders.HUGGINGFACE == provider:
|
||||||
|
return litellm.HuggingfaceConfig()
|
||||||
|
elif litellm.LlmProviders.TOGETHER_AI == provider:
|
||||||
|
return litellm.TogetherAIConfig()
|
||||||
|
elif litellm.LlmProviders.OPENROUTER == provider:
|
||||||
|
return litellm.OpenrouterConfig()
|
||||||
|
elif litellm.LlmProviders.GEMINI == provider:
|
||||||
|
return litellm.GoogleAIStudioGeminiConfig()
|
||||||
|
elif (
|
||||||
|
litellm.LlmProviders.AI21 == provider
|
||||||
|
or litellm.LlmProviders.AI21_CHAT == provider
|
||||||
|
):
|
||||||
|
return litellm.AI21ChatConfig()
|
||||||
|
elif litellm.LlmProviders.AZURE == provider:
|
||||||
|
return litellm.AzureOpenAIConfig()
|
||||||
|
elif litellm.LlmProviders.AZURE_AI == provider:
|
||||||
|
return litellm.AzureAIStudioConfig()
|
||||||
|
elif litellm.LlmProviders.AZURE_TEXT == provider:
|
||||||
|
return litellm.AzureOpenAITextConfig()
|
||||||
|
elif litellm.LlmProviders.HOSTED_VLLM == provider:
|
||||||
|
return litellm.HostedVLLMChatConfig()
|
||||||
|
elif litellm.LlmProviders.NLP_CLOUD == provider:
|
||||||
|
return litellm.NLPCloudConfig()
|
||||||
|
elif litellm.LlmProviders.OOBABOOGA == provider:
|
||||||
|
return litellm.OobaboogaConfig()
|
||||||
|
elif litellm.LlmProviders.OLLAMA_CHAT == provider:
|
||||||
|
return litellm.OllamaChatConfig()
|
||||||
|
elif litellm.LlmProviders.DEEPINFRA == provider:
|
||||||
|
return litellm.DeepInfraConfig()
|
||||||
|
elif litellm.LlmProviders.PERPLEXITY == provider:
|
||||||
|
return litellm.PerplexityChatConfig()
|
||||||
|
elif (
|
||||||
|
litellm.LlmProviders.MISTRAL == provider
|
||||||
|
or litellm.LlmProviders.CODESTRAL == provider
|
||||||
|
):
|
||||||
|
return litellm.MistralConfig()
|
||||||
|
elif litellm.LlmProviders.NVIDIA_NIM == provider:
|
||||||
|
return litellm.NvidiaNimConfig()
|
||||||
|
elif litellm.LlmProviders.CEREBRAS == provider:
|
||||||
|
return litellm.CerebrasConfig()
|
||||||
|
elif litellm.LlmProviders.VOLCENGINE == provider:
|
||||||
|
return litellm.VolcEngineConfig()
|
||||||
|
elif litellm.LlmProviders.TEXT_COMPLETION_CODESTRAL == provider:
|
||||||
|
return litellm.MistralTextCompletionConfig()
|
||||||
|
elif litellm.LlmProviders.SAMBANOVA == provider:
|
||||||
|
return litellm.SambanovaConfig()
|
||||||
|
elif litellm.LlmProviders.MARITALK == provider:
|
||||||
|
return litellm.MaritalkConfig()
|
||||||
|
elif litellm.LlmProviders.CLOUDFLARE == provider:
|
||||||
|
return litellm.CloudflareChatConfig()
|
||||||
|
elif litellm.LlmProviders.ANTHROPIC_TEXT == provider:
|
||||||
|
return litellm.AnthropicTextConfig()
|
||||||
elif litellm.LlmProviders.VLLM == provider:
|
elif litellm.LlmProviders.VLLM == provider:
|
||||||
return litellm.VLLMConfig()
|
return litellm.VLLMConfig()
|
||||||
elif litellm.LlmProviders.OLLAMA == provider:
|
elif litellm.LlmProviders.OLLAMA == provider:
|
||||||
|
|
|
@ -168,12 +168,17 @@ def test_all_model_configs():
|
||||||
drop_params=False,
|
drop_params=False,
|
||||||
) == {"max_tokens": 10}
|
) == {"max_tokens": 10}
|
||||||
|
|
||||||
from litellm.llms.huggingface_restapi import HuggingfaceConfig
|
from litellm.llms.huggingface.chat.handler import HuggingfaceConfig
|
||||||
|
|
||||||
assert "max_completion_tokens" in HuggingfaceConfig().get_supported_openai_params()
|
assert "max_completion_tokens" in HuggingfaceConfig().get_supported_openai_params(
|
||||||
assert HuggingfaceConfig().map_openai_params({"max_completion_tokens": 10}, {}) == {
|
model="llama3"
|
||||||
"max_new_tokens": 10
|
)
|
||||||
}
|
assert HuggingfaceConfig().map_openai_params(
|
||||||
|
non_default_params={"max_completion_tokens": 10},
|
||||||
|
optional_params={},
|
||||||
|
model="llama3",
|
||||||
|
drop_params=False,
|
||||||
|
) == {"max_new_tokens": 10}
|
||||||
|
|
||||||
from litellm.llms.nvidia_nim.chat import NvidiaNimConfig
|
from litellm.llms.nvidia_nim.chat import NvidiaNimConfig
|
||||||
|
|
||||||
|
@ -184,15 +189,19 @@ def test_all_model_configs():
|
||||||
model="llama3",
|
model="llama3",
|
||||||
non_default_params={"max_completion_tokens": 10},
|
non_default_params={"max_completion_tokens": 10},
|
||||||
optional_params={},
|
optional_params={},
|
||||||
|
drop_params=False,
|
||||||
) == {"max_tokens": 10}
|
) == {"max_tokens": 10}
|
||||||
|
|
||||||
from litellm.llms.ollama_chat import OllamaChatConfig
|
from litellm.llms.ollama_chat import OllamaChatConfig
|
||||||
|
|
||||||
assert "max_completion_tokens" in OllamaChatConfig().get_supported_openai_params()
|
assert "max_completion_tokens" in OllamaChatConfig().get_supported_openai_params(
|
||||||
|
model="llama3"
|
||||||
|
)
|
||||||
assert OllamaChatConfig().map_openai_params(
|
assert OllamaChatConfig().map_openai_params(
|
||||||
model="llama3",
|
model="llama3",
|
||||||
non_default_params={"max_completion_tokens": 10},
|
non_default_params={"max_completion_tokens": 10},
|
||||||
optional_params={},
|
optional_params={},
|
||||||
|
drop_params=False,
|
||||||
) == {"num_predict": 10}
|
) == {"num_predict": 10}
|
||||||
|
|
||||||
from litellm.llms.predibase import PredibaseConfig
|
from litellm.llms.predibase import PredibaseConfig
|
||||||
|
@ -207,11 +216,13 @@ def test_all_model_configs():
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
"max_completion_tokens"
|
"max_completion_tokens"
|
||||||
in MistralTextCompletionConfig().get_supported_openai_params()
|
in MistralTextCompletionConfig().get_supported_openai_params(model="llama3")
|
||||||
)
|
)
|
||||||
assert MistralTextCompletionConfig().map_openai_params(
|
assert MistralTextCompletionConfig().map_openai_params(
|
||||||
{"max_completion_tokens": 10},
|
model="llama3",
|
||||||
{},
|
non_default_params={"max_completion_tokens": 10},
|
||||||
|
optional_params={},
|
||||||
|
drop_params=False,
|
||||||
) == {"max_tokens": 10}
|
) == {"max_tokens": 10}
|
||||||
|
|
||||||
from litellm.llms.volcengine import VolcEngineConfig
|
from litellm.llms.volcengine import VolcEngineConfig
|
||||||
|
@ -223,9 +234,10 @@ def test_all_model_configs():
|
||||||
model="llama3",
|
model="llama3",
|
||||||
non_default_params={"max_completion_tokens": 10},
|
non_default_params={"max_completion_tokens": 10},
|
||||||
optional_params={},
|
optional_params={},
|
||||||
|
drop_params=False,
|
||||||
) == {"max_tokens": 10}
|
) == {"max_tokens": 10}
|
||||||
|
|
||||||
from litellm.llms.ai21.chat import AI21ChatConfig
|
from litellm.llms.ai21.chat.transformation import AI21ChatConfig
|
||||||
|
|
||||||
assert "max_completion_tokens" in AI21ChatConfig().get_supported_openai_params(
|
assert "max_completion_tokens" in AI21ChatConfig().get_supported_openai_params(
|
||||||
"jamba-1.5-mini@001"
|
"jamba-1.5-mini@001"
|
||||||
|
@ -234,11 +246,14 @@ def test_all_model_configs():
|
||||||
model="jamba-1.5-mini@001",
|
model="jamba-1.5-mini@001",
|
||||||
non_default_params={"max_completion_tokens": 10},
|
non_default_params={"max_completion_tokens": 10},
|
||||||
optional_params={},
|
optional_params={},
|
||||||
|
drop_params=False,
|
||||||
) == {"max_tokens": 10}
|
) == {"max_tokens": 10}
|
||||||
|
|
||||||
from litellm.llms.azure.chat.gpt_transformation import AzureOpenAIConfig
|
from litellm.llms.azure.chat.gpt_transformation import AzureOpenAIConfig
|
||||||
|
|
||||||
assert "max_completion_tokens" in AzureOpenAIConfig().get_supported_openai_params()
|
assert "max_completion_tokens" in AzureOpenAIConfig().get_supported_openai_params(
|
||||||
|
model="gpt-3.5-turbo"
|
||||||
|
)
|
||||||
assert AzureOpenAIConfig().map_openai_params(
|
assert AzureOpenAIConfig().map_openai_params(
|
||||||
model="gpt-3.5-turbo",
|
model="gpt-3.5-turbo",
|
||||||
non_default_params={"max_completion_tokens": 10},
|
non_default_params={"max_completion_tokens": 10},
|
||||||
|
@ -266,11 +281,13 @@ def test_all_model_configs():
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
"max_completion_tokens"
|
"max_completion_tokens"
|
||||||
in MistralTextCompletionConfig().get_supported_openai_params()
|
in MistralTextCompletionConfig().get_supported_openai_params(model="llama3")
|
||||||
)
|
)
|
||||||
assert MistralTextCompletionConfig().map_openai_params(
|
assert MistralTextCompletionConfig().map_openai_params(
|
||||||
|
model="llama3",
|
||||||
non_default_params={"max_completion_tokens": 10},
|
non_default_params={"max_completion_tokens": 10},
|
||||||
optional_params={},
|
optional_params={},
|
||||||
|
drop_params=False,
|
||||||
) == {"max_tokens": 10}
|
) == {"max_tokens": 10}
|
||||||
|
|
||||||
from litellm.llms.bedrock.common_utils import (
|
from litellm.llms.bedrock.common_utils import (
|
||||||
|
@ -341,7 +358,9 @@ def test_all_model_configs():
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
"max_completion_tokens"
|
"max_completion_tokens"
|
||||||
in GoogleAIStudioGeminiConfig().get_supported_openai_params()
|
in GoogleAIStudioGeminiConfig().get_supported_openai_params(
|
||||||
|
model="gemini-1.0-pro"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
assert GoogleAIStudioGeminiConfig().map_openai_params(
|
assert GoogleAIStudioGeminiConfig().map_openai_params(
|
||||||
|
@ -351,7 +370,9 @@ def test_all_model_configs():
|
||||||
drop_params=False,
|
drop_params=False,
|
||||||
) == {"max_output_tokens": 10}
|
) == {"max_output_tokens": 10}
|
||||||
|
|
||||||
assert "max_completion_tokens" in VertexGeminiConfig().get_supported_openai_params()
|
assert "max_completion_tokens" in VertexGeminiConfig().get_supported_openai_params(
|
||||||
|
model="gemini-1.0-pro"
|
||||||
|
)
|
||||||
|
|
||||||
assert VertexGeminiConfig().map_openai_params(
|
assert VertexGeminiConfig().map_openai_params(
|
||||||
model="gemini-1.0-pro",
|
model="gemini-1.0-pro",
|
||||||
|
|
|
@ -190,9 +190,10 @@ def test_databricks_optional_params():
|
||||||
custom_llm_provider="databricks",
|
custom_llm_provider="databricks",
|
||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
temperature=0.2,
|
temperature=0.2,
|
||||||
|
stream=True,
|
||||||
)
|
)
|
||||||
print(f"optional_params: {optional_params}")
|
print(f"optional_params: {optional_params}")
|
||||||
assert len(optional_params) == 2
|
assert len(optional_params) == 3
|
||||||
assert "user" not in optional_params
|
assert "user" not in optional_params
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -449,8 +449,12 @@ def test_azure_tool_call_invoke_helper():
|
||||||
{"role": "assistant", "function_call": {"name": "get_weather"}},
|
{"role": "assistant", "function_call": {"name": "get_weather"}},
|
||||||
]
|
]
|
||||||
|
|
||||||
transformed_messages = litellm.AzureOpenAIConfig.transform_request(
|
transformed_messages = litellm.AzureOpenAIConfig().transform_request(
|
||||||
model="gpt-4o", messages=messages, optional_params={}
|
model="gpt-4o",
|
||||||
|
messages=messages,
|
||||||
|
optional_params={},
|
||||||
|
litellm_params={},
|
||||||
|
headers={},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert transformed_messages["messages"] == [
|
assert transformed_messages["messages"] == [
|
||||||
|
|
|
@ -69,7 +69,7 @@ def test_batch_completions_models():
|
||||||
def test_batch_completion_models_all_responses():
|
def test_batch_completion_models_all_responses():
|
||||||
try:
|
try:
|
||||||
responses = batch_completion_models_all_responses(
|
responses = batch_completion_models_all_responses(
|
||||||
models=["j2-light", "claude-3-haiku-20240307"],
|
models=["gemini/gemini-1.5-flash", "claude-3-haiku-20240307"],
|
||||||
messages=[{"role": "user", "content": "write a poem"}],
|
messages=[{"role": "user", "content": "write a poem"}],
|
||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1606,30 +1606,33 @@ HF Tests we should pass
|
||||||
#####################################################
|
#####################################################
|
||||||
#####################################################
|
#####################################################
|
||||||
# Test util to sort models to TGI, conv, None
|
# Test util to sort models to TGI, conv, None
|
||||||
|
from litellm.llms.huggingface.chat.transformation import HuggingfaceChatConfig
|
||||||
|
|
||||||
|
|
||||||
def test_get_hf_task_for_model():
|
def test_get_hf_task_for_model():
|
||||||
model = "glaiveai/glaive-coder-7b"
|
model = "glaiveai/glaive-coder-7b"
|
||||||
model_type, _ = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
|
model_type, _ = HuggingfaceChatConfig().get_hf_task_for_model(model)
|
||||||
print(f"model:{model}, model type: {model_type}")
|
print(f"model:{model}, model type: {model_type}")
|
||||||
assert model_type == "text-generation-inference"
|
assert model_type == "text-generation-inference"
|
||||||
|
|
||||||
model = "meta-llama/Llama-2-7b-hf"
|
model = "meta-llama/Llama-2-7b-hf"
|
||||||
model_type, _ = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
|
model_type, _ = HuggingfaceChatConfig().get_hf_task_for_model(model)
|
||||||
print(f"model:{model}, model type: {model_type}")
|
print(f"model:{model}, model type: {model_type}")
|
||||||
assert model_type == "text-generation-inference"
|
assert model_type == "text-generation-inference"
|
||||||
|
|
||||||
model = "facebook/blenderbot-400M-distill"
|
model = "facebook/blenderbot-400M-distill"
|
||||||
model_type, _ = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
|
model_type, _ = HuggingfaceChatConfig().get_hf_task_for_model(model)
|
||||||
print(f"model:{model}, model type: {model_type}")
|
print(f"model:{model}, model type: {model_type}")
|
||||||
assert model_type == "conversational"
|
assert model_type == "conversational"
|
||||||
|
|
||||||
model = "facebook/blenderbot-3B"
|
model = "facebook/blenderbot-3B"
|
||||||
model_type, _ = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
|
model_type, _ = HuggingfaceChatConfig().get_hf_task_for_model(model)
|
||||||
print(f"model:{model}, model type: {model_type}")
|
print(f"model:{model}, model type: {model_type}")
|
||||||
assert model_type == "conversational"
|
assert model_type == "conversational"
|
||||||
|
|
||||||
# neither Conv or None
|
# neither Conv or None
|
||||||
model = "roneneldan/TinyStories-3M"
|
model = "roneneldan/TinyStories-3M"
|
||||||
model_type, _ = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
|
model_type, _ = HuggingfaceChatConfig().get_hf_task_for_model(model)
|
||||||
print(f"model:{model}, model type: {model_type}")
|
print(f"model:{model}, model type: {model_type}")
|
||||||
assert model_type == "text-generation"
|
assert model_type == "text-generation"
|
||||||
|
|
||||||
|
@ -1717,14 +1720,17 @@ def tgi_mock_post(url, **kwargs):
|
||||||
def test_hf_test_completion_tgi():
|
def test_hf_test_completion_tgi():
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
try:
|
try:
|
||||||
|
client = HTTPHandler()
|
||||||
|
|
||||||
with patch("requests.post", side_effect=tgi_mock_post) as mock_client:
|
with patch.object(client, "post", side_effect=tgi_mock_post) as mock_client:
|
||||||
response = completion(
|
response = completion(
|
||||||
model="huggingface/HuggingFaceH4/zephyr-7b-beta",
|
model="huggingface/HuggingFaceH4/zephyr-7b-beta",
|
||||||
messages=[{"content": "Hello, how are you?", "role": "user"}],
|
messages=[{"content": "Hello, how are you?", "role": "user"}],
|
||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
wait_for_model=True,
|
wait_for_model=True,
|
||||||
|
client=client,
|
||||||
)
|
)
|
||||||
|
mock_client.assert_called_once()
|
||||||
# Add any assertions-here to check the response
|
# Add any assertions-here to check the response
|
||||||
print(response)
|
print(response)
|
||||||
assert "options" in mock_client.call_args.kwargs["data"]
|
assert "options" in mock_client.call_args.kwargs["data"]
|
||||||
|
@ -1862,13 +1868,15 @@ def mock_post(url, **kwargs):
|
||||||
|
|
||||||
def test_hf_classifier_task():
|
def test_hf_classifier_task():
|
||||||
try:
|
try:
|
||||||
with patch("requests.post", side_effect=mock_post):
|
client = HTTPHandler()
|
||||||
|
with patch.object(client, "post", side_effect=mock_post):
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
user_message = "I like you. I love you"
|
user_message = "I like you. I love you"
|
||||||
messages = [{"content": user_message, "role": "user"}]
|
messages = [{"content": user_message, "role": "user"}]
|
||||||
response = completion(
|
response = completion(
|
||||||
model="huggingface/text-classification/shahrukhx01/question-vs-statement-classifier",
|
model="huggingface/text-classification/shahrukhx01/question-vs-statement-classifier",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
client=client,
|
||||||
)
|
)
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
assert isinstance(response, litellm.ModelResponse)
|
assert isinstance(response, litellm.ModelResponse)
|
||||||
|
@ -3096,19 +3104,20 @@ async def test_completion_replicate_llama3(sync_mode):
|
||||||
response = completion(
|
response = completion(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
max_tokens=10,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = await litellm.acompletion(
|
response = await litellm.acompletion(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
max_tokens=10,
|
||||||
)
|
)
|
||||||
print(f"ASYNC REPLICATE RESPONSE - {response}")
|
print(f"ASYNC REPLICATE RESPONSE - {response}")
|
||||||
print(response)
|
print(f"REPLICATE RESPONSE - {response}")
|
||||||
# Add any assertions here to check the response
|
# Add any assertions here to check the response
|
||||||
assert isinstance(response, litellm.ModelResponse)
|
assert isinstance(response, litellm.ModelResponse)
|
||||||
|
assert len(response.choices[0].message.content.strip()) > 0
|
||||||
response_format_tests(response=response)
|
response_format_tests(response=response)
|
||||||
except litellm.APIError as e:
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
@ -3745,22 +3754,6 @@ def test_mistral_anyscale_stream():
|
||||||
# pytest.fail(f"Error occurred: {e}")
|
# pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
#### Test A121 ###################
|
|
||||||
@pytest.mark.skip(reason="Local test")
|
|
||||||
def test_completion_ai21():
|
|
||||||
print("running ai21 j2light test")
|
|
||||||
litellm.set_verbose = True
|
|
||||||
model_name = "j2-light"
|
|
||||||
try:
|
|
||||||
response = completion(
|
|
||||||
model=model_name, messages=messages, max_tokens=100, temperature=0.8
|
|
||||||
)
|
|
||||||
# Add any assertions here to check the response
|
|
||||||
print(response)
|
|
||||||
except Exception as e:
|
|
||||||
pytest.fail(f"Error occurred: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
# test_completion_ai21()
|
# test_completion_ai21()
|
||||||
# test_completion_ai21()
|
# test_completion_ai21()
|
||||||
## test deep infra
|
## test deep infra
|
||||||
|
|
|
@ -165,10 +165,10 @@ def test_get_gpt3_tokens():
|
||||||
# test_get_gpt3_tokens()
|
# test_get_gpt3_tokens()
|
||||||
|
|
||||||
|
|
||||||
def test_get_palm_tokens():
|
def test_get_gemini_tokens():
|
||||||
# # 🦄🦄🦄🦄🦄🦄🦄🦄
|
# # 🦄🦄🦄🦄🦄🦄🦄🦄
|
||||||
max_tokens = get_max_tokens("palm/chat-bison")
|
max_tokens = get_max_tokens("gemini/gemini-1.5-flash")
|
||||||
assert max_tokens == 4096
|
assert max_tokens == 8192
|
||||||
print(max_tokens)
|
print(max_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -29,19 +29,6 @@ def logger_fn(user_model_dict):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
# completion with num retries + impact on exception mapping
|
|
||||||
def test_completion_with_num_retries():
|
|
||||||
try:
|
|
||||||
response = completion(
|
|
||||||
model="j2-ultra",
|
|
||||||
messages=[{"messages": "vibe", "bad": "message"}],
|
|
||||||
num_retries=2,
|
|
||||||
)
|
|
||||||
pytest.fail(f"Unmapped exception occurred")
|
|
||||||
except Exception as e:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# test_completion_with_num_retries()
|
# test_completion_with_num_retries()
|
||||||
def test_completion_with_0_num_retries():
|
def test_completion_with_0_num_retries():
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -290,35 +290,46 @@ async def test_add_and_delete_deployments(llm_router, model_list_flag_value):
|
||||||
assert len(llm_router.model_list) == len(model_list) + prev_llm_router_val
|
assert len(llm_router.model_list) == len(model_list) + prev_llm_router_val
|
||||||
|
|
||||||
|
|
||||||
def test_provider_config_manager():
|
from litellm import LITELLM_CHAT_PROVIDERS, LlmProviders
|
||||||
from litellm import LITELLM_CHAT_PROVIDERS, LlmProviders
|
from litellm.utils import ProviderConfigManager
|
||||||
from litellm.utils import ProviderConfigManager
|
from litellm.llms.base_llm.transformation import BaseConfig
|
||||||
from litellm.llms.base_llm.transformation import BaseConfig
|
|
||||||
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
|
|
||||||
|
|
||||||
for provider in LITELLM_CHAT_PROVIDERS:
|
|
||||||
if provider == LlmProviders.TRITON or provider == LlmProviders.PREDIBASE:
|
|
||||||
continue
|
|
||||||
assert isinstance(
|
|
||||||
ProviderConfigManager.get_provider_chat_config(
|
|
||||||
model="gpt-3.5-turbo", provider=LlmProviders(provider)
|
|
||||||
),
|
|
||||||
BaseConfig,
|
|
||||||
), f"Provider {provider} is not a subclass of BaseConfig"
|
|
||||||
|
|
||||||
config = ProviderConfigManager.get_provider_chat_config(
|
def _check_provider_config(config: BaseConfig, provider: LlmProviders):
|
||||||
model="gpt-3.5-turbo", provider=LlmProviders(provider)
|
assert isinstance(
|
||||||
)
|
config,
|
||||||
|
BaseConfig,
|
||||||
if (
|
), f"Provider {provider} is not a subclass of BaseConfig. Got={config}"
|
||||||
provider != litellm.LlmProviders.OPENAI
|
|
||||||
and provider != litellm.LlmProviders.OPENAI_LIKE
|
|
||||||
and provider != litellm.LlmProviders.CUSTOM_OPENAI
|
|
||||||
):
|
|
||||||
assert (
|
|
||||||
config.__class__.__name__ != "OpenAIGPTConfig"
|
|
||||||
), f"Provider {provider} is an instance of OpenAIGPTConfig"
|
|
||||||
|
|
||||||
|
if (
|
||||||
|
provider != litellm.LlmProviders.OPENAI
|
||||||
|
and provider != litellm.LlmProviders.OPENAI_LIKE
|
||||||
|
and provider != litellm.LlmProviders.CUSTOM_OPENAI
|
||||||
|
):
|
||||||
assert (
|
assert (
|
||||||
"_abc_impl" not in config.get_config()
|
config.__class__.__name__ != "OpenAIGPTConfig"
|
||||||
), f"Provider {provider} has _abc_impl"
|
), f"Provider {provider} is an instance of OpenAIGPTConfig"
|
||||||
|
|
||||||
|
assert "_abc_impl" not in config.get_config(), f"Provider {provider} has _abc_impl"
|
||||||
|
|
||||||
|
|
||||||
|
# def test_provider_config_manager():
|
||||||
|
# from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||||
|
|
||||||
|
# for provider in LITELLM_CHAT_PROVIDERS:
|
||||||
|
# if (
|
||||||
|
# provider == LlmProviders.VERTEX_AI
|
||||||
|
# or provider == LlmProviders.VERTEX_AI_BETA
|
||||||
|
# or provider == LlmProviders.BEDROCK
|
||||||
|
# or provider == LlmProviders.BASETEN
|
||||||
|
# or provider == LlmProviders.SAGEMAKER
|
||||||
|
# or provider == LlmProviders.SAGEMAKER_CHAT
|
||||||
|
# or provider == LlmProviders.VLLM
|
||||||
|
# or provider == LlmProviders.PETALS
|
||||||
|
# or provider == LlmProviders.OLLAMA
|
||||||
|
# ):
|
||||||
|
# continue
|
||||||
|
# config = ProviderConfigManager.get_provider_chat_config(
|
||||||
|
# model="gpt-3.5-turbo", provider=LlmProviders(provider)
|
||||||
|
# )
|
||||||
|
# _check_provider_config(config, provider)
|
||||||
|
|
|
@ -522,6 +522,7 @@ async def test_basic_gcs_logging_per_request_with_no_litellm_callback_set():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.flaky(retries=5, delay=3)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_gcs_logging_config_without_service_account():
|
async def test_get_gcs_logging_config_without_service_account():
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -167,51 +167,6 @@ def cohere_test_completion():
|
||||||
|
|
||||||
# cohere_test_completion()
|
# cohere_test_completion()
|
||||||
|
|
||||||
# AI21
|
|
||||||
|
|
||||||
|
|
||||||
def ai21_test_completion():
|
|
||||||
litellm.AI21Config(maxTokens=10)
|
|
||||||
litellm.set_verbose = True
|
|
||||||
try:
|
|
||||||
# OVERRIDE WITH DYNAMIC MAX TOKENS
|
|
||||||
response_1 = litellm.completion(
|
|
||||||
model="j2-mid",
|
|
||||||
messages=[
|
|
||||||
{
|
|
||||||
"content": "Hello, how are you? Be as verbose as possible",
|
|
||||||
"role": "user",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
max_tokens=100,
|
|
||||||
)
|
|
||||||
response_1_text = response_1.choices[0].message.content
|
|
||||||
print(f"response_1_text: {response_1_text}")
|
|
||||||
|
|
||||||
# USE CONFIG TOKENS
|
|
||||||
response_2 = litellm.completion(
|
|
||||||
model="j2-mid",
|
|
||||||
messages=[
|
|
||||||
{
|
|
||||||
"content": "Hello, how are you? Be as verbose as possible",
|
|
||||||
"role": "user",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
response_2_text = response_2.choices[0].message.content
|
|
||||||
print(f"response_2_text: {response_2_text}")
|
|
||||||
|
|
||||||
assert len(response_2_text) < len(response_1_text)
|
|
||||||
|
|
||||||
response_3 = litellm.completion(
|
|
||||||
model="j2-light",
|
|
||||||
messages=[{"content": "Hello, how are you?", "role": "user"}],
|
|
||||||
n=2,
|
|
||||||
)
|
|
||||||
assert len(response_3.choices) > 1
|
|
||||||
except Exception as e:
|
|
||||||
pytest.fail(f"Error occurred: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
# ai21_test_completion()
|
# ai21_test_completion()
|
||||||
|
|
||||||
|
|
|
@ -47,6 +47,7 @@ def cleanup_redis():
|
||||||
print(f"Error cleaning up Redis: {str(e)}")
|
print(f"Error cleaning up Redis: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.flaky(retries=6, delay=2)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_provider_budgets_e2e_test():
|
async def test_provider_budgets_e2e_test():
|
||||||
"""
|
"""
|
||||||
|
@ -106,7 +107,7 @@ async def test_provider_budgets_e2e_test():
|
||||||
|
|
||||||
print("response.hidden_params", response._hidden_params)
|
print("response.hidden_params", response._hidden_params)
|
||||||
|
|
||||||
await asyncio.sleep(0.5)
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
assert response._hidden_params.get("custom_llm_provider") == "azure"
|
assert response._hidden_params.get("custom_llm_provider") == "azure"
|
||||||
|
|
||||||
|
|
|
@ -1931,66 +1931,11 @@ async def test_completion_watsonx_stream():
|
||||||
# raise Exception("Empty response received")
|
# raise Exception("Empty response received")
|
||||||
# except Exception:
|
# except Exception:
|
||||||
# pytest.fail(f"error occurred: {traceback.format_exc()}")
|
# pytest.fail(f"error occurred: {traceback.format_exc()}")
|
||||||
# test_maritalk_streaming()
|
|
||||||
# test on openai completion call
|
|
||||||
|
|
||||||
|
|
||||||
# # test on ai21 completion call
|
|
||||||
def ai21_completion_call():
|
|
||||||
try:
|
|
||||||
messages = [
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": "You are an all-knowing oracle",
|
|
||||||
},
|
|
||||||
{"role": "user", "content": "What is the meaning of the Universe?"},
|
|
||||||
]
|
|
||||||
response = completion(
|
|
||||||
model="j2-ultra", messages=messages, stream=True, max_tokens=500
|
|
||||||
)
|
|
||||||
print(f"response: {response}")
|
|
||||||
has_finished = False
|
|
||||||
complete_response = ""
|
|
||||||
start_time = time.time()
|
|
||||||
for idx, chunk in enumerate(response):
|
|
||||||
chunk, finished = streaming_format_tests(idx, chunk)
|
|
||||||
has_finished = finished
|
|
||||||
complete_response += chunk
|
|
||||||
if finished:
|
|
||||||
break
|
|
||||||
if has_finished is False:
|
|
||||||
raise Exception("finished reason missing from final chunk")
|
|
||||||
if complete_response.strip() == "":
|
|
||||||
raise Exception("Empty response received")
|
|
||||||
print(f"completion_response: {complete_response}")
|
|
||||||
except Exception:
|
|
||||||
pytest.fail(f"error occurred: {traceback.format_exc()}")
|
|
||||||
|
|
||||||
|
|
||||||
# ai21_completion_call()
|
# ai21_completion_call()
|
||||||
|
|
||||||
|
|
||||||
def ai21_completion_call_bad_key():
|
|
||||||
try:
|
|
||||||
api_key = "bad-key"
|
|
||||||
response = completion(
|
|
||||||
model="j2-ultra", messages=messages, stream=True, api_key=api_key
|
|
||||||
)
|
|
||||||
print(f"response: {response}")
|
|
||||||
complete_response = ""
|
|
||||||
start_time = time.time()
|
|
||||||
for idx, chunk in enumerate(response):
|
|
||||||
chunk, finished = streaming_format_tests(idx, chunk)
|
|
||||||
if finished:
|
|
||||||
break
|
|
||||||
complete_response += chunk
|
|
||||||
if complete_response.strip() == "":
|
|
||||||
raise Exception("Empty response received")
|
|
||||||
print(f"completion_response: {complete_response}")
|
|
||||||
except Exception:
|
|
||||||
pytest.fail(f"error occurred: {traceback.format_exc()}")
|
|
||||||
|
|
||||||
|
|
||||||
# ai21_completion_call_bad_key()
|
# ai21_completion_call_bad_key()
|
||||||
|
|
||||||
|
|
||||||
|
@ -2418,34 +2363,6 @@ def test_completion_openai_with_functions():
|
||||||
#### Test Async streaming ####
|
#### Test Async streaming ####
|
||||||
|
|
||||||
|
|
||||||
# # test on ai21 completion call
|
|
||||||
async def ai21_async_completion_call():
|
|
||||||
try:
|
|
||||||
response = completion(
|
|
||||||
model="j2-ultra", messages=messages, stream=True, logger_fn=logger_fn
|
|
||||||
)
|
|
||||||
print(f"response: {response}")
|
|
||||||
complete_response = ""
|
|
||||||
start_time = time.time()
|
|
||||||
# Change for loop to async for loop
|
|
||||||
idx = 0
|
|
||||||
async for chunk in response:
|
|
||||||
chunk, finished = streaming_format_tests(idx, chunk)
|
|
||||||
if finished:
|
|
||||||
break
|
|
||||||
complete_response += chunk
|
|
||||||
idx += 1
|
|
||||||
if complete_response.strip() == "":
|
|
||||||
raise Exception("Empty response received")
|
|
||||||
print(f"complete response: {complete_response}")
|
|
||||||
except Exception:
|
|
||||||
print(f"error occurred: {traceback.format_exc()}")
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# asyncio.run(ai21_async_completion_call())
|
|
||||||
|
|
||||||
|
|
||||||
async def completion_call():
|
async def completion_call():
|
||||||
try:
|
try:
|
||||||
response = completion(
|
response = completion(
|
||||||
|
|
|
@ -3934,6 +3934,7 @@ def test_completion_text_003_prompt_array():
|
||||||
|
|
||||||
|
|
||||||
##### hugging face tests
|
##### hugging face tests
|
||||||
|
@pytest.mark.skip(reason="local test")
|
||||||
def test_completion_hf_prompt_array():
|
def test_completion_hf_prompt_array():
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
|
|
@ -437,8 +437,8 @@ def test_token_counter():
|
||||||
print(tokens)
|
print(tokens)
|
||||||
assert tokens > 0
|
assert tokens > 0
|
||||||
|
|
||||||
tokens = token_counter(model="palm/chat-bison", messages=messages)
|
tokens = token_counter(model="gemini/chat-bison", messages=messages)
|
||||||
print("palm/chat-bison")
|
print("gemini/chat-bison")
|
||||||
print(tokens)
|
print(tokens)
|
||||||
assert tokens > 0
|
assert tokens > 0
|
||||||
|
|
||||||
|
@ -465,7 +465,7 @@ def test_token_counter():
|
||||||
("azure/gpt-4-1106-preview", True),
|
("azure/gpt-4-1106-preview", True),
|
||||||
("groq/gemma-7b-it", True),
|
("groq/gemma-7b-it", True),
|
||||||
("anthropic.claude-instant-v1", False),
|
("anthropic.claude-instant-v1", False),
|
||||||
("palm/chat-bison", False),
|
("gemini/gemini-1.5-flash", True),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_supports_function_calling(model, expected_bool):
|
def test_supports_function_calling(model, expected_bool):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue