mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
[Feat] Add fireworks AI embedding (#5812)
* add fireworks embedding models * add fireworks ai * fireworks ai embeddings support * is_fireworks_embedding_model * working fireworks embeddings * fix health check * models * fix embedding get optional params * fix linting errors * fix pick_cheapest_chat_model_from_llm_provider * add fireworks ai litellm provider * docs fireworks embedding models * fixes for when azure ad token is passed
This commit is contained in:
parent
4994addba8
commit
28f0dac398
9 changed files with 181 additions and 61 deletions
|
@ -41,6 +41,7 @@ from litellm import ( # type: ignore
|
|||
)
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.utils import (
|
||||
CustomStreamWrapper,
|
||||
Usage,
|
||||
|
@ -3435,27 +3436,33 @@ def embedding(
|
|||
)
|
||||
if azure is True or custom_llm_provider == "azure":
|
||||
# azure configs
|
||||
api_type = get_secret("AZURE_API_TYPE") or "azure"
|
||||
api_type = get_secret_str("AZURE_API_TYPE") or "azure"
|
||||
|
||||
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE")
|
||||
api_base = api_base or litellm.api_base or get_secret_str("AZURE_API_BASE")
|
||||
|
||||
api_version = (
|
||||
api_version
|
||||
or litellm.api_version
|
||||
or get_secret("AZURE_API_VERSION")
|
||||
or get_secret_str("AZURE_API_VERSION")
|
||||
or litellm.AZURE_DEFAULT_API_VERSION
|
||||
)
|
||||
|
||||
azure_ad_token = optional_params.pop("azure_ad_token", None) or get_secret(
|
||||
"AZURE_AD_TOKEN"
|
||||
)
|
||||
azure_ad_token = optional_params.pop(
|
||||
"azure_ad_token", None
|
||||
) or get_secret_str("AZURE_AD_TOKEN")
|
||||
|
||||
api_key = (
|
||||
api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret("AZURE_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
)
|
||||
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
f"No API Base provided for Azure OpenAI LLM provider. Set 'AZURE_API_BASE' in .env"
|
||||
)
|
||||
|
||||
## EMBEDDING CALL
|
||||
response = azure_chat_completions.embedding(
|
||||
model=model,
|
||||
|
@ -3477,12 +3484,12 @@ def embedding(
|
|||
api_base = (
|
||||
api_base
|
||||
or litellm.api_base
|
||||
or get_secret("OPENAI_API_BASE")
|
||||
or get_secret_str("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
openai.organization = (
|
||||
litellm.organization
|
||||
or get_secret("OPENAI_ORGANIZATION")
|
||||
or get_secret_str("OPENAI_ORGANIZATION")
|
||||
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||
)
|
||||
# set API KEY
|
||||
|
@ -3490,7 +3497,7 @@ def embedding(
|
|||
api_key
|
||||
or litellm.api_key
|
||||
or litellm.openai_key
|
||||
or get_secret("OPENAI_API_KEY")
|
||||
or get_secret_str("OPENAI_API_KEY")
|
||||
)
|
||||
api_type = "openai"
|
||||
api_version = None
|
||||
|
@ -3618,7 +3625,9 @@ def embedding(
|
|||
)
|
||||
elif custom_llm_provider == "gemini":
|
||||
|
||||
gemini_api_key = api_key or get_secret("GEMINI_API_KEY") or litellm.api_key
|
||||
gemini_api_key = (
|
||||
api_key or get_secret_str("GEMINI_API_KEY") or litellm.api_key
|
||||
)
|
||||
|
||||
response = google_batch_embeddings.batch_embeddings( # type: ignore
|
||||
model=model,
|
||||
|
@ -3743,7 +3752,23 @@ def embedding(
|
|||
print_verbose=print_verbose,
|
||||
)
|
||||
elif custom_llm_provider == "mistral":
|
||||
api_key = api_key or litellm.api_key or get_secret("MISTRAL_API_KEY")
|
||||
api_key = api_key or litellm.api_key or get_secret_str("MISTRAL_API_KEY")
|
||||
response = openai_chat_completions.embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
logging_obj=logging,
|
||||
timeout=timeout,
|
||||
model_response=EmbeddingResponse(),
|
||||
optional_params=optional_params,
|
||||
client=client,
|
||||
aembedding=aembedding,
|
||||
)
|
||||
elif custom_llm_provider == "fireworks_ai":
|
||||
api_key = (
|
||||
api_key or litellm.api_key or get_secret_str("FIREWORKS_AI_API_KEY")
|
||||
)
|
||||
response = openai_chat_completions.embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
|
@ -3757,7 +3782,7 @@ def embedding(
|
|||
aembedding=aembedding,
|
||||
)
|
||||
elif custom_llm_provider == "voyage":
|
||||
api_key = api_key or litellm.api_key or get_secret("VOYAGE_API_KEY")
|
||||
api_key = api_key or litellm.api_key or get_secret_str("VOYAGE_API_KEY")
|
||||
response = openai_chat_completions.embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
|
@ -5170,11 +5195,11 @@ async def ahealth_check(
|
|||
response = {}
|
||||
elif "*" in model:
|
||||
from litellm.litellm_core_utils.llm_request_utils import (
|
||||
pick_cheapest_model_from_llm_provider,
|
||||
pick_cheapest_chat_model_from_llm_provider,
|
||||
)
|
||||
|
||||
# this is a wildcard model, we need to pick a random model from the provider
|
||||
cheapest_model = pick_cheapest_model_from_llm_provider(
|
||||
cheapest_model = pick_cheapest_chat_model_from_llm_provider(
|
||||
custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
model_params["model"] = cheapest_model
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue