LiteLLM Minor Fixes & Improvements (10/18/2024) (#6320)

* fix(converse_transformation.py): handle cross region model name when getting openai param support

Fixes https://github.com/BerriAI/litellm/issues/6291

* LiteLLM Minor Fixes & Improvements (10/17/2024)  (#6293)

* fix(ui_sso.py): fix faulty admin only check

Fixes https://github.com/BerriAI/litellm/issues/6286

* refactor(sso_helper_utils.py): refactor /sso/callback to use helper utils, covered by unit testing

Prevent future regressions

* feat(prompt_factory): support 'ensure_alternating_roles' param

Closes https://github.com/BerriAI/litellm/issues/6257

* fix(proxy/utils.py): add dailytagspend to expected views

* feat(auth_utils.py): support setting regex for clientside auth credentials

Fixes https://github.com/BerriAI/litellm/issues/6203

* build(cookbook): add tutorial for mlflow + langchain + litellm proxy tracing

* feat(argilla.py): add argilla logging integration

Closes https://github.com/BerriAI/litellm/issues/6201

* fix: fix linting errors

* fix: fix ruff error

* test: fix test

* fix: update vertex ai assumption - parts not always guaranteed (#6296)

* docs(configs.md): add argila env var to docs

* docs(user_keys.md): add regex doc for clientside auth params

* docs(argilla.md): add doc on argilla logging

* docs(argilla.md): add sampling rate to argilla calls

* bump: version 1.49.6 → 1.49.7

* add gpt-4o-audio models to model cost map (#6306)

* (code quality) add ruff check PLR0915 for `too-many-statements`  (#6309)

* ruff add PLR0915

* add noqa for PLR0915

* fix noqa

* add # noqa: PLR0915

* # noqa: PLR0915

* # noqa: PLR0915

* # noqa: PLR0915

* add # noqa: PLR0915

* # noqa: PLR0915

* # noqa: PLR0915

* # noqa: PLR0915

* # noqa: PLR0915

* doc fix Turn on / off caching per Key. (#6297)

* (feat) Support `audio`,  `modalities` params (#6304)

* add audio, modalities param

* add test for gpt audio models

* add get_supported_openai_params for GPT audio models

* add supported params for audio

* test_audio_output_from_model

* bump openai to openai==1.52.0

* bump openai on pyproject

* fix audio test

* fix test mock_chat_response

* handle audio for Message

* fix handling audio for OAI compatible API endpoints

* fix linting

* fix mock dbrx test

* (feat) Support audio param in responses streaming (#6312)

* add audio, modalities param

* add test for gpt audio models

* add get_supported_openai_params for GPT audio models

* add supported params for audio

* test_audio_output_from_model

* bump openai to openai==1.52.0

* bump openai on pyproject

* fix audio test

* fix test mock_chat_response

* handle audio for Message

* fix handling audio for OAI compatible API endpoints

* fix linting

* fix mock dbrx test

* add audio to Delta

* handle model_response.choices.delta.audio

* fix linting

* build(model_prices_and_context_window.json): add gpt-4o-audio audio token cost tracking

* refactor(model_prices_and_context_window.json): refactor 'supports_audio' to be 'supports_audio_input' and 'supports_audio_output'

Allows for flag to be used for openai + gemini models (both support audio input)

* feat(cost_calculation.py): support cost calc for audio model

Closes https://github.com/BerriAI/litellm/issues/6302

* feat(utils.py): expose new `supports_audio_input` and `supports_audio_output` functions

Closes https://github.com/BerriAI/litellm/issues/6303

* feat(handle_jwt.py): support single dict list

* fix(cost_calculator.py): fix linting errors

* fix: fix linting error

* fix(cost_calculator): move to using standard openai usage cached tokens value

* test: fix test

---------

Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com>
This commit is contained in:
Krish Dholakia 2024-10-19 22:23:27 -07:00 committed by GitHub
parent c58d542282
commit 7cc12bd5c6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 496 additions and 121 deletions

View file

@ -233,7 +233,7 @@ class Cache:
if self.namespace is not None and isinstance(self.cache, RedisCache):
self.cache.namespace = self.namespace
def get_cache_key(self, *args, **kwargs) -> str:
def get_cache_key(self, *args, **kwargs) -> str: # noqa: PLR0915
"""
Get the cache key for the given arguments.

View file

@ -37,12 +37,16 @@ from litellm.llms.databricks.cost_calculator import (
from litellm.llms.fireworks_ai.cost_calculator import (
cost_per_token as fireworks_ai_cost_per_token,
)
from litellm.llms.OpenAI.cost_calculation import (
cost_per_second as openai_cost_per_second,
)
from litellm.llms.OpenAI.cost_calculation import cost_per_token as openai_cost_per_token
from litellm.llms.OpenAI.cost_calculation import cost_router as openai_cost_router
from litellm.llms.together_ai.cost_calculator import get_model_params_and_category
from litellm.types.llms.openai import HttpxBinaryResponseContent
from litellm.types.rerank import RerankResponse
from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS
from litellm.types.utils import PassthroughCallTypes, Usage
from litellm.types.utils import CallTypesLiteral, PassthroughCallTypes, Usage
from litellm.utils import (
CallTypes,
CostPerToken,
@ -97,25 +101,10 @@ def cost_per_token( # noqa: PLR0915
custom_cost_per_second: Optional[float] = None,
### NUMBER OF QUERIES ###
number_of_queries: Optional[int] = None,
### USAGE OBJECT ###
usage_object: Optional[Usage] = None, # just read the usage object if provided
### CALL TYPE ###
call_type: Literal[
"embedding",
"aembedding",
"completion",
"acompletion",
"atext_completion",
"text_completion",
"image_generation",
"aimage_generation",
"moderation",
"amoderation",
"atranscription",
"transcription",
"aspeech",
"speech",
"rerank",
"arerank",
] = "completion",
call_type: CallTypesLiteral = "completion",
) -> Tuple[float, float]: # type: ignore
"""
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
@ -139,13 +128,16 @@ def cost_per_token( # noqa: PLR0915
raise Exception("Invalid arg. Model cannot be none.")
## RECONSTRUCT USAGE BLOCK ##
usage_block = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
)
if usage_object is not None:
usage_block = usage_object
else:
usage_block = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
)
## CUSTOM PRICING ##
response_cost = _cost_per_token_custom_pricing_helper(
@ -264,9 +256,13 @@ def cost_per_token( # noqa: PLR0915
elif custom_llm_provider == "anthropic":
return anthropic_cost_per_token(model=model, usage=usage_block)
elif custom_llm_provider == "openai":
return openai_cost_per_token(
model=model, usage=usage_block, response_time_ms=response_time_ms
)
openai_cost_route = openai_cost_router(call_type=CallTypes(call_type))
if openai_cost_route == "cost_per_token":
return openai_cost_per_token(model=model, usage=usage_block)
elif openai_cost_route == "cost_per_second":
return openai_cost_per_second(
model=model, usage=usage_block, response_time_ms=response_time_ms
)
elif custom_llm_provider == "databricks":
return databricks_cost_per_token(model=model, usage=usage_block)
elif custom_llm_provider == "fireworks_ai":
@ -474,6 +470,45 @@ def _select_model_name_for_cost_calc(
return return_model
def _get_usage_object(
completion_response: Any,
) -> Optional[Usage]:
usage_obj: Optional[Usage] = None
if completion_response is not None and isinstance(
completion_response, ModelResponse
):
usage_obj = completion_response.get("usage")
return usage_obj
def _infer_call_type(
call_type: Optional[CallTypesLiteral], completion_response: Any
) -> Optional[CallTypesLiteral]:
if call_type is not None:
return call_type
if completion_response is None:
return None
if isinstance(completion_response, ModelResponse):
return "completion"
elif isinstance(completion_response, EmbeddingResponse):
return "embedding"
elif isinstance(completion_response, TranscriptionResponse):
return "transcription"
elif isinstance(completion_response, HttpxBinaryResponseContent):
return "speech"
elif isinstance(completion_response, RerankResponse):
return "rerank"
elif isinstance(completion_response, ImageResponse):
return "image_generation"
elif isinstance(completion_response, TextCompletionResponse):
return "text_completion"
return call_type
def completion_cost( # noqa: PLR0915
completion_response=None,
model: Optional[str] = None,
@ -481,24 +516,7 @@ def completion_cost( # noqa: PLR0915
messages: List = [],
completion="",
total_time: Optional[float] = 0.0, # used for replicate, sagemaker
call_type: Literal[
"embedding",
"aembedding",
"completion",
"acompletion",
"atext_completion",
"text_completion",
"image_generation",
"aimage_generation",
"moderation",
"amoderation",
"atranscription",
"transcription",
"aspeech",
"speech",
"rerank",
"arerank",
] = "completion",
call_type: Optional[CallTypesLiteral] = None,
### REGION ###
custom_llm_provider=None,
region_name=None, # used for bedrock pricing
@ -539,6 +557,7 @@ def completion_cost( # noqa: PLR0915
- For un-mapped Replicate models, the cost is calculated based on the total time used for the request.
"""
try:
call_type = _infer_call_type(call_type, completion_response) or "completion"
if (
(call_type == "aimage_generation" or call_type == "image_generation")
and model is not None
@ -554,6 +573,9 @@ def completion_cost( # noqa: PLR0915
completion_characters: Optional[int] = None
cache_creation_input_tokens: Optional[int] = None
cache_read_input_tokens: Optional[int] = None
cost_per_token_usage_object: Optional[litellm.Usage] = _get_usage_object(
completion_response=completion_response
)
if completion_response is not None and (
isinstance(completion_response, BaseModel)
or isinstance(completion_response, dict)
@ -760,6 +782,7 @@ def completion_cost( # noqa: PLR0915
completion_characters=completion_characters,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
usage_object=cost_per_token_usage_object,
call_type=call_type,
)
_final_cost = prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar

View file

@ -25,10 +25,13 @@ def cost_per_token(
"""
## GET MODEL INFO
model_info = get_model_info(model=model, custom_llm_provider="azure")
cached_tokens: Optional[int] = None
## CALCULATE INPUT COST
total_prompt_tokens: float = usage["prompt_tokens"] - usage._cache_read_input_tokens
prompt_cost: float = total_prompt_tokens * model_info["input_cost_per_token"]
non_cached_text_tokens = usage.prompt_tokens
if usage.prompt_tokens_details and usage.prompt_tokens_details.cached_tokens:
cached_tokens = usage.prompt_tokens_details.cached_tokens
non_cached_text_tokens = non_cached_text_tokens - cached_tokens
prompt_cost: float = non_cached_text_tokens * model_info["input_cost_per_token"]
## CALCULATE OUTPUT COST
completion_cost: float = (
@ -36,9 +39,9 @@ def cost_per_token(
)
## Prompt Caching cost calculation
if model_info.get("cache_read_input_token_cost") is not None:
if model_info.get("cache_read_input_token_cost") is not None and cached_tokens:
# Note: We read ._cache_read_input_tokens from the Usage - since cost_calculator.py standardizes the cache read tokens on usage._cache_read_input_tokens
prompt_cost += usage._cache_read_input_tokens * (
prompt_cost += cached_tokens * (
model_info.get("cache_read_input_token_cost", 0) or 0
)

View file

@ -3,16 +3,21 @@ Helper util for handling openai-specific cost calculation
- e.g.: prompt caching
"""
from typing import Optional, Tuple
from typing import Literal, Optional, Tuple
from litellm._logging import verbose_logger
from litellm.types.utils import Usage
from litellm.types.utils import CallTypes, Usage
from litellm.utils import get_model_info
def cost_per_token(
model: str, usage: Usage, response_time_ms: Optional[float] = 0.0
) -> Tuple[float, float]:
def cost_router(call_type: CallTypes) -> Literal["cost_per_token", "cost_per_second"]:
if call_type == CallTypes.atranscription or call_type == CallTypes.transcription:
return "cost_per_second"
else:
return "cost_per_token"
def cost_per_token(model: str, usage: Usage) -> Tuple[float, float]:
"""
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
@ -27,21 +32,61 @@ def cost_per_token(
model_info = get_model_info(model=model, custom_llm_provider="openai")
## CALCULATE INPUT COST
total_prompt_tokens: float = usage["prompt_tokens"] - usage._cache_read_input_tokens
prompt_cost: float = total_prompt_tokens * model_info["input_cost_per_token"]
### Non-cached text tokens
non_cached_text_tokens = usage.prompt_tokens
cached_tokens: Optional[int] = None
if usage.prompt_tokens_details and usage.prompt_tokens_details.cached_tokens:
cached_tokens = usage.prompt_tokens_details.cached_tokens
non_cached_text_tokens = non_cached_text_tokens - cached_tokens
prompt_cost: float = non_cached_text_tokens * model_info["input_cost_per_token"]
## Prompt Caching cost calculation
if model_info.get("cache_read_input_token_cost") is not None and cached_tokens:
# Note: We read ._cache_read_input_tokens from the Usage - since cost_calculator.py standardizes the cache read tokens on usage._cache_read_input_tokens
prompt_cost += cached_tokens * (
model_info.get("cache_read_input_token_cost", 0) or 0
)
_audio_tokens: Optional[int] = (
usage.prompt_tokens_details.audio_tokens
if usage.prompt_tokens_details is not None
else None
)
_audio_cost_per_token: Optional[float] = model_info.get(
"input_cost_per_audio_token"
)
if _audio_tokens is not None and _audio_cost_per_token is not None:
audio_cost: float = _audio_tokens * _audio_cost_per_token
prompt_cost += audio_cost
## CALCULATE OUTPUT COST
completion_cost: float = (
usage["completion_tokens"] * model_info["output_cost_per_token"]
)
_output_cost_per_audio_token: Optional[float] = model_info.get(
"output_cost_per_audio_token"
)
_output_audio_tokens: Optional[int] = (
usage.completion_tokens_details.audio_tokens
if usage.completion_tokens_details is not None
else None
)
if _output_cost_per_audio_token is not None and _output_audio_tokens is not None:
audio_cost = _output_audio_tokens * _output_cost_per_audio_token
completion_cost += audio_cost
## Prompt Caching cost calculation
if model_info.get("cache_read_input_token_cost") is not None:
# Note: We read ._cache_read_input_tokens from the Usage - since cost_calculator.py standardizes the cache read tokens on usage._cache_read_input_tokens
prompt_cost += usage._cache_read_input_tokens * (
model_info.get("cache_read_input_token_cost", 0) or 0
)
return prompt_cost, completion_cost
def cost_per_second(
model: str, usage: Usage, response_time_ms: Optional[float] = 0.0
) -> Tuple[float, float]:
"""
Calculates the cost per second for a given model, prompt tokens, and completion tokens.
"""
## GET MODEL INFO
model_info = get_model_info(model=model, custom_llm_provider="openai")
prompt_cost = 0.0
completion_cost = 0.0
## Speech / Audio cost calculation
if (
"output_cost_per_second" in model_info
@ -52,7 +97,6 @@ def cost_per_token(
f"For model={model} - output_cost_per_second: {model_info.get('output_cost_per_second')}; response time: {response_time_ms}"
)
## COST PER SECOND ##
prompt_cost = 0
completion_cost = model_info["output_cost_per_second"] * response_time_ms / 1000
elif (
"input_cost_per_second" in model_info

View file

@ -43,7 +43,7 @@ from litellm.types.llms.openai import (
ChatCompletionToolCallFunctionChunk,
ChatCompletionUsageBlock,
)
from litellm.types.utils import GenericStreamingChunk, PromptTokensDetails
from litellm.types.utils import GenericStreamingChunk, PromptTokensDetailsWrapper
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from ...base import BaseLLM
@ -294,7 +294,7 @@ class AnthropicChatCompletion(BaseLLM):
cache_read_input_tokens = _usage["cache_read_input_tokens"]
prompt_tokens += cache_read_input_tokens
prompt_tokens_details = PromptTokensDetails(
prompt_tokens_details = PromptTokensDetailsWrapper(
cached_tokens=cache_read_input_tokens
)
total_tokens = prompt_tokens + completion_tokens

View file

@ -82,15 +82,19 @@ class AmazonConverseConfig:
"response_format",
]
## Filter out 'cross-region' from model name
base_model = self._get_base_model(model)
if (
model.startswith("anthropic")
or model.startswith("mistral")
or model.startswith("cohere")
or model.startswith("meta.llama3-1")
base_model.startswith("anthropic")
or base_model.startswith("mistral")
or base_model.startswith("cohere")
or base_model.startswith("meta.llama3-1")
or base_model.startswith("meta.llama3-2")
):
supported_params.append("tools")
if model.startswith("anthropic") or model.startswith("mistral"):
if base_model.startswith("anthropic") or base_model.startswith("mistral"):
# only anthropic and mistral support tool choice config. otherwise (E.g. cohere) will fail the call - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
supported_params.append("tool_choice")

View file

@ -5403,7 +5403,7 @@ def stream_chunk_builder_text_completion(
return TextCompletionResponse(**response)
def stream_chunk_builder(
def stream_chunk_builder( # noqa: PLR0915
chunks: list, messages: Optional[list] = None, start_time=None, end_time=None
) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
try:

View file

@ -10,7 +10,8 @@
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_vision": true,
"supports_audio": true,
"supports_audio_input": true,
"supports_audio_output": true,
"supports_prompt_caching": true
},
"gpt-4": {
@ -43,24 +44,30 @@
"max_input_tokens": 128000,
"max_output_tokens": 16384,
"input_cost_per_token": 0.0000025,
"input_cost_per_audio_token": 0.0001,
"output_cost_per_token": 0.000010,
"output_cost_per_audio_token": 0.0002,
"litellm_provider": "openai",
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_audio": true
"supports_audio_input": true,
"supports_audio_output": true
},
"gpt-4o-audio-preview-2024-10-01": {
"max_tokens": 16384,
"max_input_tokens": 128000,
"max_output_tokens": 16384,
"input_cost_per_token": 0.0000025,
"input_cost_per_audio_token": 0.0001,
"output_cost_per_token": 0.000010,
"output_cost_per_audio_token": 0.0002,
"litellm_provider": "openai",
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_audio": true
"supports_audio_input": true,
"supports_audio_output": true
},
"gpt-4o-mini": {
"max_tokens": 16384,

View file

@ -3,4 +3,3 @@ model_list:
litellm_params:
model: gpt-4o-audio-preview
api_key: os.environ/OPENAI_API_KEY

View file

@ -2035,3 +2035,14 @@ class SpecialHeaders(enum.Enum):
class LitellmDataForBackendLLMCall(TypedDict, total=False):
headers: dict
organization: str
class JWTKeyItem(TypedDict, total=False):
kid: str
JWKKeyValue = Union[List[JWTKeyItem], JWTKeyItem]
class JWKUrlResponse(TypedDict, total=False):
keys: JWKKeyValue

View file

@ -8,7 +8,7 @@ JWT token must have 'litellm_proxy_admin' in scope.
import json
import os
from typing import Optional
from typing import Optional, cast
from cryptography import x509
from cryptography.hazmat.backends import default_backend
@ -17,7 +17,7 @@ from cryptography.hazmat.primitives import serialization
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLM_UserTable
from litellm.proxy._types import JWKKeyValue, JWTKeyItem, LiteLLM_JWTAuth
from litellm.proxy.utils import PrismaClient
@ -174,7 +174,7 @@ class JWTHandler:
response_json = response.json()
if "keys" in response_json:
keys = response.json()["keys"]
keys: JWKKeyValue = response.json()["keys"]
else:
keys = response_json
@ -186,27 +186,35 @@ class JWTHandler:
else:
keys = cached_keys
public_key: Optional[dict] = None
if len(keys) == 1:
if kid is None or keys["kid"] == kid:
public_key = keys[0]
elif len(keys) > 1:
for key in keys:
if kid is not None and key == kid:
public_key = keys[key]
elif (
kid is not None
and isinstance(key, dict)
and key.get("kid", None) is not None
and key["kid"] == kid
):
public_key = key
public_key = self.parse_keys(keys=keys, kid=kid)
if public_key is None:
raise Exception(
f"No matching public key found. kid={kid}, keys_url={keys_url}, cached_keys={cached_keys}, len(keys)={len(keys)}"
)
return cast(dict, public_key)
def parse_keys(self, keys: JWKKeyValue, kid: Optional[str]) -> Optional[JWTKeyItem]:
public_key: Optional[JWTKeyItem] = None
if len(keys) == 1:
if isinstance(keys, dict) and (keys.get("kid", None) == kid or kid is None):
public_key = keys
elif isinstance(keys, list) and (
keys[0].get("kid", None) == kid or kid is None
):
public_key = keys[0]
elif len(keys) > 1:
for key in keys:
if isinstance(key, dict):
key_kid = key.get("kid", None)
else:
key_kid = None
if (
kid is not None
and isinstance(key, dict)
and key_kid is not None
and key_kid == kid
):
public_key = key
return public_key

View file

@ -1543,6 +1543,7 @@ class ProxyConfig:
## INIT PROXY REDIS USAGE CLIENT ##
redis_usage_cache = litellm.cache.cache
async def get_config(self, config_file_path: Optional[str] = None) -> dict:
"""
Load config file

View file

@ -66,6 +66,7 @@ class ModelInfo(TypedDict, total=False):
cache_creation_input_token_cost: Optional[float]
cache_read_input_token_cost: Optional[float]
input_cost_per_character: Optional[float] # only for vertex ai models
input_cost_per_audio_token: Optional[float]
input_cost_per_token_above_128k_tokens: Optional[float] # only for vertex ai models
input_cost_per_character_above_128k_tokens: Optional[
float
@ -77,6 +78,7 @@ class ModelInfo(TypedDict, total=False):
input_cost_per_second: Optional[float] # for OpenAI Speech models
output_cost_per_token: Required[float]
output_cost_per_character: Optional[float] # only for vertex ai models
output_cost_per_audio_token: Optional[float]
output_cost_per_token_above_128k_tokens: Optional[
float
] # only for vertex ai models
@ -102,6 +104,8 @@ class ModelInfo(TypedDict, total=False):
supports_function_calling: Optional[bool]
supports_assistant_prefill: Optional[bool]
supports_prompt_caching: Optional[bool]
supports_audio_input: Optional[bool]
supports_audio_output: Optional[bool]
class GenericStreamingChunk(TypedDict, total=False):
@ -139,6 +143,27 @@ class CallTypes(Enum):
arealtime = "_arealtime"
CallTypesLiteral = Literal[
"embedding",
"aembedding",
"completion",
"acompletion",
"atext_completion",
"text_completion",
"image_generation",
"aimage_generation",
"moderation",
"amoderation",
"atranscription",
"transcription",
"aspeech",
"speech",
"rerank",
"arerank",
"_arealtime",
]
class PassthroughCallTypes(Enum):
passthrough_image_generation = "passthrough-image-generation"
@ -535,6 +560,23 @@ class Choices(OpenAIObject):
setattr(self, key, value)
class CompletionTokensDetailsWrapper(
CompletionTokensDetails
): # wrapper for older openai versions
text_tokens: Optional[int] = None
"""Text tokens generated by the model."""
class PromptTokensDetailsWrapper(
PromptTokensDetails
): # wrapper for older openai versions
text_tokens: Optional[int] = None
"""Text tokens sent to the model."""
image_tokens: Optional[int] = None
"""Image tokens sent to the model."""
class Usage(CompletionUsage):
_cache_creation_input_tokens: int = PrivateAttr(
0
@ -549,23 +591,23 @@ class Usage(CompletionUsage):
completion_tokens: Optional[int] = None,
total_tokens: Optional[int] = None,
reasoning_tokens: Optional[int] = None,
prompt_tokens_details: Optional[Union[PromptTokensDetails, dict]] = None,
prompt_tokens_details: Optional[Union[PromptTokensDetailsWrapper, dict]] = None,
completion_tokens_details: Optional[
Union[CompletionTokensDetails, dict]
Union[CompletionTokensDetailsWrapper, dict]
] = None,
**params,
):
# handle reasoning_tokens
_completion_tokens_details: Optional[CompletionTokensDetails] = None
_completion_tokens_details: Optional[CompletionTokensDetailsWrapper] = None
if reasoning_tokens:
completion_tokens_details = CompletionTokensDetails(
completion_tokens_details = CompletionTokensDetailsWrapper(
reasoning_tokens=reasoning_tokens
)
# Ensure completion_tokens_details is properly handled
if completion_tokens_details:
if isinstance(completion_tokens_details, dict):
_completion_tokens_details = CompletionTokensDetails(
_completion_tokens_details = CompletionTokensDetailsWrapper(
**completion_tokens_details
)
elif isinstance(completion_tokens_details, CompletionTokensDetails):
@ -576,7 +618,7 @@ class Usage(CompletionUsage):
params["prompt_cache_hit_tokens"], int
):
if prompt_tokens_details is None:
prompt_tokens_details = PromptTokensDetails(
prompt_tokens_details = PromptTokensDetailsWrapper(
cached_tokens=params["prompt_cache_hit_tokens"]
)
@ -585,15 +627,17 @@ class Usage(CompletionUsage):
params["cache_read_input_tokens"], int
):
if prompt_tokens_details is None:
prompt_tokens_details = PromptTokensDetails(
prompt_tokens_details = PromptTokensDetailsWrapper(
cached_tokens=params["cache_read_input_tokens"]
)
# handle prompt_tokens_details
_prompt_tokens_details: Optional[PromptTokensDetails] = None
_prompt_tokens_details: Optional[PromptTokensDetailsWrapper] = None
if prompt_tokens_details:
if isinstance(prompt_tokens_details, dict):
_prompt_tokens_details = PromptTokensDetails(**prompt_tokens_details)
_prompt_tokens_details = PromptTokensDetailsWrapper(
**prompt_tokens_details
)
elif isinstance(prompt_tokens_details, PromptTokensDetails):
_prompt_tokens_details = prompt_tokens_details

View file

@ -1834,6 +1834,54 @@ def supports_function_calling(
)
def _supports_factory(model: str, custom_llm_provider: Optional[str], key: str) -> bool:
"""
Check if the given model supports function calling and return a boolean value.
Parameters:
model (str): The model name to be checked.
custom_llm_provider (Optional[str]): The provider to be checked.
Returns:
bool: True if the model supports function calling, False otherwise.
Raises:
Exception: If the given model is not found or there's an error in retrieval.
"""
try:
model, custom_llm_provider, _, _ = litellm.get_llm_provider(
model=model, custom_llm_provider=custom_llm_provider
)
model_info = litellm.get_model_info(
model=model, custom_llm_provider=custom_llm_provider
)
if model_info.get(key, False) is True:
return True
return False
except Exception as e:
raise Exception(
f"Model not found or error in checking {key} support. You passed model={model}, custom_llm_provider={custom_llm_provider}. Error: {str(e)}"
)
def supports_audio_input(model: str, custom_llm_provider: Optional[str] = None) -> bool:
"""Check if a given model supports audio input in a chat completion call"""
return _supports_factory(
model=model, custom_llm_provider=custom_llm_provider, key="supports_audio_input"
)
def supports_audio_output(
model: str, custom_llm_provider: Optional[str] = None
) -> bool:
"""Check if a given model supports audio output in a chat completion call"""
return _supports_factory(
model=model, custom_llm_provider=custom_llm_provider, key="supports_audio_input"
)
def supports_prompt_caching(
model: str, custom_llm_provider: Optional[str] = None
) -> bool:
@ -4601,9 +4649,11 @@ def get_model_info( # noqa: PLR0915
] # only for vertex ai models
input_cost_per_query: Optional[float] # only for rerank models
input_cost_per_image: Optional[float] # only for vertex ai models
input_cost_per_audio_token: Optional[float]
input_cost_per_audio_per_second: Optional[float] # only for vertex ai models
input_cost_per_video_per_second: Optional[float] # only for vertex ai models
output_cost_per_token: Required[float]
output_cost_per_audio_token: Optional[float]
output_cost_per_character: Optional[float] # only for vertex ai models
output_cost_per_token_above_128k_tokens: Optional[
float
@ -4627,6 +4677,8 @@ def get_model_info( # noqa: PLR0915
supports_vision: Optional[bool]
supports_function_calling: Optional[bool]
supports_prompt_caching: Optional[bool]
supports_audio_input: Optional[bool]
supports_audio_output: Optional[bool]
Raises:
Exception: If the model is not mapped yet.
@ -4870,7 +4922,13 @@ def get_model_info( # noqa: PLR0915
),
input_cost_per_query=_model_info.get("input_cost_per_query", None),
input_cost_per_second=_model_info.get("input_cost_per_second", None),
input_cost_per_audio_token=_model_info.get(
"input_cost_per_audio_token", None
),
output_cost_per_token=_output_cost_per_token,
output_cost_per_audio_token=_model_info.get(
"output_cost_per_audio_token", None
),
output_cost_per_character=_model_info.get(
"output_cost_per_character", None
),
@ -4903,6 +4961,8 @@ def get_model_info( # noqa: PLR0915
supports_prompt_caching=_model_info.get(
"supports_prompt_caching", False
),
supports_audio_input=_model_info.get("supports_audio_input", False),
supports_audio_output=_model_info.get("supports_audio_output", False),
)
except Exception:
raise Exception(

View file

@ -10,7 +10,8 @@
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_vision": true,
"supports_audio": true,
"supports_audio_input": true,
"supports_audio_output": true,
"supports_prompt_caching": true
},
"gpt-4": {
@ -43,24 +44,30 @@
"max_input_tokens": 128000,
"max_output_tokens": 16384,
"input_cost_per_token": 0.0000025,
"input_cost_per_audio_token": 0.0001,
"output_cost_per_token": 0.000010,
"output_cost_per_audio_token": 0.0002,
"litellm_provider": "openai",
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_audio": true
"supports_audio_input": true,
"supports_audio_output": true
},
"gpt-4o-audio-preview-2024-10-01": {
"max_tokens": 16384,
"max_input_tokens": 128000,
"max_output_tokens": 16384,
"input_cost_per_token": 0.0000025,
"input_cost_per_audio_token": 0.0001,
"output_cost_per_token": 0.000010,
"output_cost_per_audio_token": 0.0002,
"litellm_provider": "openai",
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_audio": true
"supports_audio_input": true,
"supports_audio_output": true
},
"gpt-4o-mini": {
"max_tokens": 16384,

View file

@ -84,13 +84,41 @@ def test_bedrock_optional_params_embeddings():
],
)
def test_bedrock_optional_params_completions(model):
litellm.drop_params = True
tools = [
{
"type": "function",
"function": {
"name": "structure_output",
"description": "Send structured output back to the user",
"strict": True,
"parameters": {
"type": "object",
"properties": {
"reasoning": {"type": "string"},
"sentiment": {"type": "string"},
},
"required": ["reasoning", "sentiment"],
"additionalProperties": False,
},
"additionalProperties": False,
},
}
]
optional_params = get_optional_params(
model=model, max_tokens=10, temperature=0.1, custom_llm_provider="bedrock"
model=model,
max_tokens=10,
temperature=0.1,
tools=tools,
custom_llm_provider="bedrock",
)
print(f"optional_params: {optional_params}")
assert len(optional_params) == 3
assert optional_params == {"maxTokens": 10, "stream": False, "temperature": 0.1}
assert len(optional_params) == 4
assert optional_params == {
"maxTokens": 10,
"stream": False,
"temperature": 0.1,
"tools": tools,
}
@pytest.mark.parametrize(

File diff suppressed because one or more lines are too long

View file

@ -993,3 +993,29 @@ async def test_allow_access_by_email(public_jwt_key, user_email, should_work):
): # Replace with the actual exception raised on failure
resp = await user_api_key_auth(request=request, api_key=bearer_token)
print(resp)
def test_get_public_key_from_jwk_url():
import litellm
from litellm.proxy.auth.handle_jwt import JWTHandler
jwt_handler = JWTHandler()
jwk_response = [
{
"kty": "RSA",
"alg": "RS256",
"kid": "RaPJB8QVptWHjHcoHkVlUWO4f0D3BtcY6iSDXgGVBgk",
"use": "sig",
"e": "AQAB",
"n": "zgLDu57gLpkzzIkKrTKQVyjK8X40hvu6X_JOeFjmYmI0r3bh7FTOmre5rTEkDOL-1xvQguZAx4hjKmCzBU5Kz84FbsGiqM0ug19df4kwdTS6XOM6YEKUZrbaw4P7xTPsbZj7W2G_kxWNm3Xaxq6UKFdUF7n9snnBKKD6iUA-cE6HfsYmt9OhYZJfy44dbAbuanFmAsWw97SHrPFL3ueh3Ixt19KgpF4iSsXNg3YvoesdFM8psmivgePyyHA8k7pK1Yq7rNQX1Q9nzhvP-F7ocFbP52KYPlaSTu30YwPTVTFKYpDNmHT1fZ7LXZZNLrP_7-NSY76HS2ozSpzjsGVelQ",
}
]
public_key = jwt_handler.parse_keys(
keys=jwk_response,
kid="RaPJB8QVptWHjHcoHkVlUWO4f0D3BtcY6iSDXgGVBgk",
)
assert public_key is not None
assert public_key == jwk_response[0]

View file

@ -833,3 +833,17 @@ def test_is_base64_encoded():
from litellm.utils import is_base64_encoded
assert is_base64_encoded(s=base64_image) is True
@pytest.mark.parametrize(
"model, expected_bool", [("gpt-3.5-turbo", False), ("gpt-4o-audio-preview", True)]
)
def test_supports_audio_input(model, expected_bool):
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
from litellm.utils import supports_audio_input, supports_audio_output
supports_pc = supports_audio_input(model=model)
assert supports_pc == expected_bool