Merge branch 'main' into litellm_proxy_team_cache_update

This commit is contained in:
Krish Dholakia 2024-07-19 21:07:26 -07:00 committed by GitHub
commit f797597202
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 222 additions and 61 deletions

View file

@ -75,6 +75,7 @@ BEDROCK_CONVERSE_MODELS = [
"anthropic.claude-v2:1", "anthropic.claude-v2:1",
"anthropic.claude-v1", "anthropic.claude-v1",
"anthropic.claude-instant-v1", "anthropic.claude-instant-v1",
"ai21.jamba-instruct-v1:0",
] ]
@ -195,13 +196,39 @@ async def make_call(
if client is None: if client is None:
client = _get_async_httpx_client() # Create a new client if none provided client = _get_async_httpx_client() # Create a new client if none provided
response = await client.post(api_base, headers=headers, data=data, stream=True) response = await client.post(
api_base,
headers=headers,
data=data,
stream=True if "ai21" not in api_base else False,
)
if response.status_code != 200: if response.status_code != 200:
raise BedrockError(status_code=response.status_code, message=response.text) raise BedrockError(status_code=response.status_code, message=response.text)
decoder = AWSEventStreamDecoder(model=model) if "ai21" in api_base:
completion_stream = decoder.aiter_bytes(response.aiter_bytes(chunk_size=1024)) aws_bedrock_process_response = BedrockConverseLLM()
model_response: (
ModelResponse
) = aws_bedrock_process_response.process_response(
model=model,
response=response,
model_response=litellm.ModelResponse(),
stream=True,
logging_obj=logging_obj,
optional_params={},
api_key="",
data=data,
messages=messages,
print_verbose=litellm.print_verbose,
encoding=litellm.encoding,
) # type: ignore
completion_stream: Any = MockResponseIterator(model_response=model_response)
else:
decoder = AWSEventStreamDecoder(model=model)
completion_stream = decoder.aiter_bytes(
response.aiter_bytes(chunk_size=1024)
)
# LOGGING # LOGGING
logging_obj.post_call( logging_obj.post_call(
@ -233,13 +260,35 @@ def make_sync_call(
if client is None: if client is None:
client = _get_httpx_client() # Create a new client if none provided client = _get_httpx_client() # Create a new client if none provided
response = client.post(api_base, headers=headers, data=data, stream=True) response = client.post(
api_base,
headers=headers,
data=data,
stream=True if "ai21" not in api_base else False,
)
if response.status_code != 200: if response.status_code != 200:
raise BedrockError(status_code=response.status_code, message=response.read()) raise BedrockError(status_code=response.status_code, message=response.read())
decoder = AWSEventStreamDecoder(model=model) if "ai21" in api_base:
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024)) aws_bedrock_process_response = BedrockConverseLLM()
model_response: ModelResponse = aws_bedrock_process_response.process_response(
model=model,
response=response,
model_response=litellm.ModelResponse(),
stream=True,
logging_obj=logging_obj,
optional_params={},
api_key="",
data=data,
messages=messages,
print_verbose=litellm.print_verbose,
encoding=litellm.encoding,
) # type: ignore
completion_stream: Any = MockResponseIterator(model_response=model_response)
else:
decoder = AWSEventStreamDecoder(model=model)
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
# LOGGING # LOGGING
logging_obj.post_call( logging_obj.post_call(
@ -1348,7 +1397,7 @@ class BedrockConverseLLM(BaseLLM):
response: Union[requests.Response, httpx.Response], response: Union[requests.Response, httpx.Response],
model_response: ModelResponse, model_response: ModelResponse,
stream: bool, stream: bool,
logging_obj: Logging, logging_obj: Optional[Logging],
optional_params: dict, optional_params: dict,
api_key: str, api_key: str,
data: Union[dict, str], data: Union[dict, str],
@ -1358,12 +1407,13 @@ class BedrockConverseLLM(BaseLLM):
) -> Union[ModelResponse, CustomStreamWrapper]: ) -> Union[ModelResponse, CustomStreamWrapper]:
## LOGGING ## LOGGING
logging_obj.post_call( if logging_obj is not None:
input=messages, logging_obj.post_call(
api_key=api_key, input=messages,
original_response=response.text, api_key=api_key,
additional_args={"complete_input_dict": data}, original_response=response.text,
) additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}") print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT ## RESPONSE OBJECT
@ -1900,7 +1950,7 @@ class BedrockConverseLLM(BaseLLM):
if acompletion: if acompletion:
if isinstance(client, HTTPHandler): if isinstance(client, HTTPHandler):
client = None client = None
if stream is True and provider != "ai21": if stream is True:
return self.async_streaming( return self.async_streaming(
model=model, model=model,
messages=messages, messages=messages,
@ -1937,7 +1987,7 @@ class BedrockConverseLLM(BaseLLM):
client=client, client=client,
) # type: ignore ) # type: ignore
if (stream is not None and stream is True) and provider != "ai21": if stream is not None and stream is True:
streaming_response = CustomStreamWrapper( streaming_response = CustomStreamWrapper(
completion_stream=None, completion_stream=None,
@ -1981,7 +2031,7 @@ class BedrockConverseLLM(BaseLLM):
model=model, model=model,
response=response, response=response,
model_response=model_response, model_response=model_response,
stream=stream, stream=stream if isinstance(stream, bool) else False,
logging_obj=logging_obj, logging_obj=logging_obj,
optional_params=optional_params, optional_params=optional_params,
api_key="", api_key="",
@ -2168,3 +2218,49 @@ class AWSEventStreamDecoder:
return None return None
return chunk.decode() # type: ignore[no-any-return] return chunk.decode() # type: ignore[no-any-return]
class MockResponseIterator: # for returning ai21 streaming responses
def __init__(self, model_response):
self.model_response = model_response
self.is_done = False
# Sync iterator
def __iter__(self):
return self
def _chunk_parser(self, chunk_data: ModelResponse) -> GenericStreamingChunk:
try:
chunk_usage: litellm.Usage = getattr(chunk_data, "usage")
processed_chunk = GenericStreamingChunk(
text=chunk_data.choices[0].message.content or "", # type: ignore
tool_use=None,
is_finished=True,
finish_reason=chunk_data.choices[0].finish_reason, # type: ignore
usage=ConverseTokenUsageBlock(
inputTokens=chunk_usage.prompt_tokens,
outputTokens=chunk_usage.completion_tokens,
totalTokens=chunk_usage.total_tokens,
),
index=0,
)
return processed_chunk
except Exception:
raise ValueError(f"Failed to decode chunk: {chunk_data}")
def __next__(self):
if self.is_done:
raise StopIteration
self.is_done = True
return self._chunk_parser(self.model_response)
# Async iterator
def __aiter__(self):
return self
async def __anext__(self):
if self.is_done:
raise StopAsyncIteration
self.is_done = True
return self._chunk_parser(self.model_response)

View file

@ -4798,9 +4798,10 @@ async def ahealth_check(
if isinstance(stack_trace, str): if isinstance(stack_trace, str):
stack_trace = stack_trace[:1000] stack_trace = stack_trace[:1000]
if model not in litellm.model_cost and mode is None: if model not in litellm.model_cost and mode is None:
raise Exception( return {
"Missing `mode`. Set the `mode` for the model - https://docs.litellm.ai/docs/proxy/health#embedding-models" "error": "Missing `mode`. Set the `mode` for the model - https://docs.litellm.ai/docs/proxy/health#embedding-models"
) }
error_to_return = str(e) + " stack trace: " + stack_trace error_to_return = str(e) + " stack trace: " + stack_trace
return {"error": error_to_return} return {"error": error_to_return}

View file

@ -2803,6 +2803,16 @@
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
"mode": "chat" "mode": "chat"
}, },
"ai21.jamba-instruct-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 70000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.0000005,
"output_cost_per_token": 0.0000007,
"litellm_provider": "bedrock",
"mode": "chat",
"supports_system_messages": true
},
"amazon.titan-text-lite-v1": { "amazon.titan-text-lite-v1": {
"max_tokens": 4000, "max_tokens": 4000,
"max_input_tokens": 42000, "max_input_tokens": 42000,

View file

@ -3,4 +3,4 @@ model_list:
litellm_params: litellm_params:
model: azure/chatgpt-v-2 model: azure/chatgpt-v-2
api_key: os.environ/AZURE_API_KEY api_key: os.environ/AZURE_API_KEY
api_base: os.environ/AZURE_API_BASE api_base: os.environ/AZURE_API_BASE

View file

@ -4,5 +4,7 @@ model_list:
model: fireworks_ai/accounts/fireworks/models/llama-v3-70b-instruct model: fireworks_ai/accounts/fireworks/models/llama-v3-70b-instruct
api_key: "os.environ/FIREWORKS_AI_API_KEY" api_key: "os.environ/FIREWORKS_AI_API_KEY"
router_settings:
enable_tag_filtering: True # 👈 Key Change
general_settings: general_settings:
master_key: sk-1234 master_key: sk-1234

View file

@ -592,6 +592,8 @@ def test_bedrock_claude_3(image_url):
assert len(response.choices) > 0 assert len(response.choices) > 0
assert len(response.choices[0].message.content) > 0 assert len(response.choices[0].message.content) > 0
except litellm.InternalServerError:
pass
except RateLimitError: except RateLimitError:
pass pass
except Exception as e: except Exception as e:

View file

@ -1348,7 +1348,10 @@ def test_completion_fireworks_ai():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
def test_completion_fireworks_ai_bad_api_base(): @pytest.mark.parametrize(
"api_key, api_base", [(None, "my-bad-api-base"), ("my-bad-api-key", None)]
)
def test_completion_fireworks_ai_dynamic_params(api_key, api_base):
try: try:
litellm.set_verbose = True litellm.set_verbose = True
messages = [ messages = [
@ -1361,7 +1364,8 @@ def test_completion_fireworks_ai_bad_api_base():
response = completion( response = completion(
model="fireworks_ai/accounts/fireworks/models/mixtral-8x7b-instruct", model="fireworks_ai/accounts/fireworks/models/mixtral-8x7b-instruct",
messages=messages, messages=messages,
api_base="my-bad-api-base", api_base=api_base,
api_key=api_key,
) )
pytest.fail(f"This call should have failed!") pytest.fail(f"This call should have failed!")
except Exception as e: except Exception as e:

View file

@ -706,9 +706,9 @@ def test_vertex_ai_completion_cost():
print("calculated_input_cost: {}".format(calculated_input_cost)) print("calculated_input_cost: {}".format(calculated_input_cost))
# @pytest.mark.skip(reason="new test - WIP, working on fixing this") @pytest.mark.skip(reason="new test - WIP, working on fixing this")
def test_vertex_ai_medlm_completion_cost(): def test_vertex_ai_medlm_completion_cost():
"""Test for medlm completion cost.""" """Test for medlm completion cost ."""
with pytest.raises(Exception) as e: with pytest.raises(Exception) as e:
model = "vertex_ai/medlm-medium" model = "vertex_ai/medlm-medium"

View file

@ -90,6 +90,7 @@ def test_context_window(model):
models = ["command-nightly"] models = ["command-nightly"]
@pytest.mark.skip(reason="duplicate test.")
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model", models)
def test_context_window_with_fallbacks(model): def test_context_window_with_fallbacks(model):
ctx_window_fallback_dict = { ctx_window_fallback_dict = {

View file

@ -1,8 +1,12 @@
#### What this tests #### #### What this tests ####
# This tests if the router timeout error handling during fallbacks # This tests if the router timeout error handling during fallbacks
import sys, os, time import asyncio
import traceback, asyncio import os
import sys
import time
import traceback
import pytest import pytest
sys.path.insert( sys.path.insert(
@ -12,9 +16,10 @@ sys.path.insert(
import os import os
from dotenv import load_dotenv
import litellm import litellm
from litellm import Router from litellm import Router
from dotenv import load_dotenv
load_dotenv() load_dotenv()
@ -37,6 +42,7 @@ def test_router_timeouts():
"litellm_params": { "litellm_params": {
"model": "claude-instant-1.2", "model": "claude-instant-1.2",
"api_key": "os.environ/ANTHROPIC_API_KEY", "api_key": "os.environ/ANTHROPIC_API_KEY",
"mock_response": "hello world",
}, },
"tpm": 20000, "tpm": 20000,
}, },
@ -90,7 +96,9 @@ def test_router_timeouts():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_router_timeouts_bedrock(): async def test_router_timeouts_bedrock():
import openai, uuid import uuid
import openai
# Model list for OpenAI and Anthropic models # Model list for OpenAI and Anthropic models
_model_list = [ _model_list = [

View file

@ -1312,22 +1312,22 @@ async def test_completion_replicate_llama3_streaming(sync_mode):
# pytest.fail(f"Error occurred: {e}") # pytest.fail(f"Error occurred: {e}")
@pytest.mark.parametrize("sync_mode", [True]) # False @pytest.mark.parametrize("sync_mode", [True, False]) #
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model", "model, region",
[ [
"bedrock/cohere.command-r-plus-v1:0", ["bedrock/ai21.jamba-instruct-v1:0", "us-east-1"],
"anthropic.claude-3-sonnet-20240229-v1:0", ["bedrock/cohere.command-r-plus-v1:0", None],
"anthropic.claude-instant-v1", ["anthropic.claude-3-sonnet-20240229-v1:0", None],
"bedrock/ai21.j2-mid", ["anthropic.claude-instant-v1", None],
"mistral.mistral-7b-instruct-v0:2", ["mistral.mistral-7b-instruct-v0:2", None],
"bedrock/amazon.titan-tg1-large", ["bedrock/amazon.titan-tg1-large", None],
"meta.llama3-8b-instruct-v1:0", ["meta.llama3-8b-instruct-v1:0", None],
"cohere.command-text-v14", ["cohere.command-text-v14", None],
], ],
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bedrock_httpx_streaming(sync_mode, model): async def test_bedrock_httpx_streaming(sync_mode, model, region):
try: try:
litellm.set_verbose = True litellm.set_verbose = True
if sync_mode: if sync_mode:
@ -1337,6 +1337,7 @@ async def test_bedrock_httpx_streaming(sync_mode, model):
messages=messages, messages=messages,
max_tokens=10, # type: ignore max_tokens=10, # type: ignore
stream=True, stream=True,
aws_region_name=region,
) )
complete_response = "" complete_response = ""
# Add any assertions here to check the response # Add any assertions here to check the response
@ -1358,6 +1359,7 @@ async def test_bedrock_httpx_streaming(sync_mode, model):
messages=messages, messages=messages,
max_tokens=100, # type: ignore max_tokens=100, # type: ignore
stream=True, stream=True,
aws_region_name=region,
) )
complete_response = "" complete_response = ""
# Add any assertions here to check the response # Add any assertions here to check the response

View file

@ -20,7 +20,12 @@ from litellm import (
token_counter, token_counter,
) )
from litellm.tests.large_text import text from litellm.tests.large_text import text
from litellm.tests.messages_with_counts import MESSAGES_TEXT, MESSAGES_WITH_IMAGES, MESSAGES_WITH_TOOLS from litellm.tests.messages_with_counts import (
MESSAGES_TEXT,
MESSAGES_WITH_IMAGES,
MESSAGES_WITH_TOOLS,
)
def test_token_counter_normal_plus_function_calling(): def test_token_counter_normal_plus_function_calling():
try: try:
@ -55,27 +60,28 @@ def test_token_counter_normal_plus_function_calling():
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}") pytest.fail(f"An exception occurred - {str(e)}")
# test_token_counter_normal_plus_function_calling() # test_token_counter_normal_plus_function_calling()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"message_count_pair", "message_count_pair",
MESSAGES_TEXT, MESSAGES_TEXT,
) )
def test_token_counter_textonly(message_count_pair): def test_token_counter_textonly(message_count_pair):
counted_tokens = token_counter( counted_tokens = token_counter(
model="gpt-35-turbo", model="gpt-35-turbo", messages=[message_count_pair["message"]]
messages=[message_count_pair["message"]]
) )
assert counted_tokens == message_count_pair["count"] assert counted_tokens == message_count_pair["count"]
@pytest.mark.parametrize( @pytest.mark.parametrize(
"message_count_pair", "message_count_pair",
MESSAGES_WITH_IMAGES, MESSAGES_WITH_IMAGES,
) )
def test_token_counter_with_images(message_count_pair): def test_token_counter_with_images(message_count_pair):
counted_tokens = token_counter( counted_tokens = token_counter(
model="gpt-4o", model="gpt-4o", messages=[message_count_pair["message"]]
messages=[message_count_pair["message"]]
) )
assert counted_tokens == message_count_pair["count"] assert counted_tokens == message_count_pair["count"]
@ -327,3 +333,13 @@ def test_get_modified_max_tokens(
), "Got={}, Expected={}, Params={}".format( ), "Got={}, Expected={}, Params={}".format(
calculated_value, expected_value, args calculated_value, expected_value, args
) )
def test_empty_tools():
messages = [{"role": "user", "content": "hey, how's it going?", "tool_calls": None}]
result = token_counter(
messages=messages,
)
print(result)

View file

@ -1911,7 +1911,7 @@ def token_counter(
# use tiktoken, anthropic, cohere, llama2, or llama3's tokenizer depending on the model # use tiktoken, anthropic, cohere, llama2, or llama3's tokenizer depending on the model
is_tool_call = False is_tool_call = False
num_tokens = 0 num_tokens = 0
if text == None: if text is None:
if messages is not None: if messages is not None:
print_verbose(f"token_counter messages received: {messages}") print_verbose(f"token_counter messages received: {messages}")
text = "" text = ""
@ -1937,7 +1937,7 @@ def token_counter(
num_tokens += calculage_img_tokens( num_tokens += calculage_img_tokens(
data=image_url_str, mode="auto" data=image_url_str, mode="auto"
) )
if "tool_calls" in message: if message.get("tool_calls"):
is_tool_call = True is_tool_call = True
for tool_call in message["tool_calls"]: for tool_call in message["tool_calls"]:
if "function" in tool_call: if "function" in tool_call:
@ -4398,44 +4398,44 @@ def get_llm_provider(
if custom_llm_provider == "perplexity": if custom_llm_provider == "perplexity":
# perplexity is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.perplexity.ai # perplexity is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.perplexity.ai
api_base = api_base or "https://api.perplexity.ai" api_base = api_base or "https://api.perplexity.ai"
dynamic_api_key = get_secret("PERPLEXITYAI_API_KEY") dynamic_api_key = api_key or get_secret("PERPLEXITYAI_API_KEY")
elif custom_llm_provider == "anyscale": elif custom_llm_provider == "anyscale":
# anyscale is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1 # anyscale is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1
api_base = api_base or "https://api.endpoints.anyscale.com/v1" api_base = api_base or "https://api.endpoints.anyscale.com/v1"
dynamic_api_key = get_secret("ANYSCALE_API_KEY") dynamic_api_key = api_key or get_secret("ANYSCALE_API_KEY")
elif custom_llm_provider == "deepinfra": elif custom_llm_provider == "deepinfra":
# deepinfra is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1 # deepinfra is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1
api_base = api_base or "https://api.deepinfra.com/v1/openai" api_base = api_base or "https://api.deepinfra.com/v1/openai"
dynamic_api_key = get_secret("DEEPINFRA_API_KEY") dynamic_api_key = api_key or get_secret("DEEPINFRA_API_KEY")
elif custom_llm_provider == "empower": elif custom_llm_provider == "empower":
api_base = api_base or "https://app.empower.dev/api/v1" api_base = api_base or "https://app.empower.dev/api/v1"
dynamic_api_key = get_secret("EMPOWER_API_KEY") dynamic_api_key = api_key or get_secret("EMPOWER_API_KEY")
elif custom_llm_provider == "groq": elif custom_llm_provider == "groq":
# groq is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.groq.com/openai/v1 # groq is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.groq.com/openai/v1
api_base = api_base or "https://api.groq.com/openai/v1" api_base = api_base or "https://api.groq.com/openai/v1"
dynamic_api_key = get_secret("GROQ_API_KEY") dynamic_api_key = api_key or get_secret("GROQ_API_KEY")
elif custom_llm_provider == "nvidia_nim": elif custom_llm_provider == "nvidia_nim":
# nvidia_nim is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1 # nvidia_nim is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1
api_base = api_base or "https://integrate.api.nvidia.com/v1" api_base = api_base or "https://integrate.api.nvidia.com/v1"
dynamic_api_key = get_secret("NVIDIA_NIM_API_KEY") dynamic_api_key = api_key or get_secret("NVIDIA_NIM_API_KEY")
elif custom_llm_provider == "volcengine": elif custom_llm_provider == "volcengine":
# volcengine is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1 # volcengine is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1
api_base = api_base or "https://ark.cn-beijing.volces.com/api/v3" api_base = api_base or "https://ark.cn-beijing.volces.com/api/v3"
dynamic_api_key = get_secret("VOLCENGINE_API_KEY") dynamic_api_key = api_key or get_secret("VOLCENGINE_API_KEY")
elif custom_llm_provider == "codestral": elif custom_llm_provider == "codestral":
# codestral is openai compatible, we just need to set this to custom_openai and have the api_base be https://codestral.mistral.ai/v1 # codestral is openai compatible, we just need to set this to custom_openai and have the api_base be https://codestral.mistral.ai/v1
api_base = api_base or "https://codestral.mistral.ai/v1" api_base = api_base or "https://codestral.mistral.ai/v1"
dynamic_api_key = get_secret("CODESTRAL_API_KEY") dynamic_api_key = api_key or get_secret("CODESTRAL_API_KEY")
elif custom_llm_provider == "deepseek": elif custom_llm_provider == "deepseek":
# deepseek is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.deepseek.com/v1 # deepseek is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.deepseek.com/v1
api_base = api_base or "https://api.deepseek.com/v1" api_base = api_base or "https://api.deepseek.com/v1"
dynamic_api_key = get_secret("DEEPSEEK_API_KEY") dynamic_api_key = api_key or get_secret("DEEPSEEK_API_KEY")
elif custom_llm_provider == "fireworks_ai": elif custom_llm_provider == "fireworks_ai":
# fireworks is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.fireworks.ai/inference/v1 # fireworks is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.fireworks.ai/inference/v1
if not model.startswith("accounts/fireworks/models"): if not model.startswith("accounts/fireworks/models"):
model = f"accounts/fireworks/models/{model}" model = f"accounts/fireworks/models/{model}"
api_base = api_base or "https://api.fireworks.ai/inference/v1" api_base = api_base or "https://api.fireworks.ai/inference/v1"
dynamic_api_key = ( dynamic_api_key = api_key or (
get_secret("FIREWORKS_API_KEY") get_secret("FIREWORKS_API_KEY")
or get_secret("FIREWORKS_AI_API_KEY") or get_secret("FIREWORKS_AI_API_KEY")
or get_secret("FIREWORKSAI_API_KEY") or get_secret("FIREWORKSAI_API_KEY")
@ -4465,10 +4465,10 @@ def get_llm_provider(
elif custom_llm_provider == "voyage": elif custom_llm_provider == "voyage":
# voyage is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.voyageai.com/v1 # voyage is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.voyageai.com/v1
api_base = "https://api.voyageai.com/v1" api_base = "https://api.voyageai.com/v1"
dynamic_api_key = get_secret("VOYAGE_API_KEY") dynamic_api_key = api_key or get_secret("VOYAGE_API_KEY")
elif custom_llm_provider == "together_ai": elif custom_llm_provider == "together_ai":
api_base = "https://api.together.xyz/v1" api_base = "https://api.together.xyz/v1"
dynamic_api_key = ( dynamic_api_key = api_key or (
get_secret("TOGETHER_API_KEY") get_secret("TOGETHER_API_KEY")
or get_secret("TOGETHER_AI_API_KEY") or get_secret("TOGETHER_AI_API_KEY")
or get_secret("TOGETHERAI_API_KEY") or get_secret("TOGETHERAI_API_KEY")
@ -4476,8 +4476,10 @@ def get_llm_provider(
) )
elif custom_llm_provider == "friendliai": elif custom_llm_provider == "friendliai":
api_base = "https://inference.friendli.ai/v1" api_base = "https://inference.friendli.ai/v1"
dynamic_api_key = get_secret("FRIENDLIAI_API_KEY") or get_secret( dynamic_api_key = (
"FRIENDLI_TOKEN" api_key
or get_secret("FRIENDLIAI_API_KEY")
or get_secret("FRIENDLI_TOKEN")
) )
if api_base is not None and not isinstance(api_base, str): if api_base is not None and not isinstance(api_base, str):
raise Exception( raise Exception(
@ -6813,6 +6815,13 @@ def exception_type(
model=model, model=model,
llm_provider="bedrock", llm_provider="bedrock",
) )
elif "Could not process image" in error_str:
exception_mapping_worked = True
raise litellm.InternalServerError(
message=f"BedrockException - {error_str}",
model=model,
llm_provider="bedrock",
)
elif hasattr(original_exception, "status_code"): elif hasattr(original_exception, "status_code"):
if original_exception.status_code == 500: if original_exception.status_code == 500:
exception_mapping_worked = True exception_mapping_worked = True

View file

@ -2803,6 +2803,16 @@
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
"mode": "chat" "mode": "chat"
}, },
"ai21.jamba-instruct-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 70000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.0000005,
"output_cost_per_token": 0.0000007,
"litellm_provider": "bedrock",
"mode": "chat",
"supports_system_messages": true
},
"amazon.titan-text-lite-v1": { "amazon.titan-text-lite-v1": {
"max_tokens": 4000, "max_tokens": 4000,
"max_input_tokens": 42000, "max_input_tokens": 42000,

View file

@ -48,7 +48,7 @@ const Sidebar: React.FC<SidebarProps> = ({
style={{ height: "100%", borderRight: 0 }} style={{ height: "100%", borderRight: 0 }}
> >
<Menu.Item key="1" onClick={() => setPage("api-keys")}> <Menu.Item key="1" onClick={() => setPage("api-keys")}>
<Text>API Keys</Text> <Text>Virtual Keys</Text>
</Menu.Item> </Menu.Item>
<Menu.Item key="3" onClick={() => setPage("llm-playground")}> <Menu.Item key="3" onClick={() => setPage("llm-playground")}>
<Text>Test Key</Text> <Text>Test Key</Text>