[Feat] Add fireworks AI embedding (#5812)

* add fireworks embedding models

* add fireworks ai

* fireworks ai embeddings support

* is_fireworks_embedding_model

* working fireworks embeddings

* fix health check * models

* fix embedding get optional params

* fix linting errors

* fix pick_cheapest_chat_model_from_llm_provider

* add fireworks ai litellm provider

* docs fireworks embedding models

* fixes for when azure ad token  is passed
This commit is contained in:
Ishaan Jaff 2024-09-20 22:23:28 -07:00 committed by GitHub
parent d349d501c8
commit 1d630b61ad
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 181 additions and 61 deletions

View file

@ -150,4 +150,18 @@ We support ALL Fireworks AI models, just set `fireworks_ai/` as a prefix when se
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| mixtral-8x7b-instruct | `completion(model="fireworks_ai/mixtral-8x7b-instruct", messages)` |
| firefunction-v1 | `completion(model="fireworks_ai/firefunction-v1", messages)` |
| llama-v2-70b-chat | `completion(model="fireworks_ai/llama-v2-70b-chat", messages)` |
| llama-v2-70b-chat | `completion(model="fireworks_ai/llama-v2-70b-chat", messages)` |
## Supported Embedding Models
:::info
We support ALL Fireworks AI models, just set `fireworks_ai/` as a prefix when sending embedding requests
:::
| Model Name | Function Call |
|-----------------------|-----------------------------------------------------------------|
| fireworks_ai/nomic-ai/nomic-embed-text-v1.5 | `response = litellm.embedding(model="fireworks_ai/nomic-ai/nomic-embed-text-v1.5", input=input_text)` |
| fireworks_ai/nomic-ai/nomic-embed-text-v1 | `response = litellm.embedding(model="fireworks_ai/nomic-ai/nomic-embed-text-v1", input=input_text)` |
| fireworks_ai/WhereIsAI/UAE-Large-V1 | `response = litellm.embedding(model="fireworks_ai/WhereIsAI/UAE-Large-V1", input=input_text)` |
| fireworks_ai/thenlper/gte-large | `response = litellm.embedding(model="fireworks_ai/thenlper/gte-large", input=input_text)` |
| fireworks_ai/thenlper/gte-base | `response = litellm.embedding(model="fireworks_ai/thenlper/gte-base", input=input_text)` |

View file

@ -378,6 +378,7 @@ nlp_cloud_models: List = []
aleph_alpha_models: List = []
bedrock_models: List = []
fireworks_ai_models: List = []
fireworks_ai_embedding_models: List = []
deepinfra_models: List = []
perplexity_models: List = []
watsonx_models: List = []
@ -454,6 +455,10 @@ def add_known_models():
# ignore the 'up-to', '-to-' model names -> not real models. just for cost tracking based on model params.
if "-to-" not in key:
fireworks_ai_models.append(key)
elif value.get("litellm_provider") == "fireworks_ai-embedding-models":
# ignore the 'up-to', '-to-' model names -> not real models. just for cost tracking based on model params.
if "-to-" not in key:
fireworks_ai_embedding_models.append(key)
add_known_models()
@ -779,7 +784,7 @@ models_by_provider: dict = {
"maritalk": maritalk_models,
"watsonx": watsonx_models,
"gemini": gemini_models,
"fireworks_ai": fireworks_ai_models,
"fireworks_ai": fireworks_ai_models + fireworks_ai_embedding_models,
}
# mapping for those models which have larger equivalents
@ -825,6 +830,7 @@ all_embedding_models = (
+ cohere_embedding_models
+ bedrock_embedding_models
+ vertex_embedding_models
+ fireworks_ai_embedding_models
)
####### IMAGE GENERATION MODELS ###################
@ -971,6 +977,9 @@ from .llms.cerebras.chat import CerebrasConfig
from .llms.sambanova.chat import SambanovaConfig
from .llms.AI21.chat import AI21ChatConfig
from .llms.fireworks_ai.chat.fireworks_ai_transformation import FireworksAIConfig
from .llms.fireworks_ai.embed.fireworks_ai_transformation import (
FireworksAIEmbeddingConfig,
)
from .llms.volcengine import VolcEngineConfig
from .llms.text_completion_codestral import MistralTextCompletionConfig
from .llms.AzureOpenAI.azure import (

View file

@ -216,7 +216,12 @@ def get_llm_provider(
dynamic_api_key = api_key or get_secret("DEEPSEEK_API_KEY")
elif custom_llm_provider == "fireworks_ai":
# fireworks is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.fireworks.ai/inference/v1
if not model.startswith("accounts/"):
if litellm.FireworksAIEmbeddingConfig().is_fireworks_embedding_model(
model=model
):
# fireworks embeddings models do no require accounts/fireworks prefix https://docs.fireworks.ai/api-reference/creates-an-embedding-vector-representing-the-input-text
pass
elif not model.startswith("accounts/"):
model = f"accounts/fireworks/models/{model}"
api_base = (
api_base

View file

@ -30,9 +30,9 @@ def _ensure_extra_body_is_safe(extra_body: Optional[Dict]) -> Optional[Dict]:
return extra_body
def pick_cheapest_model_from_llm_provider(custom_llm_provider: str):
def pick_cheapest_chat_model_from_llm_provider(custom_llm_provider: str):
"""
Pick a random model from the LLM provider.
Pick the cheapest chat model from the LLM provider.
"""
if custom_llm_provider not in litellm.models_by_provider:
raise ValueError(f"Unknown LLM provider: {custom_llm_provider}")
@ -41,9 +41,14 @@ def pick_cheapest_model_from_llm_provider(custom_llm_provider: str):
min_cost = float("inf")
cheapest_model = None
for model in known_models:
model_info = litellm.get_model_info(
model=model, custom_llm_provider=custom_llm_provider
)
try:
model_info = litellm.get_model_info(
model=model, custom_llm_provider=custom_llm_provider
)
except:
continue
if model_info.get("mode") != "chat":
continue
_cost = model_info.get("input_cost_per_token", 0) + model_info.get(
"output_cost_per_token", 0
)

View file

@ -1032,9 +1032,9 @@ class AzureChatCompletion(BaseLLM):
data: dict,
model_response: EmbeddingResponse,
azure_client_params: dict,
api_key: str,
input: list,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
client: Optional[AsyncAzureOpenAI] = None,
timeout=None,
):
@ -1078,13 +1078,13 @@ class AzureChatCompletion(BaseLLM):
self,
model: str,
input: list,
api_key: str,
api_base: str,
api_version: str,
timeout: float,
logging_obj: LiteLLMLoggingObj,
model_response: EmbeddingResponse,
optional_params: dict,
api_key: Optional[str] = None,
azure_ad_token: Optional[str] = None,
client=None,
aembedding=None,

View file

@ -0,0 +1,47 @@
"""
This is OpenAI compatible - no transformation is applied
"""
import types
from typing import Literal, Optional, Union
import litellm
class FireworksAIEmbeddingConfig:
def get_supported_openai_params(self, model: str):
"""
dimensions Only supported in nomic-ai/nomic-embed-text-v1.5 and later models.
https://docs.fireworks.ai/api-reference/creates-an-embedding-vector-representing-the-input-text
"""
if "nomic-ai" in model:
return ["dimensions"]
return []
def map_openai_params(
self, non_default_params: dict, optional_params: dict, model: str
):
"""
No transformation is applied - fireworks ai is openai compatible
"""
supported_openai_params = self.get_supported_openai_params(model)
for param, value in non_default_params.items():
if param in supported_openai_params:
optional_params[param] = value
return optional_params
def is_fireworks_embedding_model(self, model: str):
"""
helper to check if a model is a fireworks embedding model
Fireworks embeddings does not support passing /accounts/fireworks in the model name so we need to know if it's a known embedding model
"""
if (
model in litellm.fireworks_ai_embedding_models
or f"fireworks_ai/{model}" in litellm.fireworks_ai_embedding_models
):
return True
return False

View file

@ -41,6 +41,7 @@ from litellm import ( # type: ignore
)
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.secret_managers.main import get_secret_str
from litellm.utils import (
CustomStreamWrapper,
Usage,
@ -3435,27 +3436,33 @@ def embedding(
)
if azure is True or custom_llm_provider == "azure":
# azure configs
api_type = get_secret("AZURE_API_TYPE") or "azure"
api_type = get_secret_str("AZURE_API_TYPE") or "azure"
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE")
api_base = api_base or litellm.api_base or get_secret_str("AZURE_API_BASE")
api_version = (
api_version
or litellm.api_version
or get_secret("AZURE_API_VERSION")
or get_secret_str("AZURE_API_VERSION")
or litellm.AZURE_DEFAULT_API_VERSION
)
azure_ad_token = optional_params.pop("azure_ad_token", None) or get_secret(
"AZURE_AD_TOKEN"
)
azure_ad_token = optional_params.pop(
"azure_ad_token", None
) or get_secret_str("AZURE_AD_TOKEN")
api_key = (
api_key
or litellm.api_key
or litellm.azure_key
or get_secret("AZURE_API_KEY")
or get_secret_str("AZURE_API_KEY")
)
if api_base is None:
raise ValueError(
f"No API Base provided for Azure OpenAI LLM provider. Set 'AZURE_API_BASE' in .env"
)
## EMBEDDING CALL
response = azure_chat_completions.embedding(
model=model,
@ -3477,12 +3484,12 @@ def embedding(
api_base = (
api_base
or litellm.api_base
or get_secret("OPENAI_API_BASE")
or get_secret_str("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
openai.organization = (
litellm.organization
or get_secret("OPENAI_ORGANIZATION")
or get_secret_str("OPENAI_ORGANIZATION")
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
# set API KEY
@ -3490,7 +3497,7 @@ def embedding(
api_key
or litellm.api_key
or litellm.openai_key
or get_secret("OPENAI_API_KEY")
or get_secret_str("OPENAI_API_KEY")
)
api_type = "openai"
api_version = None
@ -3618,7 +3625,9 @@ def embedding(
)
elif custom_llm_provider == "gemini":
gemini_api_key = api_key or get_secret("GEMINI_API_KEY") or litellm.api_key
gemini_api_key = (
api_key or get_secret_str("GEMINI_API_KEY") or litellm.api_key
)
response = google_batch_embeddings.batch_embeddings( # type: ignore
model=model,
@ -3743,7 +3752,23 @@ def embedding(
print_verbose=print_verbose,
)
elif custom_llm_provider == "mistral":
api_key = api_key or litellm.api_key or get_secret("MISTRAL_API_KEY")
api_key = api_key or litellm.api_key or get_secret_str("MISTRAL_API_KEY")
response = openai_chat_completions.embedding(
model=model,
input=input,
api_base=api_base,
api_key=api_key,
logging_obj=logging,
timeout=timeout,
model_response=EmbeddingResponse(),
optional_params=optional_params,
client=client,
aembedding=aembedding,
)
elif custom_llm_provider == "fireworks_ai":
api_key = (
api_key or litellm.api_key or get_secret_str("FIREWORKS_AI_API_KEY")
)
response = openai_chat_completions.embedding(
model=model,
input=input,
@ -3757,7 +3782,7 @@ def embedding(
aembedding=aembedding,
)
elif custom_llm_provider == "voyage":
api_key = api_key or litellm.api_key or get_secret("VOYAGE_API_KEY")
api_key = api_key or litellm.api_key or get_secret_str("VOYAGE_API_KEY")
response = openai_chat_completions.embedding(
model=model,
input=input,
@ -5170,11 +5195,11 @@ async def ahealth_check(
response = {}
elif "*" in model:
from litellm.litellm_core_utils.llm_request_utils import (
pick_cheapest_model_from_llm_provider,
pick_cheapest_chat_model_from_llm_provider,
)
# this is a wildcard model, we need to pick a random model from the provider
cheapest_model = pick_cheapest_model_from_llm_provider(
cheapest_model = pick_cheapest_chat_model_from_llm_provider(
custom_llm_provider=custom_llm_provider
)
model_params["model"] = cheapest_model

View file

@ -657,6 +657,24 @@ def test_mistral_embeddings():
pytest.fail(f"Error occurred: {e}")
def test_fireworks_embeddings():
try:
litellm.set_verbose = True
response = litellm.embedding(
model="fireworks_ai/nomic-ai/nomic-embed-text-v1.5",
input=["good morning from litellm"],
)
print(f"response: {response}")
assert isinstance(response.usage, litellm.Usage)
cost = completion_cost(completion_response=response)
print("cost", cost)
assert cost > 0.0
print(response._hidden_params)
assert response._hidden_params["response_cost"] > 0.0
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_watsonx_embeddings():
def mock_wx_embed_request(method: str, url: str, **kwargs):

View file

@ -2610,13 +2610,13 @@ def get_optional_params_embeddings(
status_code=500,
message="Setting dimensions is not supported for OpenAI `text-embedding-3` and later models. To drop it from the call, set `litellm.drop_params = True`.",
)
if custom_llm_provider == "triton":
elif custom_llm_provider == "triton":
keys = list(non_default_params.keys())
for k in keys:
non_default_params.pop(k, None)
final_params = {**non_default_params, **kwargs}
return final_params
if custom_llm_provider == "databricks":
elif custom_llm_provider == "databricks":
supported_params = get_supported_openai_params(
model=model or "",
custom_llm_provider="databricks",
@ -2628,7 +2628,7 @@ def get_optional_params_embeddings(
)
final_params = {**optional_params, **kwargs}
return final_params
if custom_llm_provider == "vertex_ai":
elif custom_llm_provider == "vertex_ai":
supported_params = get_supported_openai_params(
model=model,
custom_llm_provider="vertex_ai",
@ -2643,7 +2643,7 @@ def get_optional_params_embeddings(
)
final_params = {**optional_params, **kwargs}
return final_params
if custom_llm_provider == "bedrock":
elif custom_llm_provider == "bedrock":
# if dimensions is in non_default_params -> pass it for model=bedrock/amazon.titan-embed-text-v2
if "amazon.titan-embed-text-v1" in model:
object: Any = litellm.AmazonTitanG1Config()
@ -2666,35 +2666,7 @@ def get_optional_params_embeddings(
)
final_params = {**optional_params, **kwargs}
return final_params
# elif model == "amazon.titan-embed-image-v1":
# supported_params = litellm.AmazonTitanG1Config().get_supported_openai_params()
# _check_valid_arg(supported_params=supported_params)
# optional_params = litellm.AmazonTitanG1Config().map_openai_params(
# non_default_params=non_default_params, optional_params={}
# )
# final_params = {**optional_params, **kwargs}
# return final_params
# if (
# "dimensions" in non_default_params.keys()
# and "amazon.titan-embed-text-v2" in model
# ):
# kwargs["dimensions"] = non_default_params["dimensions"]
# non_default_params.pop("dimensions", None)
# if len(non_default_params.keys()) > 0:
# if litellm.drop_params is True: # drop the unsupported non-default values
# keys = list(non_default_params.keys())
# for k in keys:
# non_default_params.pop(k, None)
# final_params = {**non_default_params, **kwargs}
# return final_params
# raise UnsupportedParamsError(
# status_code=500,
# message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.",
# )
# return {**non_default_params, **kwargs}
if custom_llm_provider == "mistral":
elif custom_llm_provider == "mistral":
supported_params = get_supported_openai_params(
model=model,
custom_llm_provider="mistral",
@ -2706,7 +2678,20 @@ def get_optional_params_embeddings(
)
final_params = {**optional_params, **kwargs}
return final_params
if (
elif custom_llm_provider == "fireworks_ai":
supported_params = get_supported_openai_params(
model=model,
custom_llm_provider="fireworks_ai",
request_type="embeddings",
)
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.FireworksAIEmbeddingConfig().map_openai_params(
non_default_params=non_default_params, optional_params={}, model=model
)
final_params = {**optional_params, **kwargs}
return final_params
elif (
custom_llm_provider != "openai"
and custom_llm_provider != "azure"
and custom_llm_provider not in litellm.openai_compatible_providers
@ -2723,7 +2708,6 @@ def get_optional_params_embeddings(
status_code=500,
message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.",
)
final_params = {**non_default_params, **kwargs}
return final_params
@ -4293,7 +4277,12 @@ def get_supported_openai_params(
elif custom_llm_provider == "anthropic":
return litellm.AnthropicConfig().get_supported_openai_params()
elif custom_llm_provider == "fireworks_ai":
return litellm.FireworksAIConfig().get_supported_openai_params()
if request_type == "embeddings":
return litellm.FireworksAIEmbeddingConfig().get_supported_openai_params(
model=model
)
else:
return litellm.FireworksAIConfig().get_supported_openai_params()
elif custom_llm_provider == "nvidia_nim":
return litellm.NvidiaNimConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "cerebras":
@ -4915,6 +4904,10 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
"litellm_provider"
].startswith("vertex_ai"):
pass
elif custom_llm_provider == "fireworks_ai" and _model_info[
"litellm_provider"
].startswith("fireworks_ai"):
pass
else:
raise Exception
elif split_model in litellm.model_cost:
@ -4929,6 +4922,10 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
"litellm_provider"
].startswith("vertex_ai"):
pass
elif custom_llm_provider == "fireworks_ai" and _model_info[
"litellm_provider"
].startswith("fireworks_ai"):
pass
else:
raise Exception
else: