LiteLLM Minor Fixes and Improvements (09/09/2024) (#5602)

* fix(main.py): pass default azure api version as alternative in completion call

Fixes api error caused due to api version

Closes https://github.com/BerriAI/litellm/issues/5584

* Fixed gemini-1.5-flash pricing (#5590)

* add /key/list endpoint

* bump: version 1.44.21 → 1.44.22

* docs architecture

* Fixed gemini-1.5-flash pricing

---------

Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com>

* fix(bedrock/chat.py): fix converse api stop sequence param mapping

Fixes https://github.com/BerriAI/litellm/issues/5592

* fix(databricks/cost_calculator.py): handle databricks model name changes

Fixes https://github.com/BerriAI/litellm/issues/5597

* fix(azure.py): support azure api version 2024-08-01-preview

Closes https://github.com/BerriAI/litellm/issues/5377

* fix(proxy/_types.py): allow dev keys to call cohere /rerank endpoint

Fixes issue where only admin could call rerank endpoint

* fix(azure.py): check if model is gpt-4o

* fix(proxy/_types.py): support /v1/rerank on non-admin routes as well

* fix(cost_calculator.py): fix split on `/` logic in cost calculator

---------

Co-authored-by: F1bos <44951186+F1bos@users.noreply.github.com>
Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com>
This commit is contained in:
Krish Dholakia 2024-09-09 21:56:12 -07:00 committed by GitHub
parent 4ac66bd843
commit 2d2282101b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 139 additions and 56 deletions

View file

@ -118,7 +118,7 @@ in_memory_llm_clients_cache: dict = {}
safe_memory_mode: bool = False
enable_azure_ad_token_refresh: Optional[bool] = False
### DEFAULT AZURE API VERSION ###
AZURE_DEFAULT_API_VERSION = "2024-07-01-preview" # this is updated to the latest
AZURE_DEFAULT_API_VERSION = "2024-08-01-preview" # this is updated to the latest
### COHERE EMBEDDINGS DEFAULT TYPE ###
COHERE_DEFAULT_EMBEDDING_INPUT_TYPE = "search_document"
### GUARDRAILS ###
@ -868,7 +868,7 @@ from .llms.custom_llm import CustomLLM
from .llms.huggingface_restapi import HuggingfaceConfig
from .llms.anthropic.chat import AnthropicConfig
from .llms.anthropic.completion import AnthropicTextConfig
from .llms.databricks import DatabricksConfig, DatabricksEmbeddingConfig
from .llms.databricks.chat import DatabricksConfig, DatabricksEmbeddingConfig
from .llms.predibase import PredibaseConfig
from .llms.replicate import ReplicateConfig
from .llms.cohere.completion import CohereConfig

View file

@ -22,6 +22,9 @@ from litellm.litellm_core_utils.llm_cost_calc.utils import _generic_cost_per_cha
from litellm.llms.anthropic.cost_calculation import (
cost_per_token as anthropic_cost_per_token,
)
from litellm.llms.databricks.cost_calculator import (
cost_per_token as databricks_cost_per_token,
)
from litellm.rerank_api.types import RerankResponse
from litellm.types.llms.openai import HttpxBinaryResponseContent
from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS
@ -159,7 +162,7 @@ def cost_per_token(
_, custom_llm_provider, _, _ = litellm.get_llm_provider(model=model)
model_without_prefix = model
model_parts = model.split("/")
model_parts = model.split("/", 1)
if len(model_parts) > 1:
model_without_prefix = model_parts[1]
else:
@ -212,6 +215,8 @@ def cost_per_token(
)
elif custom_llm_provider == "anthropic":
return anthropic_cost_per_token(model=model, usage=usage_block)
elif custom_llm_provider == "databricks":
return databricks_cost_per_token(model=model, usage=usage_block)
elif custom_llm_provider == "gemini":
return google_cost_per_token(
model=model_without_prefix,

View file

@ -245,7 +245,10 @@ class AzureOpenAIConfig:
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the models perspective.
"""
if json_schema is not None:
if json_schema is not None and (
(api_version_year <= "2024" and api_version_month < "08")
or "gpt-4o" not in model
): # azure api version "2024-08-01-preview" onwards supports 'json_schema' only for gpt-4o
_tool_choice = ChatCompletionToolChoiceObjectParam(
type="function",
function=ChatCompletionToolChoiceFunctionParam(

View file

@ -736,7 +736,9 @@ class BedrockLLM(BaseAWSLLM):
if (stream is not None and stream is True) and provider != "ai21":
endpoint_url = f"{endpoint_url}/model/{modelId}/invoke-with-response-stream"
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/invoke-with-response-stream"
proxy_endpoint_url = (
f"{proxy_endpoint_url}/model/{modelId}/invoke-with-response-stream"
)
else:
endpoint_url = f"{endpoint_url}/model/{modelId}/invoke"
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/invoke"
@ -1268,7 +1270,7 @@ class AmazonConverseConfig:
if len(value) == 0: # converse raises error for empty strings
continue
value = [value]
optional_params["stop_sequences"] = value
optional_params["stopSequences"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":

View file

@ -29,8 +29,8 @@ from litellm.types.utils import (
)
from litellm.utils import CustomStreamWrapper, EmbeddingResponse, ModelResponse, Usage
from .base import BaseLLM
from .prompt_templates.factory import custom_prompt, prompt_factory
from ..base import BaseLLM
from ..prompt_templates.factory import custom_prompt, prompt_factory
class DatabricksError(Exception):
@ -328,6 +328,7 @@ class DatabricksChatCompletion(BaseLLM):
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
custom_llm_provider: str,
print_verbose: Callable,
encoding,
api_key,
@ -371,6 +372,8 @@ class DatabricksChatCompletion(BaseLLM):
)
response = ModelResponse(**response_json)
response.model = custom_llm_provider + "/" + response.model
if base_model is not None:
response._hidden_params["model"] = base_model
return response
@ -472,6 +475,7 @@ class DatabricksChatCompletion(BaseLLM):
data=data,
api_base=api_base,
custom_prompt_dict=custom_prompt_dict,
custom_llm_provider=custom_llm_provider,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
@ -528,6 +532,8 @@ class DatabricksChatCompletion(BaseLLM):
response = ModelResponse(**response_json)
response.model = custom_llm_provider + "/" + response.model
if base_model is not None:
response._hidden_params["model"] = base_model

View file

@ -0,0 +1,39 @@
"""
Helper util for handling databricks-specific cost calculation
- e.g.: handling 'dbrx-instruct-*'
"""
from typing import Tuple
from litellm.types.utils import Usage
from litellm.utils import get_model_info
def cost_per_token(model: str, usage: Usage) -> Tuple[float, float]:
"""
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
Input:
- model: str, the model name without provider prefix
- usage: LiteLLM Usage block, containing anthropic caching information
Returns:
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
"""
base_model = model
if model.startswith("databricks/dbrx-instruct") or model.startswith(
"dbrx-instruct"
):
base_model = "databricks-dbrx-instruct"
## GET MODEL INFO
model_info = get_model_info(model=base_model, custom_llm_provider="databricks")
## CALCULATE INPUT COST
prompt_cost: float = usage["prompt_tokens"] * model_info["input_cost_per_token"]
## CALCULATE OUTPUT COST
completion_cost = usage["completion_tokens"] * model_info["output_cost_per_token"]
return prompt_cost, completion_cost

View file

@ -273,7 +273,7 @@ class SagemakerLLM(BaseAWSLLM):
model_id = optional_params.get("model_id", None)
if use_messages_api is True:
from litellm.llms.databricks import DatabricksChatCompletion
from litellm.llms.databricks.chat import DatabricksChatCompletion
openai_like_chat_completions = DatabricksChatCompletion()
inference_params["stream"] = True if stream is True else False

View file

@ -80,7 +80,7 @@ class VertexAIPartnerModels(BaseLLM):
import vertexai
from google.cloud import aiplatform
from litellm.llms.databricks import DatabricksChatCompletion
from litellm.llms.databricks.chat import DatabricksChatCompletion
from litellm.llms.OpenAI.openai import OpenAIChatCompletion
from litellm.llms.text_completion_codestral import CodestralTextCompletion
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (

View file

@ -92,7 +92,7 @@ from .llms.cohere import chat as cohere_chat
from .llms.cohere import completion as cohere_completion # type: ignore
from .llms.cohere import embed as cohere_embed
from .llms.custom_llm import CustomLLM, custom_chat_llm_router
from .llms.databricks import DatabricksChatCompletion
from .llms.databricks.chat import DatabricksChatCompletion
from .llms.huggingface_restapi import Huggingface
from .llms.OpenAI.audio_transcriptions import OpenAIAudioTranscription
from .llms.OpenAI.openai import OpenAIChatCompletion, OpenAITextCompletion
@ -1013,7 +1013,10 @@ def completion(
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE")
api_version = (
api_version or litellm.api_version or get_secret("AZURE_API_VERSION")
api_version
or litellm.api_version
or get_secret("AZURE_API_VERSION")
or litellm.AZURE_DEFAULT_API_VERSION
)
api_key = (

View file

@ -2512,16 +2512,16 @@
"max_audio_length_hours": 8.4,
"max_audio_per_prompt": 1,
"max_pdf_size_mb": 30,
"input_cost_per_token": 0.00000035,
"input_cost_per_token_above_128k_tokens": 0.0000007,
"output_cost_per_token": 0.00000105,
"output_cost_per_token_above_128k_tokens": 0.0000021,
"input_cost_per_token": 0.000000075,
"input_cost_per_token_above_128k_tokens": 0.00000015,
"output_cost_per_token": 0.0000003,
"output_cost_per_token_above_128k_tokens": 0.0000006,
"litellm_provider": "gemini",
"mode": "chat",
"supports_system_messages": true,
"supports_function_calling": true,
"supports_vision": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-1.5-flash-latest": {
"max_tokens": 8192,
@ -2533,16 +2533,16 @@
"max_audio_length_hours": 8.4,
"max_audio_per_prompt": 1,
"max_pdf_size_mb": 30,
"input_cost_per_token": 0.00000035,
"input_cost_per_token_above_128k_tokens": 0.0000007,
"output_cost_per_token": 0.00000105,
"output_cost_per_token_above_128k_tokens": 0.0000021,
"input_cost_per_token": 0.000000075,
"input_cost_per_token_above_128k_tokens": 0.00000015,
"output_cost_per_token": 0.0000003,
"output_cost_per_token_above_128k_tokens": 0.0000006,
"litellm_provider": "gemini",
"mode": "chat",
"supports_system_messages": true,
"supports_function_calling": true,
"supports_vision": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-pro": {
"max_tokens": 8192,

View file

@ -242,6 +242,9 @@ class LiteLLMRoutes(enum.Enum):
"/v1/models",
# token counter
"/utils/token_counter",
# rerank
"/rerank",
"/v1/rerank",
]
mapped_pass_through_routes: List = [

View file

@ -891,18 +891,29 @@ def encode_image(image_path):
return base64.b64encode(image_file.read()).decode("utf-8")
@pytest.mark.skip(
reason="we already test claude-3, this is just another way to pass images"
)
def test_completion_claude_3_base64():
@pytest.mark.parametrize(
"model",
[
"gpt-4o",
"azure/gpt-4o",
"anthropic/claude-3-opus-20240229",
],
) #
def test_completion_base64(model):
try:
import base64
import requests
litellm.set_verbose = True
litellm.num_retries = 3
image_path = "../proxy/cached_logo.jpg"
# Getting the base64 string
base64_image = encode_image(image_path)
url = "https://dummyimage.com/100/100/fff&text=Test+image"
response = requests.get(url)
file_data = response.content
encoded_file = base64.b64encode(file_data).decode("utf-8")
base64_image = f"data:image/png;base64,{encoded_file}"
resp = litellm.completion(
model="anthropic/claude-3-opus-20240229",
model=model,
messages=[
{
"role": "user",
@ -910,9 +921,7 @@ def test_completion_claude_3_base64():
{"type": "text", "text": "Whats in this image?"},
{
"type": "image_url",
"image_url": {
"url": "data:image/jpeg;base64," + base64_image
},
"image_url": {"url": base64_image},
},
],
}
@ -921,7 +930,6 @@ def test_completion_claude_3_base64():
print(f"\nResponse: {resp}")
prompt_tokens = resp.usage.prompt_tokens
raise Exception("it worked!")
except Exception as e:
if "500 Internal error encountered.'" in str(e):
pass
@ -2176,15 +2184,16 @@ def test_completion_openai():
@pytest.mark.parametrize(
"model",
"model, api_version",
[
"gpt-4o-2024-08-06",
"azure/chatgpt-v-2",
"bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
("gpt-4o-2024-08-06", None),
("azure/chatgpt-v-2", None),
("bedrock/anthropic.claude-3-sonnet-20240229-v1:0", None),
("azure/gpt-4o", "2024-08-01-preview"),
],
)
@pytest.mark.flaky(retries=3, delay=1)
def test_completion_openai_pydantic(model):
def test_completion_openai_pydantic(model, api_version):
try:
litellm.set_verbose = True
from pydantic import BaseModel
@ -2209,6 +2218,7 @@ def test_completion_openai_pydantic(model):
messages=messages,
metadata={"hi": "bye"},
response_format=EventsList,
api_version=api_version,
)
break
except litellm.JSONSchemaValidationError:
@ -3471,14 +3481,14 @@ def response_format_tests(response: litellm.ModelResponse):
@pytest.mark.parametrize(
"model",
[
# "bedrock/cohere.command-r-plus-v1:0",
"bedrock/mistral.mistral-large-2407-v1:0",
"bedrock/cohere.command-r-plus-v1:0",
"anthropic.claude-3-sonnet-20240229-v1:0",
# "anthropic.claude-instant-v1",
# "bedrock/ai21.j2-mid",
# "mistral.mistral-7b-instruct-v0:2",
"anthropic.claude-instant-v1",
"mistral.mistral-7b-instruct-v0:2",
# "bedrock/amazon.titan-tg1-large",
# "meta.llama3-8b-instruct-v1:0",
# "cohere.command-text-v14",
"meta.llama3-8b-instruct-v1:0",
"cohere.command-text-v14",
],
)
@pytest.mark.parametrize("sync_mode", [True, False])
@ -3493,6 +3503,7 @@ async def test_completion_bedrock_httpx_models(sync_mode, model):
messages=[{"role": "user", "content": "Hey! how's it going?"}],
temperature=0.2,
max_tokens=200,
stop=["stop sequence"],
)
assert isinstance(response, litellm.ModelResponse)
@ -3504,6 +3515,7 @@ async def test_completion_bedrock_httpx_models(sync_mode, model):
messages=[{"role": "user", "content": "Hey! how's it going?"}],
temperature=0.2,
max_tokens=100,
stop=["stop sequence"],
)
assert isinstance(response, litellm.ModelResponse)

View file

@ -1219,3 +1219,13 @@ def test_completion_cost_anthropic_prompt_caching():
cost_2 = completion_cost(model=model, completion_response=response_2)
assert cost_1 > cost_2
def test_completion_cost_databricks():
model, messages = "databricks/databricks-dbrx-instruct", [
{"role": "user", "content": "What is 2+2?"}
]
resp = litellm.completion(model=model, messages=messages) # works fine
cost = completion_cost(completion_response=resp)

View file

@ -2512,16 +2512,16 @@
"max_audio_length_hours": 8.4,
"max_audio_per_prompt": 1,
"max_pdf_size_mb": 30,
"input_cost_per_token": 0.00000035,
"input_cost_per_token_above_128k_tokens": 0.0000007,
"output_cost_per_token": 0.00000105,
"output_cost_per_token_above_128k_tokens": 0.0000021,
"input_cost_per_token": 0.000000075,
"input_cost_per_token_above_128k_tokens": 0.00000015,
"output_cost_per_token": 0.0000003,
"output_cost_per_token_above_128k_tokens": 0.0000006,
"litellm_provider": "gemini",
"mode": "chat",
"supports_system_messages": true,
"supports_function_calling": true,
"supports_vision": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-1.5-flash-latest": {
"max_tokens": 8192,
@ -2533,16 +2533,16 @@
"max_audio_length_hours": 8.4,
"max_audio_per_prompt": 1,
"max_pdf_size_mb": 30,
"input_cost_per_token": 0.00000035,
"input_cost_per_token_above_128k_tokens": 0.0000007,
"output_cost_per_token": 0.00000105,
"output_cost_per_token_above_128k_tokens": 0.0000021,
"input_cost_per_token": 0.000000075,
"input_cost_per_token_above_128k_tokens": 0.00000015,
"output_cost_per_token": 0.0000003,
"output_cost_per_token_above_128k_tokens": 0.0000006,
"litellm_provider": "gemini",
"mode": "chat",
"supports_system_messages": true,
"supports_function_calling": true,
"supports_vision": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-pro": {
"max_tokens": 8192,