LiteLLM Minor Fixes & Improvements (12/23/2024) - p3 (#7394)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 35s

* build(model_prices_and_context_window.json): add gemini-1.5-flash context caching

* fix(context_caching/transformation.py): just use last identified cache point

Fixes https://github.com/BerriAI/litellm/issues/6738

* fix(context_caching/transformation.py): pick first contiguous block - handles system message error from google

Fixes https://github.com/BerriAI/litellm/issues/6738

* fix(vertex_ai/gemini/): track context caching tokens

* refactor(gemini/): place transformation.py inside `chat/` folder

make it easy for user to know we support the equivalent endpoint

* fix: fix import

* refactor(vertex_ai/): move vertex_ai cost calc inside vertex_ai/ folder

make it easier to see cost calculation logic

* fix: fix linting errors

* fix: fix circular import

* feat(gemini/cost_calculator.py): support gemini context caching cost calculation

generifies anthropic's cost calculation function and uses it across anthropic + gemini

* build(model_prices_and_context_window.json): add cost tracking for gemini-1.5-flash-002 w/ context caching

Closes https://github.com/BerriAI/litellm/issues/6891

* docs(gemini.md): add gemini context caching architecture diagram

make it easier for user to understand how context caching works

* docs(gemini.md): link to relevant gemini context caching code

* docs(gemini/context_caching): add readme in github, make it easy for dev to know context caching is supported + where to go for code

* fix(llm_cost_calc/utils.py): handle gemini 128k token diff cost calc scenario

* fix(deepseek/cost_calculator.py): support deepseek context caching cost calculation

* test: fix test
This commit is contained in:
Krish Dholakia 2024-12-23 22:02:52 -08:00 committed by GitHub
parent 442d309bcd
commit c3edfc2c92
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 719 additions and 447 deletions

View file

@ -10,7 +10,8 @@ import TabItem from '@theme/TabItem';
| Provider Route on LiteLLM | `gemini/` |
| Provider Doc | [Google AI Studio ↗](https://ai.google.dev/aistudio) |
| API Endpoint for Provider | https://generativelanguage.googleapis.com |
| Supported Endpoints | `/chat/completions`, `/embeddings` |
| Supported OpenAI Endpoints | `/chat/completions`, `/embeddings`, `/completions` |
| Pass-through Endpoint | [Supported](../pass_through/google_ai_studio.md) |
<br />
@ -552,175 +553,6 @@ content = response.get('choices', [{}])[0].get('message', {}).get('content')
print(content)
```
## Context Caching
Use Google AI Studio context caching is supported by
```bash
{
...,
"cache_control": {"type": "ephemeral"}
}
```
in your message content block.
:::note
Gemini Context Caching only allows 1 block of continuous messages to be cached.
The raw request to Gemini looks like this:
```bash
curl -X POST "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash-001:generateContent?key=$GOOGLE_API_KEY" \
-H 'Content-Type: application/json' \
-d '{
"contents": [
{
"parts":[{
"text": "Please summarize this transcript"
}],
"role": "user"
},
],
"cachedContent": "'$CACHE_NAME'"
}'
```
:::
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import completion
for _ in range(2):
resp = completion(
model="gemini/gemini-1.5-pro",
messages=[
# System Message
{
"role": "system",
"content": [
{
"type": "text",
"text": "Here is the full text of a complex legal agreement" * 4000,
"cache_control": {"type": "ephemeral"}, # 👈 KEY CHANGE
}
],
},
# marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache.
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key terms and conditions in this agreement?",
"cache_control": {"type": "ephemeral"},
}
],
}]
)
print(resp.usage) # 👈 2nd usage block will be less, since cached tokens used
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Setup config.yaml
```yaml
model_list:
- model_name: gemini-1.5-pro
litellm_params:
model: gemini/gemini-1.5-pro
api_key: os.environ/GEMINI_API_KEY
```
2. Start proxy
```bash
litellm --config /path/to/config.yaml
```
3. Test it!
[**See Langchain, OpenAI JS, Llamaindex, etc. examples**](../proxy/user_keys.md#request-format)
<Tabs>
<TabItem value="curl" label="Curl">
```bash
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--data '{
"model": "gemini-1.5-pro",
"messages": [
# System Message
{
"role": "system",
"content": [
{
"type": "text",
"text": "Here is the full text of a complex legal agreement" * 4000,
"cache_control": {"type": "ephemeral"}, # 👈 KEY CHANGE
}
],
},
# marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache.
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key terms and conditions in this agreement?",
"cache_control": {"type": "ephemeral"},
}
],
}],
}'
```
</TabItem>
<TabItem value="openai-python" label="OpenAI Python SDK">
```python
import openai
client = openai.AsyncOpenAI(
api_key="anything", # litellm proxy api key
base_url="http://0.0.0.0:4000" # litellm proxy base url
)
response = await client.chat.completions.create(
model="gemini-1.5-pro",
messages=[
{
"role": "system",
"content": [
{
"type": "text",
"text": "Here is the full text of a complex legal agreement" * 4000,
"cache_control": {"type": "ephemeral"}, # 👈 KEY CHANGE
}
],
},
{
"role": "user",
"content": "what are the key terms and conditions in this agreement?",
},
]
)
```
</TabItem>
</Tabs>
</TabItem>
</Tabs>
## Usage - PDF / Videos / etc. Files
### Inline Data (e.g. audio stream)
@ -857,3 +689,191 @@ response = litellm.completion(
| gemini-pro | `completion(model='gemini/gemini-pro', messages)` | `os.environ['GEMINI_API_KEY']` |
| gemini-1.5-pro-latest | `completion(model='gemini/gemini-1.5-pro-latest', messages)` | `os.environ['GEMINI_API_KEY']` |
| gemini-pro-vision | `completion(model='gemini/gemini-pro-vision', messages)` | `os.environ['GEMINI_API_KEY']` |
## Context Caching
Use Google AI Studio context caching is supported by
```bash
{
{
"role": "system",
"content": ...,
"cache_control": {"type": "ephemeral"} # 👈 KEY CHANGE
},
...
}
```
in your message content block.
### Architecture Diagram
<Image img={require('../../img/gemini_context_caching.png')} />
**Notes:**
- [Relevant code](https://github.com/BerriAI/litellm/blob/main/litellm/llms/vertex_ai/context_caching/vertex_ai_context_caching.py#L255)
- Gemini Context Caching only allows 1 block of continuous messages to be cached.
- If multiple non-continuous blocks contain `cache_control` - the first continuous block will be used. (sent to `/cachedContent` in the [Gemini format](https://ai.google.dev/api/caching#cache_create-SHELL))
- The raw request to Gemini's `/generateContent` endpoint looks like this:
```bash
curl -X POST "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash-001:generateContent?key=$GOOGLE_API_KEY" \
-H 'Content-Type: application/json' \
-d '{
"contents": [
{
"parts":[{
"text": "Please summarize this transcript"
}],
"role": "user"
},
],
"cachedContent": "'$CACHE_NAME'"
}'
```
### Example Usage
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import completion
for _ in range(2):
resp = completion(
model="gemini/gemini-1.5-pro",
messages=[
# System Message
{
"role": "system",
"content": [
{
"type": "text",
"text": "Here is the full text of a complex legal agreement" * 4000,
"cache_control": {"type": "ephemeral"}, # 👈 KEY CHANGE
}
],
},
# marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache.
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key terms and conditions in this agreement?",
"cache_control": {"type": "ephemeral"},
}
],
}]
)
print(resp.usage) # 👈 2nd usage block will be less, since cached tokens used
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Setup config.yaml
```yaml
model_list:
- model_name: gemini-1.5-pro
litellm_params:
model: gemini/gemini-1.5-pro
api_key: os.environ/GEMINI_API_KEY
```
2. Start proxy
```bash
litellm --config /path/to/config.yaml
```
3. Test it!
[**See Langchain, OpenAI JS, Llamaindex, etc. examples**](../proxy/user_keys.md#request-format)
<Tabs>
<TabItem value="curl" label="Curl">
```bash
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--data '{
"model": "gemini-1.5-pro",
"messages": [
# System Message
{
"role": "system",
"content": [
{
"type": "text",
"text": "Here is the full text of a complex legal agreement" * 4000,
"cache_control": {"type": "ephemeral"}, # 👈 KEY CHANGE
}
],
},
# marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache.
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key terms and conditions in this agreement?",
"cache_control": {"type": "ephemeral"},
}
],
}],
}'
```
</TabItem>
<TabItem value="openai-python" label="OpenAI Python SDK">
```python
import openai
client = openai.AsyncOpenAI(
api_key="anything", # litellm proxy api key
base_url="http://0.0.0.0:4000" # litellm proxy base url
)
response = await client.chat.completions.create(
model="gemini-1.5-pro",
messages=[
{
"role": "system",
"content": [
{
"type": "text",
"text": "Here is the full text of a complex legal agreement" * 4000,
"cache_control": {"type": "ephemeral"}, # 👈 KEY CHANGE
}
],
},
{
"role": "user",
"content": "what are the key terms and conditions in this agreement?",
},
]
)
```
</TabItem>
</Tabs>
</TabItem>
</Tabs>

Binary file not shown.

After

Width:  |  Height:  |  Size: 314 KiB

View file

@ -1049,9 +1049,11 @@ from .llms.petals.completion.transformation import PetalsConfig
from .llms.deprecated_providers.aleph_alpha import AlephAlphaConfig
from .llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
VertexGeminiConfig,
GoogleAIStudioGeminiConfig,
VertexAIConfig,
GoogleAIStudioGeminiConfig as GeminiConfig,
)
from .llms.gemini.chat.transformation import (
GoogleAIStudioGeminiConfig,
GoogleAIStudioGeminiConfig as GeminiConfig, # aliased to maintain backwards compatibility
)

View file

@ -8,15 +8,6 @@ from pydantic import BaseModel
import litellm
import litellm._logging
from litellm import verbose_logger
from litellm.litellm_core_utils.llm_cost_calc.google import (
cost_per_character as google_cost_per_character,
)
from litellm.litellm_core_utils.llm_cost_calc.google import (
cost_per_token as google_cost_per_token,
)
from litellm.litellm_core_utils.llm_cost_calc.google import (
cost_router as google_cost_router,
)
from litellm.litellm_core_utils.llm_cost_calc.utils import _generic_cost_per_character
from litellm.llms.anthropic.cost_calculation import (
cost_per_token as anthropic_cost_per_token,
@ -36,14 +27,25 @@ from litellm.llms.cohere.cost_calculator import (
from litellm.llms.databricks.cost_calculator import (
cost_per_token as databricks_cost_per_token,
)
from litellm.llms.deepseek.cost_calculator import (
cost_per_token as deepseek_cost_per_token,
)
from litellm.llms.fireworks_ai.cost_calculator import (
cost_per_token as fireworks_ai_cost_per_token,
)
from litellm.llms.gemini.cost_calculator import cost_per_token as gemini_cost_per_token
from litellm.llms.openai.cost_calculation import (
cost_per_second as openai_cost_per_second,
)
from litellm.llms.openai.cost_calculation import cost_per_token as openai_cost_per_token
from litellm.llms.together_ai.cost_calculator import get_model_params_and_category
from litellm.llms.vertex_ai.cost_calculator import (
cost_per_character as google_cost_per_character,
)
from litellm.llms.vertex_ai.cost_calculator import (
cost_per_token as google_cost_per_token,
)
from litellm.llms.vertex_ai.cost_calculator import cost_router as google_cost_router
from litellm.llms.vertex_ai.image_generation.cost_calculator import (
cost_calculator as vertex_ai_image_cost_calculator,
)
@ -272,12 +274,9 @@ def cost_per_token( # noqa: PLR0915
model=model, usage=usage_block, response_time_ms=response_time_ms
)
elif custom_llm_provider == "gemini":
return google_cost_per_token(
model=model_without_prefix,
custom_llm_provider=custom_llm_provider,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
return gemini_cost_per_token(model=model, usage=usage_block)
elif custom_llm_provider == "deepseek":
return deepseek_cost_per_token(model=model, usage=usage_block)
else:
model_info = litellm.get_model_info(
model=model, custom_llm_provider=custom_llm_provider

View file

@ -5,6 +5,14 @@ from typing import Optional, Tuple
import litellm
from litellm import verbose_logger
from litellm.types.utils import ModelInfo, Usage
from litellm.utils import get_model_info
def _is_above_128k(tokens: float) -> bool:
if tokens > 128000:
return True
return False
def _generic_cost_per_character(
@ -80,3 +88,93 @@ def _generic_cost_per_character(
completion_cost = None
return prompt_cost, completion_cost
def _get_prompt_token_base_cost(model_info: ModelInfo, usage: Usage) -> float:
"""
Return prompt cost for a given model and usage.
If input_tokens > 128k and `input_cost_per_token_above_128k_tokens` is set, then we use the `input_cost_per_token_above_128k_tokens` field.
"""
input_cost_per_token_above_128k_tokens = model_info.get(
"input_cost_per_token_above_128k_tokens"
)
if _is_above_128k(usage.prompt_tokens) and input_cost_per_token_above_128k_tokens:
return input_cost_per_token_above_128k_tokens
return model_info["input_cost_per_token"]
def _get_completion_token_base_cost(model_info: ModelInfo, usage: Usage) -> float:
"""
Return prompt cost for a given model and usage.
If input_tokens > 128k and `input_cost_per_token_above_128k_tokens` is set, then we use the `input_cost_per_token_above_128k_tokens` field.
"""
output_cost_per_token_above_128k_tokens = model_info.get(
"output_cost_per_token_above_128k_tokens"
)
if (
_is_above_128k(usage.completion_tokens)
and output_cost_per_token_above_128k_tokens
):
return output_cost_per_token_above_128k_tokens
return model_info["output_cost_per_token"]
def generic_cost_per_token(
model: str, usage: Usage, custom_llm_provider: str
) -> Tuple[float, float]:
"""
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
Handles context caching as well.
Input:
- model: str, the model name without provider prefix
- usage: LiteLLM Usage block, containing anthropic caching information
Returns:
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
"""
## GET MODEL INFO
model_info = get_model_info(model=model, custom_llm_provider=custom_llm_provider)
## CALCULATE INPUT COST
### Cost of processing (non-cache hit + cache hit) + Cost of cache-writing (cache writing)
prompt_cost = 0.0
### PROCESSING COST
non_cache_hit_tokens = usage.prompt_tokens
cache_hit_tokens = 0
if usage.prompt_tokens_details and usage.prompt_tokens_details.cached_tokens:
cache_hit_tokens = usage.prompt_tokens_details.cached_tokens
non_cache_hit_tokens = non_cache_hit_tokens - cache_hit_tokens
prompt_base_cost = _get_prompt_token_base_cost(model_info=model_info, usage=usage)
prompt_cost = float(non_cache_hit_tokens) * prompt_base_cost
_cache_read_input_token_cost = model_info.get("cache_read_input_token_cost")
if (
_cache_read_input_token_cost is not None
and usage.prompt_tokens_details
and usage.prompt_tokens_details.cached_tokens
):
prompt_cost += (
float(usage.prompt_tokens_details.cached_tokens)
* _cache_read_input_token_cost
)
### CACHE WRITING COST
_cache_creation_input_token_cost = model_info.get("cache_creation_input_token_cost")
if _cache_creation_input_token_cost is not None:
prompt_cost += (
float(usage._cache_creation_input_tokens) * _cache_creation_input_token_cost
)
## CALCULATE OUTPUT COST
completion_base_cost = _get_completion_token_base_cost(
model_info=model_info, usage=usage
)
completion_cost = usage["completion_tokens"] * completion_base_cost
return prompt_cost, completion_cost

View file

@ -5,8 +5,8 @@ Helper util for handling anthropic-specific cost calculation
from typing import Tuple
from litellm.litellm_core_utils.llm_cost_calc.utils import generic_cost_per_token
from litellm.types.utils import Usage
from litellm.utils import get_model_info
def cost_per_token(model: str, usage: Usage) -> Tuple[float, float]:
@ -20,40 +20,6 @@ def cost_per_token(model: str, usage: Usage) -> Tuple[float, float]:
Returns:
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
"""
## GET MODEL INFO
model_info = get_model_info(model=model, custom_llm_provider="anthropic")
## CALCULATE INPUT COST
### Cost of processing (non-cache hit + cache hit) + Cost of cache-writing (cache writing)
prompt_cost = 0.0
### PROCESSING COST
non_cache_hit_tokens = usage.prompt_tokens
cache_hit_tokens = 0
if usage.prompt_tokens_details and usage.prompt_tokens_details.cached_tokens:
cache_hit_tokens = usage.prompt_tokens_details.cached_tokens
non_cache_hit_tokens = non_cache_hit_tokens - cache_hit_tokens
prompt_cost = float(non_cache_hit_tokens) * model_info["input_cost_per_token"]
_cache_read_input_token_cost = model_info.get("cache_read_input_token_cost")
if (
_cache_read_input_token_cost is not None
and usage.prompt_tokens_details
and usage.prompt_tokens_details.cached_tokens
):
prompt_cost += (
float(usage.prompt_tokens_details.cached_tokens)
* _cache_read_input_token_cost
)
### CACHE WRITING COST
_cache_creation_input_token_cost = model_info.get("cache_creation_input_token_cost")
if _cache_creation_input_token_cost is not None:
prompt_cost += (
float(usage._cache_creation_input_tokens) * _cache_creation_input_token_cost
)
## CALCULATE OUTPUT COST
completion_cost = usage["completion_tokens"] * model_info["output_cost_per_token"]
return prompt_cost, completion_cost
return generic_cost_per_token(
model=model, usage=usage, custom_llm_provider="anthropic"
)

View file

@ -0,0 +1,21 @@
"""
Cost calculator for DeepSeek Chat models.
Handles prompt caching scenario.
"""
from typing import Tuple
from litellm.litellm_core_utils.llm_cost_calc.utils import generic_cost_per_token
from litellm.types.utils import Usage
def cost_per_token(model: str, usage: Usage) -> Tuple[float, float]:
"""
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
Follows the same logic as Anthropic's cost per token calculation.
"""
return generic_cost_per_token(
model=model, usage=usage, custom_llm_provider="deepseek"
)

View file

@ -0,0 +1,131 @@
from typing import Dict, List, Optional
from litellm.litellm_core_utils.prompt_templates.factory import (
convert_generic_image_chunk_to_openai_image_obj,
convert_to_anthropic_image_obj,
)
from litellm.types.llms.openai import AllMessageValues
from litellm.types.llms.vertex_ai import ContentType, PartType
from ...vertex_ai.gemini.transformation import _gemini_convert_messages_with_history
from ...vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexGeminiConfig
class GoogleAIStudioGeminiConfig(
VertexGeminiConfig
): # key diff from VertexAI - 'frequency_penalty' and 'presence_penalty' not supported
"""
Reference: https://ai.google.dev/api/rest/v1beta/GenerationConfig
The class `GoogleAIStudioGeminiConfig` provides configuration for the Google AI Studio's Gemini API interface. Below are the parameters:
- `temperature` (float): This controls the degree of randomness in token selection.
- `max_output_tokens` (integer): This sets the limitation for the maximum amount of token in the text output. In this case, the default value is 256.
- `top_p` (float): The tokens are selected from the most probable to the least probable until the sum of their probabilities equals the `top_p` value. Default is 0.95.
- `top_k` (integer): The value of `top_k` determines how many of the most probable tokens are considered in the selection. For example, a `top_k` of 1 means the selected token is the most probable among all tokens. The default value is 40.
- `response_mime_type` (str): The MIME type of the response. The default value is 'text/plain'. Other values - `application/json`.
- `response_schema` (dict): Optional. Output response schema of the generated candidate text when response mime type can have schema. Schema can be objects, primitives or arrays and is a subset of OpenAPI schema. If set, a compatible response_mime_type must also be set. Compatible mimetypes: application/json: Schema for JSON response.
- `candidate_count` (int): Number of generated responses to return.
- `stop_sequences` (List[str]): The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop at the first appearance of a stop sequence. The stop sequence will not be included as part of the response.
Note: Please make sure to modify the default parameters as required for your use case.
"""
temperature: Optional[float] = None
max_output_tokens: Optional[int] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
response_mime_type: Optional[str] = None
response_schema: Optional[dict] = None
candidate_count: Optional[int] = None
stop_sequences: Optional[list] = None
def __init__(
self,
temperature: Optional[float] = None,
max_output_tokens: Optional[int] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
response_mime_type: Optional[str] = None,
response_schema: Optional[dict] = None,
candidate_count: Optional[int] = None,
stop_sequences: Optional[list] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return super().get_config()
def get_supported_openai_params(self, model: str) -> List[str]:
return [
"temperature",
"top_p",
"max_tokens",
"max_completion_tokens",
"stream",
"tools",
"tool_choice",
"functions",
"response_format",
"n",
"stop",
"logprobs",
]
def map_openai_params(
self,
non_default_params: Dict,
optional_params: Dict,
model: str,
drop_params: bool,
) -> Dict:
# drop frequency_penalty and presence_penalty
if "frequency_penalty" in non_default_params:
del non_default_params["frequency_penalty"]
if "presence_penalty" in non_default_params:
del non_default_params["presence_penalty"]
return super().map_openai_params(
model=model,
non_default_params=non_default_params,
optional_params=optional_params,
drop_params=drop_params,
)
def _transform_messages(
self, messages: List[AllMessageValues]
) -> List[ContentType]:
"""
Google AI Studio Gemini does not support image urls in messages.
"""
for message in messages:
_message_content = message.get("content")
if _message_content is not None and isinstance(_message_content, list):
_parts: List[PartType] = []
for element in _message_content:
if element.get("type") == "image_url":
img_element = element
_image_url: Optional[str] = None
if isinstance(img_element.get("image_url"), dict):
_image_url = img_element["image_url"].get("url") # type: ignore
else:
_image_url = img_element.get("image_url") # type: ignore
if _image_url and "https://" in _image_url:
image_obj = convert_to_anthropic_image_obj(_image_url)
img_element["image_url"] = ( # type: ignore
convert_generic_image_chunk_to_openai_image_obj(
image_obj
)
)
return _gemini_convert_messages_with_history(messages=messages)

View file

@ -0,0 +1 @@
[Go here for the Gemini Context Caching code](../../vertex_ai/context_caching/)

View file

@ -0,0 +1,21 @@
"""
This file is used to calculate the cost of the Gemini API.
Handles the context caching for Gemini API.
"""
from typing import Tuple
from litellm.litellm_core_utils.llm_cost_calc.utils import generic_cost_per_token
from litellm.types.utils import Usage
def cost_per_token(model: str, usage: Usage) -> Tuple[float, float]:
"""
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
Follows the same logic as Anthropic's cost per token calculation.
"""
return generic_cost_per_token(
model=model, usage=usage, custom_llm_provider="gemini"
)

View file

@ -10,13 +10,43 @@ from litellm.types.llms.openai import AllMessageValues
from litellm.types.llms.vertex_ai import CachedContentRequestBody
from litellm.utils import is_cached_message
from ..common_utils import VertexAIError, get_supports_system_message
from ..common_utils import get_supports_system_message
from ..gemini.transformation import (
_gemini_convert_messages_with_history,
_transform_system_message,
)
def get_first_continuous_block_idx(
filtered_messages: List[Tuple[int, AllMessageValues]] # (idx, message)
) -> int:
"""
Find the array index that ends the first continuous sequence of message blocks.
Args:
filtered_messages: List of tuples containing (index, message) pairs
Returns:
int: The array index where the first continuous sequence ends
"""
if not filtered_messages:
return -1
if len(filtered_messages) == 1:
return 0
current_value = filtered_messages[0][0]
# Search forward through the array indices
for i in range(1, len(filtered_messages)):
if filtered_messages[i][0] != current_value + 1:
return i - 1
current_value = filtered_messages[i][0]
# If we made it through the whole list, return the last index
return len(filtered_messages) - 1
def separate_cached_messages(
messages: List[AllMessageValues],
) -> Tuple[List[AllMessageValues], List[AllMessageValues]]:
@ -41,22 +71,11 @@ def separate_cached_messages(
filtered_messages.append((idx, message))
# Validate only one block of continuous cached messages
if len(filtered_messages) > 1:
expected_idx = filtered_messages[0][0] + 1
for idx, _ in filtered_messages[1:]:
if idx != expected_idx:
raise VertexAIError(
status_code=422,
message="Gemini Context Caching only supports 1 message/block of continuous messages. Your idx, messages were - {}".format(
filtered_messages
),
)
expected_idx += 1
last_continuous_block_idx = get_first_continuous_block_idx(filtered_messages)
# Separate messages based on the block of cached messages
if filtered_messages:
if filtered_messages and last_continuous_block_idx is not None:
first_cached_idx = filtered_messages[0][0]
last_cached_idx = filtered_messages[-1][0]
last_cached_idx = filtered_messages[last_continuous_block_idx][0]
cached_messages = messages[first_cached_idx : last_cached_idx + 1]
non_cached_messages = (

View file

@ -4,6 +4,7 @@ from typing import Literal, Optional, Tuple, Union
import litellm
from litellm import verbose_logger
from litellm.litellm_core_utils.llm_cost_calc.utils import _is_above_128k
"""
Gemini pricing covers:
@ -22,12 +23,6 @@ Google AI Studio -> token based pricing
models_without_dynamic_pricing = ["gemini-1.0-pro", "gemini-pro"]
def _is_above_128k(tokens: float) -> bool:
if tokens > 128000:
return True
return False
def cost_router(
model: str,
custom_llm_provider: str,
@ -47,8 +42,6 @@ def cost_router(
or "codestral" in model
):
return "cost_per_token"
elif custom_llm_provider == "gemini":
return "cost_per_token"
elif custom_llm_provider == "vertex_ai" and (
call_type == "embedding" or call_type == "aembedding"
):

View file

@ -26,10 +26,6 @@ import litellm.litellm_core_utils
import litellm.litellm_core_utils.litellm_logging
from litellm import verbose_logger
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.litellm_core_utils.prompt_templates.factory import (
convert_generic_image_chunk_to_openai_image_obj,
convert_to_anthropic_image_obj,
)
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
@ -52,7 +48,6 @@ from litellm.types.llms.vertex_ai import (
GenerateContentResponseBody,
HttpxPartType,
LogprobsResult,
PartType,
ToolConfig,
Tools,
)
@ -60,7 +55,9 @@ from litellm.types.utils import (
ChatCompletionTokenLogprob,
ChoiceLogprobs,
GenericStreamingChunk,
PromptTokensDetailsWrapper,
TopLogprob,
Usage,
)
from litellm.utils import CustomStreamWrapper, ModelResponse
@ -679,7 +676,7 @@ class VertexGeminiConfig(BaseConfig):
model_response.choices = [choice]
## GET USAGE ##
usage = litellm.Usage(
usage = Usage(
prompt_tokens=completion_response["usageMetadata"].get(
"promptTokenCount", 0
),
@ -717,7 +714,7 @@ class VertexGeminiConfig(BaseConfig):
model_response.choices = [choice]
## GET USAGE ##
usage = litellm.Usage(
usage = Usage(
prompt_tokens=completion_response["usageMetadata"].get(
"promptTokenCount", 0
),
@ -731,6 +728,35 @@ class VertexGeminiConfig(BaseConfig):
return model_response
def _calculate_usage(
self,
completion_response: GenerateContentResponseBody,
) -> Usage:
cached_tokens: Optional[int] = None
prompt_tokens_details: Optional[PromptTokensDetailsWrapper] = None
if "cachedContentTokenCount" in completion_response["usageMetadata"]:
cached_tokens = completion_response["usageMetadata"][
"cachedContentTokenCount"
]
if cached_tokens is not None:
prompt_tokens_details = PromptTokensDetailsWrapper(
cached_tokens=cached_tokens,
)
## GET USAGE ##
usage = Usage(
prompt_tokens=completion_response["usageMetadata"].get(
"promptTokenCount", 0
),
completion_tokens=completion_response["usageMetadata"].get(
"candidatesTokenCount", 0
),
total_tokens=completion_response["usageMetadata"].get("totalTokenCount", 0),
prompt_tokens_details=prompt_tokens_details,
)
return usage
def transform_response(
self,
model: str,
@ -854,19 +880,7 @@ class VertexGeminiConfig(BaseConfig):
model_response.choices.append(choice)
## GET USAGE ##
usage = litellm.Usage(
prompt_tokens=completion_response["usageMetadata"].get(
"promptTokenCount", 0
),
completion_tokens=completion_response["usageMetadata"].get(
"candidatesTokenCount", 0
),
total_tokens=completion_response["usageMetadata"].get(
"totalTokenCount", 0
),
)
usage = self._calculate_usage(completion_response=completion_response)
setattr(model_response, "usage", usage)
## ADD GROUNDING METADATA ##
@ -943,126 +957,6 @@ class VertexGeminiConfig(BaseConfig):
return default_headers
class GoogleAIStudioGeminiConfig(
VertexGeminiConfig
): # key diff from VertexAI - 'frequency_penalty' and 'presence_penalty' not supported
"""
Reference: https://ai.google.dev/api/rest/v1beta/GenerationConfig
The class `GoogleAIStudioGeminiConfig` provides configuration for the Google AI Studio's Gemini API interface. Below are the parameters:
- `temperature` (float): This controls the degree of randomness in token selection.
- `max_output_tokens` (integer): This sets the limitation for the maximum amount of token in the text output. In this case, the default value is 256.
- `top_p` (float): The tokens are selected from the most probable to the least probable until the sum of their probabilities equals the `top_p` value. Default is 0.95.
- `top_k` (integer): The value of `top_k` determines how many of the most probable tokens are considered in the selection. For example, a `top_k` of 1 means the selected token is the most probable among all tokens. The default value is 40.
- `response_mime_type` (str): The MIME type of the response. The default value is 'text/plain'. Other values - `application/json`.
- `response_schema` (dict): Optional. Output response schema of the generated candidate text when response mime type can have schema. Schema can be objects, primitives or arrays and is a subset of OpenAPI schema. If set, a compatible response_mime_type must also be set. Compatible mimetypes: application/json: Schema for JSON response.
- `candidate_count` (int): Number of generated responses to return.
- `stop_sequences` (List[str]): The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop at the first appearance of a stop sequence. The stop sequence will not be included as part of the response.
Note: Please make sure to modify the default parameters as required for your use case.
"""
temperature: Optional[float] = None
max_output_tokens: Optional[int] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
response_mime_type: Optional[str] = None
response_schema: Optional[dict] = None
candidate_count: Optional[int] = None
stop_sequences: Optional[list] = None
def __init__(
self,
temperature: Optional[float] = None,
max_output_tokens: Optional[int] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
response_mime_type: Optional[str] = None,
response_schema: Optional[dict] = None,
candidate_count: Optional[int] = None,
stop_sequences: Optional[list] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return super().get_config()
def get_supported_openai_params(self, model: str) -> List[str]:
return [
"temperature",
"top_p",
"max_tokens",
"max_completion_tokens",
"stream",
"tools",
"tool_choice",
"functions",
"response_format",
"n",
"stop",
"logprobs",
]
def map_openai_params(
self,
non_default_params: Dict,
optional_params: Dict,
model: str,
drop_params: bool,
) -> Dict:
# drop frequency_penalty and presence_penalty
if "frequency_penalty" in non_default_params:
del non_default_params["frequency_penalty"]
if "presence_penalty" in non_default_params:
del non_default_params["presence_penalty"]
return super().map_openai_params(
model=model,
non_default_params=non_default_params,
optional_params=optional_params,
drop_params=drop_params,
)
def _transform_messages(
self, messages: List[AllMessageValues]
) -> List[ContentType]:
"""
Google AI Studio Gemini does not support image urls in messages.
"""
for message in messages:
_message_content = message.get("content")
if _message_content is not None and isinstance(_message_content, list):
_parts: List[PartType] = []
for element in _message_content:
if element.get("type") == "image_url":
img_element = element
_image_url: Optional[str] = None
if isinstance(img_element.get("image_url"), dict):
_image_url = img_element["image_url"].get("url") # type: ignore
else:
_image_url = img_element.get("image_url") # type: ignore
if _image_url and "https://" in _image_url:
image_obj = convert_to_anthropic_image_obj(_image_url)
img_element["image_url"] = ( # type: ignore
convert_generic_image_chunk_to_openai_image_obj(
image_obj
)
)
return _gemini_convert_messages_with_history(messages=messages)
async def make_call(
client: Optional[AsyncHTTPHandler],
api_base: str,

View file

@ -1944,6 +1944,8 @@
"max_output_tokens": 4096,
"input_cost_per_token": 0.00000014,
"input_cost_per_token_cache_hit": 0.000000014,
"cache_read_input_token_cost": 0.000000014,
"cache_creation_input_token_cost": 0.0,
"output_cost_per_token": 0.00000028,
"litellm_provider": "deepseek",
"mode": "chat",
@ -3758,6 +3760,8 @@
"max_audio_length_hours": 8.4,
"max_audio_per_prompt": 1,
"max_pdf_size_mb": 30,
"cache_read_input_token_cost": 0.00000001875,
"cache_creation_input_token_cost": 0.000001,
"input_cost_per_token": 0.000000075,
"input_cost_per_token_above_128k_tokens": 0.00000015,
"output_cost_per_token": 0.0000003,
@ -3783,6 +3787,8 @@
"max_audio_length_hours": 8.4,
"max_audio_per_prompt": 1,
"max_pdf_size_mb": 30,
"cache_read_input_token_cost": 0.00000001875,
"cache_creation_input_token_cost": 0.000001,
"input_cost_per_token": 0.000000075,
"input_cost_per_token_above_128k_tokens": 0.00000015,
"output_cost_per_token": 0.0000003,
@ -3842,6 +3848,7 @@
"supports_function_calling": true,
"supports_vision": true,
"supports_response_schema": true,
"supports_prompt_caching": true,
"tpm": 4000000,
"rpm": 2000,
"source": "https://ai.google.dev/pricing"
@ -3866,6 +3873,7 @@
"supports_function_calling": true,
"supports_vision": true,
"supports_response_schema": true,
"supports_prompt_caching": true,
"tpm": 4000000,
"rpm": 4000,
"source": "https://ai.google.dev/pricing"
@ -3890,6 +3898,7 @@
"supports_function_calling": true,
"supports_vision": true,
"supports_response_schema": true,
"supports_prompt_caching": true,
"tpm": 4000000,
"rpm": 4000,
"source": "https://ai.google.dev/pricing"

View file

@ -183,6 +183,7 @@ class UsageMetadata(TypedDict, total=False):
promptTokenCount: int
totalTokenCount: int
candidatesTokenCount: int
cachedContentTokenCount: int
class CachedContent(TypedDict, total=False):

View file

@ -1944,6 +1944,8 @@
"max_output_tokens": 4096,
"input_cost_per_token": 0.00000014,
"input_cost_per_token_cache_hit": 0.000000014,
"cache_read_input_token_cost": 0.000000014,
"cache_creation_input_token_cost": 0.0,
"output_cost_per_token": 0.00000028,
"litellm_provider": "deepseek",
"mode": "chat",
@ -3758,6 +3760,8 @@
"max_audio_length_hours": 8.4,
"max_audio_per_prompt": 1,
"max_pdf_size_mb": 30,
"cache_read_input_token_cost": 0.00000001875,
"cache_creation_input_token_cost": 0.000001,
"input_cost_per_token": 0.000000075,
"input_cost_per_token_above_128k_tokens": 0.00000015,
"output_cost_per_token": 0.0000003,
@ -3783,6 +3787,8 @@
"max_audio_length_hours": 8.4,
"max_audio_per_prompt": 1,
"max_pdf_size_mb": 30,
"cache_read_input_token_cost": 0.00000001875,
"cache_creation_input_token_cost": 0.000001,
"input_cost_per_token": 0.000000075,
"input_cost_per_token_above_128k_tokens": 0.00000015,
"output_cost_per_token": 0.0000003,
@ -3842,6 +3848,7 @@
"supports_function_calling": true,
"supports_vision": true,
"supports_response_schema": true,
"supports_prompt_caching": true,
"tpm": 4000000,
"rpm": 2000,
"source": "https://ai.google.dev/pricing"
@ -3866,6 +3873,7 @@
"supports_function_calling": true,
"supports_vision": true,
"supports_response_schema": true,
"supports_prompt_caching": true,
"tpm": 4000000,
"rpm": 4000,
"source": "https://ai.google.dev/pricing"
@ -3890,6 +3898,7 @@
"supports_function_calling": true,
"supports_vision": true,
"supports_response_schema": true,
"supports_prompt_caching": true,
"tpm": 4000000,
"rpm": 4000,
"source": "https://ai.google.dev/pricing"

View file

@ -5,6 +5,7 @@ import sys
from typing import Any, Dict, List
from unittest.mock import MagicMock, Mock, patch
import os
import uuid
sys.path.insert(
0, os.path.abspath("../..")
@ -45,6 +46,7 @@ def _usage_format_tests(usage: litellm.Usage):
}
```
"""
print(f"usage={usage}")
assert usage.total_tokens == usage.prompt_tokens + usage.completion_tokens
assert usage.prompt_tokens > usage.prompt_tokens_details.cached_tokens
@ -342,54 +344,75 @@ class BaseLLMChatTest(ABC):
print("Model does not support prompt caching")
pytest.skip("Model does not support prompt caching")
try:
for _ in range(2):
response = self.completion_function(
**base_completion_call_args,
messages=[
# System Message
{
"role": "system",
"content": [
{
"type": "text",
"text": "Here is the full text of a complex legal agreement"
* 400,
"cache_control": {"type": "ephemeral"},
}
],
},
# marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache.
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key terms and conditions in this agreement?",
"cache_control": {"type": "ephemeral"},
}
],
},
{
"role": "assistant",
"content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/mo",
},
# The final turn is marked with cache-control, for continuing in followups.
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key terms and conditions in this agreement?",
"cache_control": {"type": "ephemeral"},
}
],
},
],
max_tokens=10,
)
uuid_str = str(uuid.uuid4())
messages = [
# System Message
{
"role": "system",
"content": [
{
"type": "text",
"text": "Here is the full text of a complex legal agreement {}".format(
uuid_str
)
* 400,
"cache_control": {"type": "ephemeral"},
}
],
},
# marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache.
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key terms and conditions in this agreement?",
"cache_control": {"type": "ephemeral"},
}
],
},
{
"role": "assistant",
"content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/mo",
},
# The final turn is marked with cache-control, for continuing in followups.
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key terms and conditions in this agreement?",
"cache_control": {"type": "ephemeral"},
}
],
},
]
_usage_format_tests(response.usage)
try:
## call 1
response = self.completion_function(
**base_completion_call_args,
messages=messages,
max_tokens=10,
)
initial_cost = response._hidden_params["response_cost"]
## call 2
response = self.completion_function(
**base_completion_call_args,
messages=messages,
max_tokens=10,
)
cached_cost = response._hidden_params["response_cost"]
assert (
cached_cost <= initial_cost
), "Cached cost={} should be less than initial cost={}".format(
cached_cost, initial_cost
)
_usage_format_tests(response.usage)
print("response=", response)
print("response.usage=", response.usage)

View file

@ -1,9 +1,21 @@
import os
import sys
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from base_llm_unit_tests import BaseLLMChatTest
from litellm.llms.vertex_ai.context_caching.transformation import (
separate_cached_messages,
)
class TestGoogleAIStudioGemini(BaseLLMChatTest):
def get_base_completion_call_args(self) -> dict:
return {"model": "gemini/gemini-1.5-flash"}
return {"model": "gemini/gemini-1.5-flash-002"}
def test_tool_call_no_arguments(self, tool_call_no_arguments):
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
@ -13,3 +25,50 @@ class TestGoogleAIStudioGemini(BaseLLMChatTest):
result = convert_to_gemini_tool_call_invoke(tool_call_no_arguments)
print(result)
def test_gemini_context_caching_separate_messages():
messages = [
# System Message
{
"role": "system",
"content": [
{
"type": "text",
"text": "Here is the full text of a complex legal agreement" * 400,
"cache_control": {"type": "ephemeral"},
}
],
},
# marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache.
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key terms and conditions in this agreement?",
"cache_control": {"type": "ephemeral"},
}
],
},
{
"role": "assistant",
"content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/mo",
},
# The final turn is marked with cache-control, for continuing in followups.
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key terms and conditions in this agreement?",
"cache_control": {"type": "ephemeral"},
}
],
},
]
cached_messages, non_cached_messages = separate_cached_messages(messages)
print(cached_messages)
print(non_cached_messages)
assert len(cached_messages) > 0, "Cached messages should be present"
assert len(non_cached_messages) > 0, "Non-cached messages should be present"

View file

@ -353,9 +353,9 @@ def test_all_model_configs():
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
VertexAIConfig,
GoogleAIStudioGeminiConfig,
VertexGeminiConfig,
)
from litellm.llms.gemini.chat.transformation import GoogleAIStudioGeminiConfig
assert "max_completion_tokens" in VertexAIConfig().get_supported_openai_params()

View file

@ -1163,14 +1163,19 @@ def test_completion_cost_azure_common_deployment_name():
assert "azure/gpt-4" == mock_client.call_args.kwargs["base_model"]
def test_completion_cost_anthropic_prompt_caching():
@pytest.mark.parametrize(
"model, custom_llm_provider",
[
("claude-3-5-sonnet-20240620", "anthropic"),
("gemini/gemini-1.5-flash-001", "gemini"),
],
)
def test_completion_cost_prompt_caching(model, custom_llm_provider):
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
from litellm.utils import Choices, Message, ModelResponse, Usage
model = "anthropic/claude-3-5-sonnet-20240620"
## WRITE TO CACHE ## (MORE EXPENSIVE)
response_1 = ModelResponse(
id="chatcmpl-3f427194-0840-4d08-b571-56bfe38a5424",
@ -1187,7 +1192,7 @@ def test_completion_cost_anthropic_prompt_caching():
)
],
created=1725036547,
model="claude-3-5-sonnet-20240620",
model=model,
object="chat.completion",
system_fingerprint=None,
usage=Usage(
@ -1203,7 +1208,7 @@ def test_completion_cost_anthropic_prompt_caching():
cost_1 = completion_cost(model=model, completion_response=response_1)
_model_info = litellm.get_model_info(
model="claude-3-5-sonnet-20240620", custom_llm_provider="anthropic"
model=model, custom_llm_provider=custom_llm_provider
)
expected_cost = (
(
@ -1211,11 +1216,12 @@ def test_completion_cost_anthropic_prompt_caching():
- response_1.usage.prompt_tokens_details.cached_tokens
)
* _model_info["input_cost_per_token"]
+ response_1.usage.prompt_tokens_details.cached_tokens
+ (response_1.usage.prompt_tokens_details.cached_tokens or 0)
* _model_info["cache_read_input_token_cost"]
+ response_1.usage.cache_creation_input_tokens
+ (response_1.usage.cache_creation_input_tokens or 0)
* _model_info["cache_creation_input_token_cost"]
+ response_1.usage.completion_tokens * _model_info["output_cost_per_token"]
+ (response_1.usage.completion_tokens or 0)
* _model_info["output_cost_per_token"]
) # Cost of processing (non-cache hit + cache hit) + Cost of cache-writing (cache writing)
assert round(expected_cost, 5) == round(cost_1, 5)
@ -1238,7 +1244,7 @@ def test_completion_cost_anthropic_prompt_caching():
)
],
created=1725036547,
model="claude-3-5-sonnet-20240620",
model=model,
object="chat.completion",
system_fingerprint=None,
usage=Usage(
@ -2437,7 +2443,7 @@ def test_completion_cost_params_2():
def test_completion_cost_params_gemini_3():
from litellm.utils import Choices, Message, ModelResponse, Usage
from litellm.litellm_core_utils.llm_cost_calc.google import cost_per_character
from litellm.llms.vertex_ai.cost_calculator import cost_per_character
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")