[Feat] Add fireworks AI embedding (#5812)

* add fireworks embedding models

* add fireworks ai

* fireworks ai embeddings support

* is_fireworks_embedding_model

* working fireworks embeddings

* fix health check * models

* fix embedding get optional params

* fix linting errors

* fix pick_cheapest_chat_model_from_llm_provider

* add fireworks ai litellm provider

* docs fireworks embedding models

* fixes for when azure ad token  is passed
This commit is contained in:
Ishaan Jaff 2024-09-20 22:23:28 -07:00 committed by GitHub
parent d349d501c8
commit 1d630b61ad
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 181 additions and 61 deletions

View file

@ -41,6 +41,7 @@ from litellm import ( # type: ignore
)
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.secret_managers.main import get_secret_str
from litellm.utils import (
CustomStreamWrapper,
Usage,
@ -3435,27 +3436,33 @@ def embedding(
)
if azure is True or custom_llm_provider == "azure":
# azure configs
api_type = get_secret("AZURE_API_TYPE") or "azure"
api_type = get_secret_str("AZURE_API_TYPE") or "azure"
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE")
api_base = api_base or litellm.api_base or get_secret_str("AZURE_API_BASE")
api_version = (
api_version
or litellm.api_version
or get_secret("AZURE_API_VERSION")
or get_secret_str("AZURE_API_VERSION")
or litellm.AZURE_DEFAULT_API_VERSION
)
azure_ad_token = optional_params.pop("azure_ad_token", None) or get_secret(
"AZURE_AD_TOKEN"
)
azure_ad_token = optional_params.pop(
"azure_ad_token", None
) or get_secret_str("AZURE_AD_TOKEN")
api_key = (
api_key
or litellm.api_key
or litellm.azure_key
or get_secret("AZURE_API_KEY")
or get_secret_str("AZURE_API_KEY")
)
if api_base is None:
raise ValueError(
f"No API Base provided for Azure OpenAI LLM provider. Set 'AZURE_API_BASE' in .env"
)
## EMBEDDING CALL
response = azure_chat_completions.embedding(
model=model,
@ -3477,12 +3484,12 @@ def embedding(
api_base = (
api_base
or litellm.api_base
or get_secret("OPENAI_API_BASE")
or get_secret_str("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
openai.organization = (
litellm.organization
or get_secret("OPENAI_ORGANIZATION")
or get_secret_str("OPENAI_ORGANIZATION")
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
# set API KEY
@ -3490,7 +3497,7 @@ def embedding(
api_key
or litellm.api_key
or litellm.openai_key
or get_secret("OPENAI_API_KEY")
or get_secret_str("OPENAI_API_KEY")
)
api_type = "openai"
api_version = None
@ -3618,7 +3625,9 @@ def embedding(
)
elif custom_llm_provider == "gemini":
gemini_api_key = api_key or get_secret("GEMINI_API_KEY") or litellm.api_key
gemini_api_key = (
api_key or get_secret_str("GEMINI_API_KEY") or litellm.api_key
)
response = google_batch_embeddings.batch_embeddings( # type: ignore
model=model,
@ -3743,7 +3752,23 @@ def embedding(
print_verbose=print_verbose,
)
elif custom_llm_provider == "mistral":
api_key = api_key or litellm.api_key or get_secret("MISTRAL_API_KEY")
api_key = api_key or litellm.api_key or get_secret_str("MISTRAL_API_KEY")
response = openai_chat_completions.embedding(
model=model,
input=input,
api_base=api_base,
api_key=api_key,
logging_obj=logging,
timeout=timeout,
model_response=EmbeddingResponse(),
optional_params=optional_params,
client=client,
aembedding=aembedding,
)
elif custom_llm_provider == "fireworks_ai":
api_key = (
api_key or litellm.api_key or get_secret_str("FIREWORKS_AI_API_KEY")
)
response = openai_chat_completions.embedding(
model=model,
input=input,
@ -3757,7 +3782,7 @@ def embedding(
aembedding=aembedding,
)
elif custom_llm_provider == "voyage":
api_key = api_key or litellm.api_key or get_secret("VOYAGE_API_KEY")
api_key = api_key or litellm.api_key or get_secret_str("VOYAGE_API_KEY")
response = openai_chat_completions.embedding(
model=model,
input=input,
@ -5170,11 +5195,11 @@ async def ahealth_check(
response = {}
elif "*" in model:
from litellm.litellm_core_utils.llm_request_utils import (
pick_cheapest_model_from_llm_provider,
pick_cheapest_chat_model_from_llm_provider,
)
# this is a wildcard model, we need to pick a random model from the provider
cheapest_model = pick_cheapest_model_from_llm_provider(
cheapest_model = pick_cheapest_chat_model_from_llm_provider(
custom_llm_provider=custom_llm_provider
)
model_params["model"] = cheapest_model