forked from phoenix/litellm-mirror
[Feat] Add fireworks AI embedding (#5812)
* add fireworks embedding models * add fireworks ai * fireworks ai embeddings support * is_fireworks_embedding_model * working fireworks embeddings * fix health check * models * fix embedding get optional params * fix linting errors * fix pick_cheapest_chat_model_from_llm_provider * add fireworks ai litellm provider * docs fireworks embedding models * fixes for when azure ad token is passed
This commit is contained in:
parent
d349d501c8
commit
1d630b61ad
9 changed files with 181 additions and 61 deletions
|
@ -150,4 +150,18 @@ We support ALL Fireworks AI models, just set `fireworks_ai/` as a prefix when se
|
|||
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| mixtral-8x7b-instruct | `completion(model="fireworks_ai/mixtral-8x7b-instruct", messages)` |
|
||||
| firefunction-v1 | `completion(model="fireworks_ai/firefunction-v1", messages)` |
|
||||
| llama-v2-70b-chat | `completion(model="fireworks_ai/llama-v2-70b-chat", messages)` |
|
||||
| llama-v2-70b-chat | `completion(model="fireworks_ai/llama-v2-70b-chat", messages)` |
|
||||
|
||||
## Supported Embedding Models
|
||||
|
||||
:::info
|
||||
We support ALL Fireworks AI models, just set `fireworks_ai/` as a prefix when sending embedding requests
|
||||
:::
|
||||
|
||||
| Model Name | Function Call |
|
||||
|-----------------------|-----------------------------------------------------------------|
|
||||
| fireworks_ai/nomic-ai/nomic-embed-text-v1.5 | `response = litellm.embedding(model="fireworks_ai/nomic-ai/nomic-embed-text-v1.5", input=input_text)` |
|
||||
| fireworks_ai/nomic-ai/nomic-embed-text-v1 | `response = litellm.embedding(model="fireworks_ai/nomic-ai/nomic-embed-text-v1", input=input_text)` |
|
||||
| fireworks_ai/WhereIsAI/UAE-Large-V1 | `response = litellm.embedding(model="fireworks_ai/WhereIsAI/UAE-Large-V1", input=input_text)` |
|
||||
| fireworks_ai/thenlper/gte-large | `response = litellm.embedding(model="fireworks_ai/thenlper/gte-large", input=input_text)` |
|
||||
| fireworks_ai/thenlper/gte-base | `response = litellm.embedding(model="fireworks_ai/thenlper/gte-base", input=input_text)` |
|
|
@ -378,6 +378,7 @@ nlp_cloud_models: List = []
|
|||
aleph_alpha_models: List = []
|
||||
bedrock_models: List = []
|
||||
fireworks_ai_models: List = []
|
||||
fireworks_ai_embedding_models: List = []
|
||||
deepinfra_models: List = []
|
||||
perplexity_models: List = []
|
||||
watsonx_models: List = []
|
||||
|
@ -454,6 +455,10 @@ def add_known_models():
|
|||
# ignore the 'up-to', '-to-' model names -> not real models. just for cost tracking based on model params.
|
||||
if "-to-" not in key:
|
||||
fireworks_ai_models.append(key)
|
||||
elif value.get("litellm_provider") == "fireworks_ai-embedding-models":
|
||||
# ignore the 'up-to', '-to-' model names -> not real models. just for cost tracking based on model params.
|
||||
if "-to-" not in key:
|
||||
fireworks_ai_embedding_models.append(key)
|
||||
|
||||
|
||||
add_known_models()
|
||||
|
@ -779,7 +784,7 @@ models_by_provider: dict = {
|
|||
"maritalk": maritalk_models,
|
||||
"watsonx": watsonx_models,
|
||||
"gemini": gemini_models,
|
||||
"fireworks_ai": fireworks_ai_models,
|
||||
"fireworks_ai": fireworks_ai_models + fireworks_ai_embedding_models,
|
||||
}
|
||||
|
||||
# mapping for those models which have larger equivalents
|
||||
|
@ -825,6 +830,7 @@ all_embedding_models = (
|
|||
+ cohere_embedding_models
|
||||
+ bedrock_embedding_models
|
||||
+ vertex_embedding_models
|
||||
+ fireworks_ai_embedding_models
|
||||
)
|
||||
|
||||
####### IMAGE GENERATION MODELS ###################
|
||||
|
@ -971,6 +977,9 @@ from .llms.cerebras.chat import CerebrasConfig
|
|||
from .llms.sambanova.chat import SambanovaConfig
|
||||
from .llms.AI21.chat import AI21ChatConfig
|
||||
from .llms.fireworks_ai.chat.fireworks_ai_transformation import FireworksAIConfig
|
||||
from .llms.fireworks_ai.embed.fireworks_ai_transformation import (
|
||||
FireworksAIEmbeddingConfig,
|
||||
)
|
||||
from .llms.volcengine import VolcEngineConfig
|
||||
from .llms.text_completion_codestral import MistralTextCompletionConfig
|
||||
from .llms.AzureOpenAI.azure import (
|
||||
|
|
|
@ -216,7 +216,12 @@ def get_llm_provider(
|
|||
dynamic_api_key = api_key or get_secret("DEEPSEEK_API_KEY")
|
||||
elif custom_llm_provider == "fireworks_ai":
|
||||
# fireworks is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.fireworks.ai/inference/v1
|
||||
if not model.startswith("accounts/"):
|
||||
if litellm.FireworksAIEmbeddingConfig().is_fireworks_embedding_model(
|
||||
model=model
|
||||
):
|
||||
# fireworks embeddings models do no require accounts/fireworks prefix https://docs.fireworks.ai/api-reference/creates-an-embedding-vector-representing-the-input-text
|
||||
pass
|
||||
elif not model.startswith("accounts/"):
|
||||
model = f"accounts/fireworks/models/{model}"
|
||||
api_base = (
|
||||
api_base
|
||||
|
|
|
@ -30,9 +30,9 @@ def _ensure_extra_body_is_safe(extra_body: Optional[Dict]) -> Optional[Dict]:
|
|||
return extra_body
|
||||
|
||||
|
||||
def pick_cheapest_model_from_llm_provider(custom_llm_provider: str):
|
||||
def pick_cheapest_chat_model_from_llm_provider(custom_llm_provider: str):
|
||||
"""
|
||||
Pick a random model from the LLM provider.
|
||||
Pick the cheapest chat model from the LLM provider.
|
||||
"""
|
||||
if custom_llm_provider not in litellm.models_by_provider:
|
||||
raise ValueError(f"Unknown LLM provider: {custom_llm_provider}")
|
||||
|
@ -41,9 +41,14 @@ def pick_cheapest_model_from_llm_provider(custom_llm_provider: str):
|
|||
min_cost = float("inf")
|
||||
cheapest_model = None
|
||||
for model in known_models:
|
||||
model_info = litellm.get_model_info(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
try:
|
||||
model_info = litellm.get_model_info(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
except:
|
||||
continue
|
||||
if model_info.get("mode") != "chat":
|
||||
continue
|
||||
_cost = model_info.get("input_cost_per_token", 0) + model_info.get(
|
||||
"output_cost_per_token", 0
|
||||
)
|
||||
|
|
|
@ -1032,9 +1032,9 @@ class AzureChatCompletion(BaseLLM):
|
|||
data: dict,
|
||||
model_response: EmbeddingResponse,
|
||||
azure_client_params: dict,
|
||||
api_key: str,
|
||||
input: list,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_key: Optional[str] = None,
|
||||
client: Optional[AsyncAzureOpenAI] = None,
|
||||
timeout=None,
|
||||
):
|
||||
|
@ -1078,13 +1078,13 @@ class AzureChatCompletion(BaseLLM):
|
|||
self,
|
||||
model: str,
|
||||
input: list,
|
||||
api_key: str,
|
||||
api_base: str,
|
||||
api_version: str,
|
||||
timeout: float,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
model_response: EmbeddingResponse,
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
azure_ad_token: Optional[str] = None,
|
||||
client=None,
|
||||
aembedding=None,
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
"""
|
||||
This is OpenAI compatible - no transformation is applied
|
||||
|
||||
"""
|
||||
|
||||
import types
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
import litellm
|
||||
|
||||
|
||||
class FireworksAIEmbeddingConfig:
|
||||
def get_supported_openai_params(self, model: str):
|
||||
"""
|
||||
dimensions Only supported in nomic-ai/nomic-embed-text-v1.5 and later models.
|
||||
|
||||
https://docs.fireworks.ai/api-reference/creates-an-embedding-vector-representing-the-input-text
|
||||
"""
|
||||
if "nomic-ai" in model:
|
||||
return ["dimensions"]
|
||||
return []
|
||||
|
||||
def map_openai_params(
|
||||
self, non_default_params: dict, optional_params: dict, model: str
|
||||
):
|
||||
"""
|
||||
No transformation is applied - fireworks ai is openai compatible
|
||||
"""
|
||||
supported_openai_params = self.get_supported_openai_params(model)
|
||||
for param, value in non_default_params.items():
|
||||
if param in supported_openai_params:
|
||||
optional_params[param] = value
|
||||
return optional_params
|
||||
|
||||
def is_fireworks_embedding_model(self, model: str):
|
||||
"""
|
||||
helper to check if a model is a fireworks embedding model
|
||||
|
||||
Fireworks embeddings does not support passing /accounts/fireworks in the model name so we need to know if it's a known embedding model
|
||||
"""
|
||||
if (
|
||||
model in litellm.fireworks_ai_embedding_models
|
||||
or f"fireworks_ai/{model}" in litellm.fireworks_ai_embedding_models
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
|
@ -41,6 +41,7 @@ from litellm import ( # type: ignore
|
|||
)
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.utils import (
|
||||
CustomStreamWrapper,
|
||||
Usage,
|
||||
|
@ -3435,27 +3436,33 @@ def embedding(
|
|||
)
|
||||
if azure is True or custom_llm_provider == "azure":
|
||||
# azure configs
|
||||
api_type = get_secret("AZURE_API_TYPE") or "azure"
|
||||
api_type = get_secret_str("AZURE_API_TYPE") or "azure"
|
||||
|
||||
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE")
|
||||
api_base = api_base or litellm.api_base or get_secret_str("AZURE_API_BASE")
|
||||
|
||||
api_version = (
|
||||
api_version
|
||||
or litellm.api_version
|
||||
or get_secret("AZURE_API_VERSION")
|
||||
or get_secret_str("AZURE_API_VERSION")
|
||||
or litellm.AZURE_DEFAULT_API_VERSION
|
||||
)
|
||||
|
||||
azure_ad_token = optional_params.pop("azure_ad_token", None) or get_secret(
|
||||
"AZURE_AD_TOKEN"
|
||||
)
|
||||
azure_ad_token = optional_params.pop(
|
||||
"azure_ad_token", None
|
||||
) or get_secret_str("AZURE_AD_TOKEN")
|
||||
|
||||
api_key = (
|
||||
api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret("AZURE_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
)
|
||||
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
f"No API Base provided for Azure OpenAI LLM provider. Set 'AZURE_API_BASE' in .env"
|
||||
)
|
||||
|
||||
## EMBEDDING CALL
|
||||
response = azure_chat_completions.embedding(
|
||||
model=model,
|
||||
|
@ -3477,12 +3484,12 @@ def embedding(
|
|||
api_base = (
|
||||
api_base
|
||||
or litellm.api_base
|
||||
or get_secret("OPENAI_API_BASE")
|
||||
or get_secret_str("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
openai.organization = (
|
||||
litellm.organization
|
||||
or get_secret("OPENAI_ORGANIZATION")
|
||||
or get_secret_str("OPENAI_ORGANIZATION")
|
||||
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||
)
|
||||
# set API KEY
|
||||
|
@ -3490,7 +3497,7 @@ def embedding(
|
|||
api_key
|
||||
or litellm.api_key
|
||||
or litellm.openai_key
|
||||
or get_secret("OPENAI_API_KEY")
|
||||
or get_secret_str("OPENAI_API_KEY")
|
||||
)
|
||||
api_type = "openai"
|
||||
api_version = None
|
||||
|
@ -3618,7 +3625,9 @@ def embedding(
|
|||
)
|
||||
elif custom_llm_provider == "gemini":
|
||||
|
||||
gemini_api_key = api_key or get_secret("GEMINI_API_KEY") or litellm.api_key
|
||||
gemini_api_key = (
|
||||
api_key or get_secret_str("GEMINI_API_KEY") or litellm.api_key
|
||||
)
|
||||
|
||||
response = google_batch_embeddings.batch_embeddings( # type: ignore
|
||||
model=model,
|
||||
|
@ -3743,7 +3752,23 @@ def embedding(
|
|||
print_verbose=print_verbose,
|
||||
)
|
||||
elif custom_llm_provider == "mistral":
|
||||
api_key = api_key or litellm.api_key or get_secret("MISTRAL_API_KEY")
|
||||
api_key = api_key or litellm.api_key or get_secret_str("MISTRAL_API_KEY")
|
||||
response = openai_chat_completions.embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
logging_obj=logging,
|
||||
timeout=timeout,
|
||||
model_response=EmbeddingResponse(),
|
||||
optional_params=optional_params,
|
||||
client=client,
|
||||
aembedding=aembedding,
|
||||
)
|
||||
elif custom_llm_provider == "fireworks_ai":
|
||||
api_key = (
|
||||
api_key or litellm.api_key or get_secret_str("FIREWORKS_AI_API_KEY")
|
||||
)
|
||||
response = openai_chat_completions.embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
|
@ -3757,7 +3782,7 @@ def embedding(
|
|||
aembedding=aembedding,
|
||||
)
|
||||
elif custom_llm_provider == "voyage":
|
||||
api_key = api_key or litellm.api_key or get_secret("VOYAGE_API_KEY")
|
||||
api_key = api_key or litellm.api_key or get_secret_str("VOYAGE_API_KEY")
|
||||
response = openai_chat_completions.embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
|
@ -5170,11 +5195,11 @@ async def ahealth_check(
|
|||
response = {}
|
||||
elif "*" in model:
|
||||
from litellm.litellm_core_utils.llm_request_utils import (
|
||||
pick_cheapest_model_from_llm_provider,
|
||||
pick_cheapest_chat_model_from_llm_provider,
|
||||
)
|
||||
|
||||
# this is a wildcard model, we need to pick a random model from the provider
|
||||
cheapest_model = pick_cheapest_model_from_llm_provider(
|
||||
cheapest_model = pick_cheapest_chat_model_from_llm_provider(
|
||||
custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
model_params["model"] = cheapest_model
|
||||
|
|
|
@ -657,6 +657,24 @@ def test_mistral_embeddings():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_fireworks_embeddings():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
response = litellm.embedding(
|
||||
model="fireworks_ai/nomic-ai/nomic-embed-text-v1.5",
|
||||
input=["good morning from litellm"],
|
||||
)
|
||||
print(f"response: {response}")
|
||||
assert isinstance(response.usage, litellm.Usage)
|
||||
cost = completion_cost(completion_response=response)
|
||||
print("cost", cost)
|
||||
assert cost > 0.0
|
||||
print(response._hidden_params)
|
||||
assert response._hidden_params["response_cost"] > 0.0
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_watsonx_embeddings():
|
||||
|
||||
def mock_wx_embed_request(method: str, url: str, **kwargs):
|
||||
|
|
|
@ -2610,13 +2610,13 @@ def get_optional_params_embeddings(
|
|||
status_code=500,
|
||||
message="Setting dimensions is not supported for OpenAI `text-embedding-3` and later models. To drop it from the call, set `litellm.drop_params = True`.",
|
||||
)
|
||||
if custom_llm_provider == "triton":
|
||||
elif custom_llm_provider == "triton":
|
||||
keys = list(non_default_params.keys())
|
||||
for k in keys:
|
||||
non_default_params.pop(k, None)
|
||||
final_params = {**non_default_params, **kwargs}
|
||||
return final_params
|
||||
if custom_llm_provider == "databricks":
|
||||
elif custom_llm_provider == "databricks":
|
||||
supported_params = get_supported_openai_params(
|
||||
model=model or "",
|
||||
custom_llm_provider="databricks",
|
||||
|
@ -2628,7 +2628,7 @@ def get_optional_params_embeddings(
|
|||
)
|
||||
final_params = {**optional_params, **kwargs}
|
||||
return final_params
|
||||
if custom_llm_provider == "vertex_ai":
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
supported_params = get_supported_openai_params(
|
||||
model=model,
|
||||
custom_llm_provider="vertex_ai",
|
||||
|
@ -2643,7 +2643,7 @@ def get_optional_params_embeddings(
|
|||
)
|
||||
final_params = {**optional_params, **kwargs}
|
||||
return final_params
|
||||
if custom_llm_provider == "bedrock":
|
||||
elif custom_llm_provider == "bedrock":
|
||||
# if dimensions is in non_default_params -> pass it for model=bedrock/amazon.titan-embed-text-v2
|
||||
if "amazon.titan-embed-text-v1" in model:
|
||||
object: Any = litellm.AmazonTitanG1Config()
|
||||
|
@ -2666,35 +2666,7 @@ def get_optional_params_embeddings(
|
|||
)
|
||||
final_params = {**optional_params, **kwargs}
|
||||
return final_params
|
||||
# elif model == "amazon.titan-embed-image-v1":
|
||||
# supported_params = litellm.AmazonTitanG1Config().get_supported_openai_params()
|
||||
# _check_valid_arg(supported_params=supported_params)
|
||||
# optional_params = litellm.AmazonTitanG1Config().map_openai_params(
|
||||
# non_default_params=non_default_params, optional_params={}
|
||||
# )
|
||||
# final_params = {**optional_params, **kwargs}
|
||||
# return final_params
|
||||
|
||||
# if (
|
||||
# "dimensions" in non_default_params.keys()
|
||||
# and "amazon.titan-embed-text-v2" in model
|
||||
# ):
|
||||
# kwargs["dimensions"] = non_default_params["dimensions"]
|
||||
# non_default_params.pop("dimensions", None)
|
||||
|
||||
# if len(non_default_params.keys()) > 0:
|
||||
# if litellm.drop_params is True: # drop the unsupported non-default values
|
||||
# keys = list(non_default_params.keys())
|
||||
# for k in keys:
|
||||
# non_default_params.pop(k, None)
|
||||
# final_params = {**non_default_params, **kwargs}
|
||||
# return final_params
|
||||
# raise UnsupportedParamsError(
|
||||
# status_code=500,
|
||||
# message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.",
|
||||
# )
|
||||
# return {**non_default_params, **kwargs}
|
||||
if custom_llm_provider == "mistral":
|
||||
elif custom_llm_provider == "mistral":
|
||||
supported_params = get_supported_openai_params(
|
||||
model=model,
|
||||
custom_llm_provider="mistral",
|
||||
|
@ -2706,7 +2678,20 @@ def get_optional_params_embeddings(
|
|||
)
|
||||
final_params = {**optional_params, **kwargs}
|
||||
return final_params
|
||||
if (
|
||||
elif custom_llm_provider == "fireworks_ai":
|
||||
supported_params = get_supported_openai_params(
|
||||
model=model,
|
||||
custom_llm_provider="fireworks_ai",
|
||||
request_type="embeddings",
|
||||
)
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
optional_params = litellm.FireworksAIEmbeddingConfig().map_openai_params(
|
||||
non_default_params=non_default_params, optional_params={}, model=model
|
||||
)
|
||||
final_params = {**optional_params, **kwargs}
|
||||
return final_params
|
||||
|
||||
elif (
|
||||
custom_llm_provider != "openai"
|
||||
and custom_llm_provider != "azure"
|
||||
and custom_llm_provider not in litellm.openai_compatible_providers
|
||||
|
@ -2723,7 +2708,6 @@ def get_optional_params_embeddings(
|
|||
status_code=500,
|
||||
message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.",
|
||||
)
|
||||
|
||||
final_params = {**non_default_params, **kwargs}
|
||||
return final_params
|
||||
|
||||
|
@ -4293,7 +4277,12 @@ def get_supported_openai_params(
|
|||
elif custom_llm_provider == "anthropic":
|
||||
return litellm.AnthropicConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "fireworks_ai":
|
||||
return litellm.FireworksAIConfig().get_supported_openai_params()
|
||||
if request_type == "embeddings":
|
||||
return litellm.FireworksAIEmbeddingConfig().get_supported_openai_params(
|
||||
model=model
|
||||
)
|
||||
else:
|
||||
return litellm.FireworksAIConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "nvidia_nim":
|
||||
return litellm.NvidiaNimConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "cerebras":
|
||||
|
@ -4915,6 +4904,10 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
|||
"litellm_provider"
|
||||
].startswith("vertex_ai"):
|
||||
pass
|
||||
elif custom_llm_provider == "fireworks_ai" and _model_info[
|
||||
"litellm_provider"
|
||||
].startswith("fireworks_ai"):
|
||||
pass
|
||||
else:
|
||||
raise Exception
|
||||
elif split_model in litellm.model_cost:
|
||||
|
@ -4929,6 +4922,10 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
|||
"litellm_provider"
|
||||
].startswith("vertex_ai"):
|
||||
pass
|
||||
elif custom_llm_provider == "fireworks_ai" and _model_info[
|
||||
"litellm_provider"
|
||||
].startswith("fireworks_ai"):
|
||||
pass
|
||||
else:
|
||||
raise Exception
|
||||
else:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue