forked from phoenix/litellm-mirror
[Feat] Add fireworks AI embedding (#5812)
* add fireworks embedding models * add fireworks ai * fireworks ai embeddings support * is_fireworks_embedding_model * working fireworks embeddings * fix health check * models * fix embedding get optional params * fix linting errors * fix pick_cheapest_chat_model_from_llm_provider * add fireworks ai litellm provider * docs fireworks embedding models * fixes for when azure ad token is passed
This commit is contained in:
parent
d349d501c8
commit
1d630b61ad
9 changed files with 181 additions and 61 deletions
|
@ -150,4 +150,18 @@ We support ALL Fireworks AI models, just set `fireworks_ai/` as a prefix when se
|
||||||
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||||
| mixtral-8x7b-instruct | `completion(model="fireworks_ai/mixtral-8x7b-instruct", messages)` |
|
| mixtral-8x7b-instruct | `completion(model="fireworks_ai/mixtral-8x7b-instruct", messages)` |
|
||||||
| firefunction-v1 | `completion(model="fireworks_ai/firefunction-v1", messages)` |
|
| firefunction-v1 | `completion(model="fireworks_ai/firefunction-v1", messages)` |
|
||||||
| llama-v2-70b-chat | `completion(model="fireworks_ai/llama-v2-70b-chat", messages)` |
|
| llama-v2-70b-chat | `completion(model="fireworks_ai/llama-v2-70b-chat", messages)` |
|
||||||
|
|
||||||
|
## Supported Embedding Models
|
||||||
|
|
||||||
|
:::info
|
||||||
|
We support ALL Fireworks AI models, just set `fireworks_ai/` as a prefix when sending embedding requests
|
||||||
|
:::
|
||||||
|
|
||||||
|
| Model Name | Function Call |
|
||||||
|
|-----------------------|-----------------------------------------------------------------|
|
||||||
|
| fireworks_ai/nomic-ai/nomic-embed-text-v1.5 | `response = litellm.embedding(model="fireworks_ai/nomic-ai/nomic-embed-text-v1.5", input=input_text)` |
|
||||||
|
| fireworks_ai/nomic-ai/nomic-embed-text-v1 | `response = litellm.embedding(model="fireworks_ai/nomic-ai/nomic-embed-text-v1", input=input_text)` |
|
||||||
|
| fireworks_ai/WhereIsAI/UAE-Large-V1 | `response = litellm.embedding(model="fireworks_ai/WhereIsAI/UAE-Large-V1", input=input_text)` |
|
||||||
|
| fireworks_ai/thenlper/gte-large | `response = litellm.embedding(model="fireworks_ai/thenlper/gte-large", input=input_text)` |
|
||||||
|
| fireworks_ai/thenlper/gte-base | `response = litellm.embedding(model="fireworks_ai/thenlper/gte-base", input=input_text)` |
|
|
@ -378,6 +378,7 @@ nlp_cloud_models: List = []
|
||||||
aleph_alpha_models: List = []
|
aleph_alpha_models: List = []
|
||||||
bedrock_models: List = []
|
bedrock_models: List = []
|
||||||
fireworks_ai_models: List = []
|
fireworks_ai_models: List = []
|
||||||
|
fireworks_ai_embedding_models: List = []
|
||||||
deepinfra_models: List = []
|
deepinfra_models: List = []
|
||||||
perplexity_models: List = []
|
perplexity_models: List = []
|
||||||
watsonx_models: List = []
|
watsonx_models: List = []
|
||||||
|
@ -454,6 +455,10 @@ def add_known_models():
|
||||||
# ignore the 'up-to', '-to-' model names -> not real models. just for cost tracking based on model params.
|
# ignore the 'up-to', '-to-' model names -> not real models. just for cost tracking based on model params.
|
||||||
if "-to-" not in key:
|
if "-to-" not in key:
|
||||||
fireworks_ai_models.append(key)
|
fireworks_ai_models.append(key)
|
||||||
|
elif value.get("litellm_provider") == "fireworks_ai-embedding-models":
|
||||||
|
# ignore the 'up-to', '-to-' model names -> not real models. just for cost tracking based on model params.
|
||||||
|
if "-to-" not in key:
|
||||||
|
fireworks_ai_embedding_models.append(key)
|
||||||
|
|
||||||
|
|
||||||
add_known_models()
|
add_known_models()
|
||||||
|
@ -779,7 +784,7 @@ models_by_provider: dict = {
|
||||||
"maritalk": maritalk_models,
|
"maritalk": maritalk_models,
|
||||||
"watsonx": watsonx_models,
|
"watsonx": watsonx_models,
|
||||||
"gemini": gemini_models,
|
"gemini": gemini_models,
|
||||||
"fireworks_ai": fireworks_ai_models,
|
"fireworks_ai": fireworks_ai_models + fireworks_ai_embedding_models,
|
||||||
}
|
}
|
||||||
|
|
||||||
# mapping for those models which have larger equivalents
|
# mapping for those models which have larger equivalents
|
||||||
|
@ -825,6 +830,7 @@ all_embedding_models = (
|
||||||
+ cohere_embedding_models
|
+ cohere_embedding_models
|
||||||
+ bedrock_embedding_models
|
+ bedrock_embedding_models
|
||||||
+ vertex_embedding_models
|
+ vertex_embedding_models
|
||||||
|
+ fireworks_ai_embedding_models
|
||||||
)
|
)
|
||||||
|
|
||||||
####### IMAGE GENERATION MODELS ###################
|
####### IMAGE GENERATION MODELS ###################
|
||||||
|
@ -971,6 +977,9 @@ from .llms.cerebras.chat import CerebrasConfig
|
||||||
from .llms.sambanova.chat import SambanovaConfig
|
from .llms.sambanova.chat import SambanovaConfig
|
||||||
from .llms.AI21.chat import AI21ChatConfig
|
from .llms.AI21.chat import AI21ChatConfig
|
||||||
from .llms.fireworks_ai.chat.fireworks_ai_transformation import FireworksAIConfig
|
from .llms.fireworks_ai.chat.fireworks_ai_transformation import FireworksAIConfig
|
||||||
|
from .llms.fireworks_ai.embed.fireworks_ai_transformation import (
|
||||||
|
FireworksAIEmbeddingConfig,
|
||||||
|
)
|
||||||
from .llms.volcengine import VolcEngineConfig
|
from .llms.volcengine import VolcEngineConfig
|
||||||
from .llms.text_completion_codestral import MistralTextCompletionConfig
|
from .llms.text_completion_codestral import MistralTextCompletionConfig
|
||||||
from .llms.AzureOpenAI.azure import (
|
from .llms.AzureOpenAI.azure import (
|
||||||
|
|
|
@ -216,7 +216,12 @@ def get_llm_provider(
|
||||||
dynamic_api_key = api_key or get_secret("DEEPSEEK_API_KEY")
|
dynamic_api_key = api_key or get_secret("DEEPSEEK_API_KEY")
|
||||||
elif custom_llm_provider == "fireworks_ai":
|
elif custom_llm_provider == "fireworks_ai":
|
||||||
# fireworks is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.fireworks.ai/inference/v1
|
# fireworks is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.fireworks.ai/inference/v1
|
||||||
if not model.startswith("accounts/"):
|
if litellm.FireworksAIEmbeddingConfig().is_fireworks_embedding_model(
|
||||||
|
model=model
|
||||||
|
):
|
||||||
|
# fireworks embeddings models do no require accounts/fireworks prefix https://docs.fireworks.ai/api-reference/creates-an-embedding-vector-representing-the-input-text
|
||||||
|
pass
|
||||||
|
elif not model.startswith("accounts/"):
|
||||||
model = f"accounts/fireworks/models/{model}"
|
model = f"accounts/fireworks/models/{model}"
|
||||||
api_base = (
|
api_base = (
|
||||||
api_base
|
api_base
|
||||||
|
|
|
@ -30,9 +30,9 @@ def _ensure_extra_body_is_safe(extra_body: Optional[Dict]) -> Optional[Dict]:
|
||||||
return extra_body
|
return extra_body
|
||||||
|
|
||||||
|
|
||||||
def pick_cheapest_model_from_llm_provider(custom_llm_provider: str):
|
def pick_cheapest_chat_model_from_llm_provider(custom_llm_provider: str):
|
||||||
"""
|
"""
|
||||||
Pick a random model from the LLM provider.
|
Pick the cheapest chat model from the LLM provider.
|
||||||
"""
|
"""
|
||||||
if custom_llm_provider not in litellm.models_by_provider:
|
if custom_llm_provider not in litellm.models_by_provider:
|
||||||
raise ValueError(f"Unknown LLM provider: {custom_llm_provider}")
|
raise ValueError(f"Unknown LLM provider: {custom_llm_provider}")
|
||||||
|
@ -41,9 +41,14 @@ def pick_cheapest_model_from_llm_provider(custom_llm_provider: str):
|
||||||
min_cost = float("inf")
|
min_cost = float("inf")
|
||||||
cheapest_model = None
|
cheapest_model = None
|
||||||
for model in known_models:
|
for model in known_models:
|
||||||
model_info = litellm.get_model_info(
|
try:
|
||||||
model=model, custom_llm_provider=custom_llm_provider
|
model_info = litellm.get_model_info(
|
||||||
)
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
if model_info.get("mode") != "chat":
|
||||||
|
continue
|
||||||
_cost = model_info.get("input_cost_per_token", 0) + model_info.get(
|
_cost = model_info.get("input_cost_per_token", 0) + model_info.get(
|
||||||
"output_cost_per_token", 0
|
"output_cost_per_token", 0
|
||||||
)
|
)
|
||||||
|
|
|
@ -1032,9 +1032,9 @@ class AzureChatCompletion(BaseLLM):
|
||||||
data: dict,
|
data: dict,
|
||||||
model_response: EmbeddingResponse,
|
model_response: EmbeddingResponse,
|
||||||
azure_client_params: dict,
|
azure_client_params: dict,
|
||||||
api_key: str,
|
|
||||||
input: list,
|
input: list,
|
||||||
logging_obj: LiteLLMLoggingObj,
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
client: Optional[AsyncAzureOpenAI] = None,
|
client: Optional[AsyncAzureOpenAI] = None,
|
||||||
timeout=None,
|
timeout=None,
|
||||||
):
|
):
|
||||||
|
@ -1078,13 +1078,13 @@ class AzureChatCompletion(BaseLLM):
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
input: list,
|
input: list,
|
||||||
api_key: str,
|
|
||||||
api_base: str,
|
api_base: str,
|
||||||
api_version: str,
|
api_version: str,
|
||||||
timeout: float,
|
timeout: float,
|
||||||
logging_obj: LiteLLMLoggingObj,
|
logging_obj: LiteLLMLoggingObj,
|
||||||
model_response: EmbeddingResponse,
|
model_response: EmbeddingResponse,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
azure_ad_token: Optional[str] = None,
|
azure_ad_token: Optional[str] = None,
|
||||||
client=None,
|
client=None,
|
||||||
aembedding=None,
|
aembedding=None,
|
||||||
|
|
|
@ -0,0 +1,47 @@
|
||||||
|
"""
|
||||||
|
This is OpenAI compatible - no transformation is applied
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import types
|
||||||
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
|
||||||
|
class FireworksAIEmbeddingConfig:
|
||||||
|
def get_supported_openai_params(self, model: str):
|
||||||
|
"""
|
||||||
|
dimensions Only supported in nomic-ai/nomic-embed-text-v1.5 and later models.
|
||||||
|
|
||||||
|
https://docs.fireworks.ai/api-reference/creates-an-embedding-vector-representing-the-input-text
|
||||||
|
"""
|
||||||
|
if "nomic-ai" in model:
|
||||||
|
return ["dimensions"]
|
||||||
|
return []
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self, non_default_params: dict, optional_params: dict, model: str
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
No transformation is applied - fireworks ai is openai compatible
|
||||||
|
"""
|
||||||
|
supported_openai_params = self.get_supported_openai_params(model)
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param in supported_openai_params:
|
||||||
|
optional_params[param] = value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
def is_fireworks_embedding_model(self, model: str):
|
||||||
|
"""
|
||||||
|
helper to check if a model is a fireworks embedding model
|
||||||
|
|
||||||
|
Fireworks embeddings does not support passing /accounts/fireworks in the model name so we need to know if it's a known embedding model
|
||||||
|
"""
|
||||||
|
if (
|
||||||
|
model in litellm.fireworks_ai_embedding_models
|
||||||
|
or f"fireworks_ai/{model}" in litellm.fireworks_ai_embedding_models
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
|
@ -41,6 +41,7 @@ from litellm import ( # type: ignore
|
||||||
)
|
)
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
|
from litellm.secret_managers.main import get_secret_str
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
CustomStreamWrapper,
|
CustomStreamWrapper,
|
||||||
Usage,
|
Usage,
|
||||||
|
@ -3435,27 +3436,33 @@ def embedding(
|
||||||
)
|
)
|
||||||
if azure is True or custom_llm_provider == "azure":
|
if azure is True or custom_llm_provider == "azure":
|
||||||
# azure configs
|
# azure configs
|
||||||
api_type = get_secret("AZURE_API_TYPE") or "azure"
|
api_type = get_secret_str("AZURE_API_TYPE") or "azure"
|
||||||
|
|
||||||
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE")
|
api_base = api_base or litellm.api_base or get_secret_str("AZURE_API_BASE")
|
||||||
|
|
||||||
api_version = (
|
api_version = (
|
||||||
api_version
|
api_version
|
||||||
or litellm.api_version
|
or litellm.api_version
|
||||||
or get_secret("AZURE_API_VERSION")
|
or get_secret_str("AZURE_API_VERSION")
|
||||||
or litellm.AZURE_DEFAULT_API_VERSION
|
or litellm.AZURE_DEFAULT_API_VERSION
|
||||||
)
|
)
|
||||||
|
|
||||||
azure_ad_token = optional_params.pop("azure_ad_token", None) or get_secret(
|
azure_ad_token = optional_params.pop(
|
||||||
"AZURE_AD_TOKEN"
|
"azure_ad_token", None
|
||||||
)
|
) or get_secret_str("AZURE_AD_TOKEN")
|
||||||
|
|
||||||
api_key = (
|
api_key = (
|
||||||
api_key
|
api_key
|
||||||
or litellm.api_key
|
or litellm.api_key
|
||||||
or litellm.azure_key
|
or litellm.azure_key
|
||||||
or get_secret("AZURE_API_KEY")
|
or get_secret_str("AZURE_API_KEY")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if api_base is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"No API Base provided for Azure OpenAI LLM provider. Set 'AZURE_API_BASE' in .env"
|
||||||
|
)
|
||||||
|
|
||||||
## EMBEDDING CALL
|
## EMBEDDING CALL
|
||||||
response = azure_chat_completions.embedding(
|
response = azure_chat_completions.embedding(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -3477,12 +3484,12 @@ def embedding(
|
||||||
api_base = (
|
api_base = (
|
||||||
api_base
|
api_base
|
||||||
or litellm.api_base
|
or litellm.api_base
|
||||||
or get_secret("OPENAI_API_BASE")
|
or get_secret_str("OPENAI_API_BASE")
|
||||||
or "https://api.openai.com/v1"
|
or "https://api.openai.com/v1"
|
||||||
)
|
)
|
||||||
openai.organization = (
|
openai.organization = (
|
||||||
litellm.organization
|
litellm.organization
|
||||||
or get_secret("OPENAI_ORGANIZATION")
|
or get_secret_str("OPENAI_ORGANIZATION")
|
||||||
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||||
)
|
)
|
||||||
# set API KEY
|
# set API KEY
|
||||||
|
@ -3490,7 +3497,7 @@ def embedding(
|
||||||
api_key
|
api_key
|
||||||
or litellm.api_key
|
or litellm.api_key
|
||||||
or litellm.openai_key
|
or litellm.openai_key
|
||||||
or get_secret("OPENAI_API_KEY")
|
or get_secret_str("OPENAI_API_KEY")
|
||||||
)
|
)
|
||||||
api_type = "openai"
|
api_type = "openai"
|
||||||
api_version = None
|
api_version = None
|
||||||
|
@ -3618,7 +3625,9 @@ def embedding(
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "gemini":
|
elif custom_llm_provider == "gemini":
|
||||||
|
|
||||||
gemini_api_key = api_key or get_secret("GEMINI_API_KEY") or litellm.api_key
|
gemini_api_key = (
|
||||||
|
api_key or get_secret_str("GEMINI_API_KEY") or litellm.api_key
|
||||||
|
)
|
||||||
|
|
||||||
response = google_batch_embeddings.batch_embeddings( # type: ignore
|
response = google_batch_embeddings.batch_embeddings( # type: ignore
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -3743,7 +3752,23 @@ def embedding(
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "mistral":
|
elif custom_llm_provider == "mistral":
|
||||||
api_key = api_key or litellm.api_key or get_secret("MISTRAL_API_KEY")
|
api_key = api_key or litellm.api_key or get_secret_str("MISTRAL_API_KEY")
|
||||||
|
response = openai_chat_completions.embedding(
|
||||||
|
model=model,
|
||||||
|
input=input,
|
||||||
|
api_base=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
logging_obj=logging,
|
||||||
|
timeout=timeout,
|
||||||
|
model_response=EmbeddingResponse(),
|
||||||
|
optional_params=optional_params,
|
||||||
|
client=client,
|
||||||
|
aembedding=aembedding,
|
||||||
|
)
|
||||||
|
elif custom_llm_provider == "fireworks_ai":
|
||||||
|
api_key = (
|
||||||
|
api_key or litellm.api_key or get_secret_str("FIREWORKS_AI_API_KEY")
|
||||||
|
)
|
||||||
response = openai_chat_completions.embedding(
|
response = openai_chat_completions.embedding(
|
||||||
model=model,
|
model=model,
|
||||||
input=input,
|
input=input,
|
||||||
|
@ -3757,7 +3782,7 @@ def embedding(
|
||||||
aembedding=aembedding,
|
aembedding=aembedding,
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "voyage":
|
elif custom_llm_provider == "voyage":
|
||||||
api_key = api_key or litellm.api_key or get_secret("VOYAGE_API_KEY")
|
api_key = api_key or litellm.api_key or get_secret_str("VOYAGE_API_KEY")
|
||||||
response = openai_chat_completions.embedding(
|
response = openai_chat_completions.embedding(
|
||||||
model=model,
|
model=model,
|
||||||
input=input,
|
input=input,
|
||||||
|
@ -5170,11 +5195,11 @@ async def ahealth_check(
|
||||||
response = {}
|
response = {}
|
||||||
elif "*" in model:
|
elif "*" in model:
|
||||||
from litellm.litellm_core_utils.llm_request_utils import (
|
from litellm.litellm_core_utils.llm_request_utils import (
|
||||||
pick_cheapest_model_from_llm_provider,
|
pick_cheapest_chat_model_from_llm_provider,
|
||||||
)
|
)
|
||||||
|
|
||||||
# this is a wildcard model, we need to pick a random model from the provider
|
# this is a wildcard model, we need to pick a random model from the provider
|
||||||
cheapest_model = pick_cheapest_model_from_llm_provider(
|
cheapest_model = pick_cheapest_chat_model_from_llm_provider(
|
||||||
custom_llm_provider=custom_llm_provider
|
custom_llm_provider=custom_llm_provider
|
||||||
)
|
)
|
||||||
model_params["model"] = cheapest_model
|
model_params["model"] = cheapest_model
|
||||||
|
|
|
@ -657,6 +657,24 @@ def test_mistral_embeddings():
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_fireworks_embeddings():
|
||||||
|
try:
|
||||||
|
litellm.set_verbose = True
|
||||||
|
response = litellm.embedding(
|
||||||
|
model="fireworks_ai/nomic-ai/nomic-embed-text-v1.5",
|
||||||
|
input=["good morning from litellm"],
|
||||||
|
)
|
||||||
|
print(f"response: {response}")
|
||||||
|
assert isinstance(response.usage, litellm.Usage)
|
||||||
|
cost = completion_cost(completion_response=response)
|
||||||
|
print("cost", cost)
|
||||||
|
assert cost > 0.0
|
||||||
|
print(response._hidden_params)
|
||||||
|
assert response._hidden_params["response_cost"] > 0.0
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
def test_watsonx_embeddings():
|
def test_watsonx_embeddings():
|
||||||
|
|
||||||
def mock_wx_embed_request(method: str, url: str, **kwargs):
|
def mock_wx_embed_request(method: str, url: str, **kwargs):
|
||||||
|
|
|
@ -2610,13 +2610,13 @@ def get_optional_params_embeddings(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
message="Setting dimensions is not supported for OpenAI `text-embedding-3` and later models. To drop it from the call, set `litellm.drop_params = True`.",
|
message="Setting dimensions is not supported for OpenAI `text-embedding-3` and later models. To drop it from the call, set `litellm.drop_params = True`.",
|
||||||
)
|
)
|
||||||
if custom_llm_provider == "triton":
|
elif custom_llm_provider == "triton":
|
||||||
keys = list(non_default_params.keys())
|
keys = list(non_default_params.keys())
|
||||||
for k in keys:
|
for k in keys:
|
||||||
non_default_params.pop(k, None)
|
non_default_params.pop(k, None)
|
||||||
final_params = {**non_default_params, **kwargs}
|
final_params = {**non_default_params, **kwargs}
|
||||||
return final_params
|
return final_params
|
||||||
if custom_llm_provider == "databricks":
|
elif custom_llm_provider == "databricks":
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
model=model or "",
|
model=model or "",
|
||||||
custom_llm_provider="databricks",
|
custom_llm_provider="databricks",
|
||||||
|
@ -2628,7 +2628,7 @@ def get_optional_params_embeddings(
|
||||||
)
|
)
|
||||||
final_params = {**optional_params, **kwargs}
|
final_params = {**optional_params, **kwargs}
|
||||||
return final_params
|
return final_params
|
||||||
if custom_llm_provider == "vertex_ai":
|
elif custom_llm_provider == "vertex_ai":
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
model=model,
|
model=model,
|
||||||
custom_llm_provider="vertex_ai",
|
custom_llm_provider="vertex_ai",
|
||||||
|
@ -2643,7 +2643,7 @@ def get_optional_params_embeddings(
|
||||||
)
|
)
|
||||||
final_params = {**optional_params, **kwargs}
|
final_params = {**optional_params, **kwargs}
|
||||||
return final_params
|
return final_params
|
||||||
if custom_llm_provider == "bedrock":
|
elif custom_llm_provider == "bedrock":
|
||||||
# if dimensions is in non_default_params -> pass it for model=bedrock/amazon.titan-embed-text-v2
|
# if dimensions is in non_default_params -> pass it for model=bedrock/amazon.titan-embed-text-v2
|
||||||
if "amazon.titan-embed-text-v1" in model:
|
if "amazon.titan-embed-text-v1" in model:
|
||||||
object: Any = litellm.AmazonTitanG1Config()
|
object: Any = litellm.AmazonTitanG1Config()
|
||||||
|
@ -2666,35 +2666,7 @@ def get_optional_params_embeddings(
|
||||||
)
|
)
|
||||||
final_params = {**optional_params, **kwargs}
|
final_params = {**optional_params, **kwargs}
|
||||||
return final_params
|
return final_params
|
||||||
# elif model == "amazon.titan-embed-image-v1":
|
elif custom_llm_provider == "mistral":
|
||||||
# supported_params = litellm.AmazonTitanG1Config().get_supported_openai_params()
|
|
||||||
# _check_valid_arg(supported_params=supported_params)
|
|
||||||
# optional_params = litellm.AmazonTitanG1Config().map_openai_params(
|
|
||||||
# non_default_params=non_default_params, optional_params={}
|
|
||||||
# )
|
|
||||||
# final_params = {**optional_params, **kwargs}
|
|
||||||
# return final_params
|
|
||||||
|
|
||||||
# if (
|
|
||||||
# "dimensions" in non_default_params.keys()
|
|
||||||
# and "amazon.titan-embed-text-v2" in model
|
|
||||||
# ):
|
|
||||||
# kwargs["dimensions"] = non_default_params["dimensions"]
|
|
||||||
# non_default_params.pop("dimensions", None)
|
|
||||||
|
|
||||||
# if len(non_default_params.keys()) > 0:
|
|
||||||
# if litellm.drop_params is True: # drop the unsupported non-default values
|
|
||||||
# keys = list(non_default_params.keys())
|
|
||||||
# for k in keys:
|
|
||||||
# non_default_params.pop(k, None)
|
|
||||||
# final_params = {**non_default_params, **kwargs}
|
|
||||||
# return final_params
|
|
||||||
# raise UnsupportedParamsError(
|
|
||||||
# status_code=500,
|
|
||||||
# message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.",
|
|
||||||
# )
|
|
||||||
# return {**non_default_params, **kwargs}
|
|
||||||
if custom_llm_provider == "mistral":
|
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
model=model,
|
model=model,
|
||||||
custom_llm_provider="mistral",
|
custom_llm_provider="mistral",
|
||||||
|
@ -2706,7 +2678,20 @@ def get_optional_params_embeddings(
|
||||||
)
|
)
|
||||||
final_params = {**optional_params, **kwargs}
|
final_params = {**optional_params, **kwargs}
|
||||||
return final_params
|
return final_params
|
||||||
if (
|
elif custom_llm_provider == "fireworks_ai":
|
||||||
|
supported_params = get_supported_openai_params(
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="fireworks_ai",
|
||||||
|
request_type="embeddings",
|
||||||
|
)
|
||||||
|
_check_valid_arg(supported_params=supported_params)
|
||||||
|
optional_params = litellm.FireworksAIEmbeddingConfig().map_openai_params(
|
||||||
|
non_default_params=non_default_params, optional_params={}, model=model
|
||||||
|
)
|
||||||
|
final_params = {**optional_params, **kwargs}
|
||||||
|
return final_params
|
||||||
|
|
||||||
|
elif (
|
||||||
custom_llm_provider != "openai"
|
custom_llm_provider != "openai"
|
||||||
and custom_llm_provider != "azure"
|
and custom_llm_provider != "azure"
|
||||||
and custom_llm_provider not in litellm.openai_compatible_providers
|
and custom_llm_provider not in litellm.openai_compatible_providers
|
||||||
|
@ -2723,7 +2708,6 @@ def get_optional_params_embeddings(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.",
|
message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.",
|
||||||
)
|
)
|
||||||
|
|
||||||
final_params = {**non_default_params, **kwargs}
|
final_params = {**non_default_params, **kwargs}
|
||||||
return final_params
|
return final_params
|
||||||
|
|
||||||
|
@ -4293,7 +4277,12 @@ def get_supported_openai_params(
|
||||||
elif custom_llm_provider == "anthropic":
|
elif custom_llm_provider == "anthropic":
|
||||||
return litellm.AnthropicConfig().get_supported_openai_params()
|
return litellm.AnthropicConfig().get_supported_openai_params()
|
||||||
elif custom_llm_provider == "fireworks_ai":
|
elif custom_llm_provider == "fireworks_ai":
|
||||||
return litellm.FireworksAIConfig().get_supported_openai_params()
|
if request_type == "embeddings":
|
||||||
|
return litellm.FireworksAIEmbeddingConfig().get_supported_openai_params(
|
||||||
|
model=model
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return litellm.FireworksAIConfig().get_supported_openai_params()
|
||||||
elif custom_llm_provider == "nvidia_nim":
|
elif custom_llm_provider == "nvidia_nim":
|
||||||
return litellm.NvidiaNimConfig().get_supported_openai_params(model=model)
|
return litellm.NvidiaNimConfig().get_supported_openai_params(model=model)
|
||||||
elif custom_llm_provider == "cerebras":
|
elif custom_llm_provider == "cerebras":
|
||||||
|
@ -4915,6 +4904,10 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
||||||
"litellm_provider"
|
"litellm_provider"
|
||||||
].startswith("vertex_ai"):
|
].startswith("vertex_ai"):
|
||||||
pass
|
pass
|
||||||
|
elif custom_llm_provider == "fireworks_ai" and _model_info[
|
||||||
|
"litellm_provider"
|
||||||
|
].startswith("fireworks_ai"):
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
raise Exception
|
raise Exception
|
||||||
elif split_model in litellm.model_cost:
|
elif split_model in litellm.model_cost:
|
||||||
|
@ -4929,6 +4922,10 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
||||||
"litellm_provider"
|
"litellm_provider"
|
||||||
].startswith("vertex_ai"):
|
].startswith("vertex_ai"):
|
||||||
pass
|
pass
|
||||||
|
elif custom_llm_provider == "fireworks_ai" and _model_info[
|
||||||
|
"litellm_provider"
|
||||||
|
].startswith("fireworks_ai"):
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
raise Exception
|
raise Exception
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue