feat(utils.py): add native fireworks ai support

addresses - https://github.com/BerriAI/litellm/issues/777, https://github.com/BerriAI/litellm/issues/2486
This commit is contained in:
Krrish Dholakia 2024-03-15 09:09:59 -07:00
parent 5a2e024576
commit 9909f44015
5 changed files with 42 additions and 0 deletions

View file

@ -328,6 +328,7 @@ openai_compatible_providers: List = [
"perplexity", "perplexity",
"xinference", "xinference",
"together_ai", "together_ai",
"fireworks_ai",
] ]
@ -479,6 +480,7 @@ provider_list: List = [
"voyage", "voyage",
"cloudflare", "cloudflare",
"xinference", "xinference",
"fireworks_ai",
"custom", # custom apis "custom", # custom apis
] ]

View file

@ -891,6 +891,7 @@ def completion(
or custom_llm_provider == "mistral" or custom_llm_provider == "mistral"
or custom_llm_provider == "openai" or custom_llm_provider == "openai"
or custom_llm_provider == "together_ai" or custom_llm_provider == "together_ai"
or custom_llm_provider in litellm.openai_compatible_providers
or "ft:gpt-3.5-turbo" in model # finetune gpt-3.5-turbo or "ft:gpt-3.5-turbo" in model # finetune gpt-3.5-turbo
): # allow user to make an openai call with a custom base ): # allow user to make an openai call with a custom base
# note: if a user sets a custom base - we should ensure this works # note: if a user sets a custom base - we should ensure this works
@ -2393,6 +2394,7 @@ async def aembedding(*args, **kwargs):
or custom_llm_provider == "deepinfra" or custom_llm_provider == "deepinfra"
or custom_llm_provider == "perplexity" or custom_llm_provider == "perplexity"
or custom_llm_provider == "groq" or custom_llm_provider == "groq"
or custom_llm_provider == "fireworks_ai"
or custom_llm_provider == "ollama" or custom_llm_provider == "ollama"
or custom_llm_provider == "vertex_ai" or custom_llm_provider == "vertex_ai"
): # currently implemented aiohttp calls for just azure and openai, soon all. ): # currently implemented aiohttp calls for just azure and openai, soon all.
@ -2892,6 +2894,7 @@ async def atext_completion(*args, **kwargs):
or custom_llm_provider == "deepinfra" or custom_llm_provider == "deepinfra"
or custom_llm_provider == "perplexity" or custom_llm_provider == "perplexity"
or custom_llm_provider == "groq" or custom_llm_provider == "groq"
or custom_llm_provider == "fireworks_ai"
or custom_llm_provider == "text-completion-openai" or custom_llm_provider == "text-completion-openai"
or custom_llm_provider == "huggingface" or custom_llm_provider == "huggingface"
or custom_llm_provider == "ollama" or custom_llm_provider == "ollama"

View file

@ -631,6 +631,13 @@
"litellm_provider": "groq", "litellm_provider": "groq",
"mode": "chat" "mode": "chat"
}, },
"groq/gemma-7b-it": {
"max_tokens": 8192,
"input_cost_per_token": 0.00000010,
"output_cost_per_token": 0.00000010,
"litellm_provider": "groq",
"mode": "chat"
},
"claude-instant-1.2": { "claude-instant-1.2": {
"max_tokens": 100000, "max_tokens": 100000,
"max_output_tokens": 8191, "max_output_tokens": 8191,

View file

@ -529,6 +529,25 @@ def test_completion_azure_gpt4_vision():
# test_completion_azure_gpt4_vision() # test_completion_azure_gpt4_vision()
def test_completion_fireworks_ai():
try:
litellm.set_verbose = True
messages = [
{"role": "system", "content": "You're a good bot"},
{
"role": "user",
"content": "Hey",
},
]
response = completion(
model="fireworks_ai/accounts/fireworks/models/mixtral-8x7b-instruct",
messages=messages,
)
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.skip(reason="this test is flaky") @pytest.mark.skip(reason="this test is flaky")
def test_completion_perplexity_api(): def test_completion_perplexity_api():
try: try:

View file

@ -5375,6 +5375,17 @@ def get_llm_provider(
# groq is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.groq.com/openai/v1 # groq is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.groq.com/openai/v1
api_base = "https://api.groq.com/openai/v1" api_base = "https://api.groq.com/openai/v1"
dynamic_api_key = get_secret("GROQ_API_KEY") dynamic_api_key = get_secret("GROQ_API_KEY")
elif custom_llm_provider == "fireworks_ai":
# fireworks is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.groq.com/openai/v1
if not model.startswith("accounts/fireworks/models"):
model = f"accounts/fireworks/models/{model}"
api_base = "https://api.fireworks.ai/inference/v1"
dynamic_api_key = (
get_secret("FIREWORKS_API_KEY")
or get_secret("FIREWORKS_AI_API_KEY")
or get_secret("FIREWORKSAI_API_KEY")
or get_secret("FIREWORKS_AI_TOKEN")
)
elif custom_llm_provider == "mistral": elif custom_llm_provider == "mistral":
# mistral is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.mistral.ai # mistral is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.mistral.ai
api_base = ( api_base = (