LiteLLM Minor Fixes & Improvements (12/23/2024) - p3 (#7394)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 35s

* build(model_prices_and_context_window.json): add gemini-1.5-flash context caching

* fix(context_caching/transformation.py): just use last identified cache point

Fixes https://github.com/BerriAI/litellm/issues/6738

* fix(context_caching/transformation.py): pick first contiguous block - handles system message error from google

Fixes https://github.com/BerriAI/litellm/issues/6738

* fix(vertex_ai/gemini/): track context caching tokens

* refactor(gemini/): place transformation.py inside `chat/` folder

make it easy for user to know we support the equivalent endpoint

* fix: fix import

* refactor(vertex_ai/): move vertex_ai cost calc inside vertex_ai/ folder

make it easier to see cost calculation logic

* fix: fix linting errors

* fix: fix circular import

* feat(gemini/cost_calculator.py): support gemini context caching cost calculation

generifies anthropic's cost calculation function and uses it across anthropic + gemini

* build(model_prices_and_context_window.json): add cost tracking for gemini-1.5-flash-002 w/ context caching

Closes https://github.com/BerriAI/litellm/issues/6891

* docs(gemini.md): add gemini context caching architecture diagram

make it easier for user to understand how context caching works

* docs(gemini.md): link to relevant gemini context caching code

* docs(gemini/context_caching): add readme in github, make it easy for dev to know context caching is supported + where to go for code

* fix(llm_cost_calc/utils.py): handle gemini 128k token diff cost calc scenario

* fix(deepseek/cost_calculator.py): support deepseek context caching cost calculation

* test: fix test
This commit is contained in:
Krish Dholakia 2024-12-23 22:02:52 -08:00 committed by GitHub
parent 442d309bcd
commit c3edfc2c92
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 719 additions and 447 deletions

View file

@ -10,7 +10,8 @@ import TabItem from '@theme/TabItem';
| Provider Route on LiteLLM | `gemini/` | | Provider Route on LiteLLM | `gemini/` |
| Provider Doc | [Google AI Studio ↗](https://ai.google.dev/aistudio) | | Provider Doc | [Google AI Studio ↗](https://ai.google.dev/aistudio) |
| API Endpoint for Provider | https://generativelanguage.googleapis.com | | API Endpoint for Provider | https://generativelanguage.googleapis.com |
| Supported Endpoints | `/chat/completions`, `/embeddings` | | Supported OpenAI Endpoints | `/chat/completions`, `/embeddings`, `/completions` |
| Pass-through Endpoint | [Supported](../pass_through/google_ai_studio.md) |
<br /> <br />
@ -552,175 +553,6 @@ content = response.get('choices', [{}])[0].get('message', {}).get('content')
print(content) print(content)
``` ```
## Context Caching
Use Google AI Studio context caching is supported by
```bash
{
...,
"cache_control": {"type": "ephemeral"}
}
```
in your message content block.
:::note
Gemini Context Caching only allows 1 block of continuous messages to be cached.
The raw request to Gemini looks like this:
```bash
curl -X POST "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash-001:generateContent?key=$GOOGLE_API_KEY" \
-H 'Content-Type: application/json' \
-d '{
"contents": [
{
"parts":[{
"text": "Please summarize this transcript"
}],
"role": "user"
},
],
"cachedContent": "'$CACHE_NAME'"
}'
```
:::
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import completion
for _ in range(2):
resp = completion(
model="gemini/gemini-1.5-pro",
messages=[
# System Message
{
"role": "system",
"content": [
{
"type": "text",
"text": "Here is the full text of a complex legal agreement" * 4000,
"cache_control": {"type": "ephemeral"}, # 👈 KEY CHANGE
}
],
},
# marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache.
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key terms and conditions in this agreement?",
"cache_control": {"type": "ephemeral"},
}
],
}]
)
print(resp.usage) # 👈 2nd usage block will be less, since cached tokens used
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Setup config.yaml
```yaml
model_list:
- model_name: gemini-1.5-pro
litellm_params:
model: gemini/gemini-1.5-pro
api_key: os.environ/GEMINI_API_KEY
```
2. Start proxy
```bash
litellm --config /path/to/config.yaml
```
3. Test it!
[**See Langchain, OpenAI JS, Llamaindex, etc. examples**](../proxy/user_keys.md#request-format)
<Tabs>
<TabItem value="curl" label="Curl">
```bash
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--data '{
"model": "gemini-1.5-pro",
"messages": [
# System Message
{
"role": "system",
"content": [
{
"type": "text",
"text": "Here is the full text of a complex legal agreement" * 4000,
"cache_control": {"type": "ephemeral"}, # 👈 KEY CHANGE
}
],
},
# marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache.
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key terms and conditions in this agreement?",
"cache_control": {"type": "ephemeral"},
}
],
}],
}'
```
</TabItem>
<TabItem value="openai-python" label="OpenAI Python SDK">
```python
import openai
client = openai.AsyncOpenAI(
api_key="anything", # litellm proxy api key
base_url="http://0.0.0.0:4000" # litellm proxy base url
)
response = await client.chat.completions.create(
model="gemini-1.5-pro",
messages=[
{
"role": "system",
"content": [
{
"type": "text",
"text": "Here is the full text of a complex legal agreement" * 4000,
"cache_control": {"type": "ephemeral"}, # 👈 KEY CHANGE
}
],
},
{
"role": "user",
"content": "what are the key terms and conditions in this agreement?",
},
]
)
```
</TabItem>
</Tabs>
</TabItem>
</Tabs>
## Usage - PDF / Videos / etc. Files ## Usage - PDF / Videos / etc. Files
### Inline Data (e.g. audio stream) ### Inline Data (e.g. audio stream)
@ -857,3 +689,191 @@ response = litellm.completion(
| gemini-pro | `completion(model='gemini/gemini-pro', messages)` | `os.environ['GEMINI_API_KEY']` | | gemini-pro | `completion(model='gemini/gemini-pro', messages)` | `os.environ['GEMINI_API_KEY']` |
| gemini-1.5-pro-latest | `completion(model='gemini/gemini-1.5-pro-latest', messages)` | `os.environ['GEMINI_API_KEY']` | | gemini-1.5-pro-latest | `completion(model='gemini/gemini-1.5-pro-latest', messages)` | `os.environ['GEMINI_API_KEY']` |
| gemini-pro-vision | `completion(model='gemini/gemini-pro-vision', messages)` | `os.environ['GEMINI_API_KEY']` | | gemini-pro-vision | `completion(model='gemini/gemini-pro-vision', messages)` | `os.environ['GEMINI_API_KEY']` |
## Context Caching
Use Google AI Studio context caching is supported by
```bash
{
{
"role": "system",
"content": ...,
"cache_control": {"type": "ephemeral"} # 👈 KEY CHANGE
},
...
}
```
in your message content block.
### Architecture Diagram
<Image img={require('../../img/gemini_context_caching.png')} />
**Notes:**
- [Relevant code](https://github.com/BerriAI/litellm/blob/main/litellm/llms/vertex_ai/context_caching/vertex_ai_context_caching.py#L255)
- Gemini Context Caching only allows 1 block of continuous messages to be cached.
- If multiple non-continuous blocks contain `cache_control` - the first continuous block will be used. (sent to `/cachedContent` in the [Gemini format](https://ai.google.dev/api/caching#cache_create-SHELL))
- The raw request to Gemini's `/generateContent` endpoint looks like this:
```bash
curl -X POST "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash-001:generateContent?key=$GOOGLE_API_KEY" \
-H 'Content-Type: application/json' \
-d '{
"contents": [
{
"parts":[{
"text": "Please summarize this transcript"
}],
"role": "user"
},
],
"cachedContent": "'$CACHE_NAME'"
}'
```
### Example Usage
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import completion
for _ in range(2):
resp = completion(
model="gemini/gemini-1.5-pro",
messages=[
# System Message
{
"role": "system",
"content": [
{
"type": "text",
"text": "Here is the full text of a complex legal agreement" * 4000,
"cache_control": {"type": "ephemeral"}, # 👈 KEY CHANGE
}
],
},
# marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache.
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key terms and conditions in this agreement?",
"cache_control": {"type": "ephemeral"},
}
],
}]
)
print(resp.usage) # 👈 2nd usage block will be less, since cached tokens used
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Setup config.yaml
```yaml
model_list:
- model_name: gemini-1.5-pro
litellm_params:
model: gemini/gemini-1.5-pro
api_key: os.environ/GEMINI_API_KEY
```
2. Start proxy
```bash
litellm --config /path/to/config.yaml
```
3. Test it!
[**See Langchain, OpenAI JS, Llamaindex, etc. examples**](../proxy/user_keys.md#request-format)
<Tabs>
<TabItem value="curl" label="Curl">
```bash
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--data '{
"model": "gemini-1.5-pro",
"messages": [
# System Message
{
"role": "system",
"content": [
{
"type": "text",
"text": "Here is the full text of a complex legal agreement" * 4000,
"cache_control": {"type": "ephemeral"}, # 👈 KEY CHANGE
}
],
},
# marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache.
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key terms and conditions in this agreement?",
"cache_control": {"type": "ephemeral"},
}
],
}],
}'
```
</TabItem>
<TabItem value="openai-python" label="OpenAI Python SDK">
```python
import openai
client = openai.AsyncOpenAI(
api_key="anything", # litellm proxy api key
base_url="http://0.0.0.0:4000" # litellm proxy base url
)
response = await client.chat.completions.create(
model="gemini-1.5-pro",
messages=[
{
"role": "system",
"content": [
{
"type": "text",
"text": "Here is the full text of a complex legal agreement" * 4000,
"cache_control": {"type": "ephemeral"}, # 👈 KEY CHANGE
}
],
},
{
"role": "user",
"content": "what are the key terms and conditions in this agreement?",
},
]
)
```
</TabItem>
</Tabs>
</TabItem>
</Tabs>

Binary file not shown.

After

Width:  |  Height:  |  Size: 314 KiB

View file

@ -1049,9 +1049,11 @@ from .llms.petals.completion.transformation import PetalsConfig
from .llms.deprecated_providers.aleph_alpha import AlephAlphaConfig from .llms.deprecated_providers.aleph_alpha import AlephAlphaConfig
from .llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import ( from .llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
VertexGeminiConfig, VertexGeminiConfig,
GoogleAIStudioGeminiConfig,
VertexAIConfig, VertexAIConfig,
GoogleAIStudioGeminiConfig as GeminiConfig, )
from .llms.gemini.chat.transformation import (
GoogleAIStudioGeminiConfig,
GoogleAIStudioGeminiConfig as GeminiConfig, # aliased to maintain backwards compatibility
) )

View file

@ -8,15 +8,6 @@ from pydantic import BaseModel
import litellm import litellm
import litellm._logging import litellm._logging
from litellm import verbose_logger from litellm import verbose_logger
from litellm.litellm_core_utils.llm_cost_calc.google import (
cost_per_character as google_cost_per_character,
)
from litellm.litellm_core_utils.llm_cost_calc.google import (
cost_per_token as google_cost_per_token,
)
from litellm.litellm_core_utils.llm_cost_calc.google import (
cost_router as google_cost_router,
)
from litellm.litellm_core_utils.llm_cost_calc.utils import _generic_cost_per_character from litellm.litellm_core_utils.llm_cost_calc.utils import _generic_cost_per_character
from litellm.llms.anthropic.cost_calculation import ( from litellm.llms.anthropic.cost_calculation import (
cost_per_token as anthropic_cost_per_token, cost_per_token as anthropic_cost_per_token,
@ -36,14 +27,25 @@ from litellm.llms.cohere.cost_calculator import (
from litellm.llms.databricks.cost_calculator import ( from litellm.llms.databricks.cost_calculator import (
cost_per_token as databricks_cost_per_token, cost_per_token as databricks_cost_per_token,
) )
from litellm.llms.deepseek.cost_calculator import (
cost_per_token as deepseek_cost_per_token,
)
from litellm.llms.fireworks_ai.cost_calculator import ( from litellm.llms.fireworks_ai.cost_calculator import (
cost_per_token as fireworks_ai_cost_per_token, cost_per_token as fireworks_ai_cost_per_token,
) )
from litellm.llms.gemini.cost_calculator import cost_per_token as gemini_cost_per_token
from litellm.llms.openai.cost_calculation import ( from litellm.llms.openai.cost_calculation import (
cost_per_second as openai_cost_per_second, cost_per_second as openai_cost_per_second,
) )
from litellm.llms.openai.cost_calculation import cost_per_token as openai_cost_per_token from litellm.llms.openai.cost_calculation import cost_per_token as openai_cost_per_token
from litellm.llms.together_ai.cost_calculator import get_model_params_and_category from litellm.llms.together_ai.cost_calculator import get_model_params_and_category
from litellm.llms.vertex_ai.cost_calculator import (
cost_per_character as google_cost_per_character,
)
from litellm.llms.vertex_ai.cost_calculator import (
cost_per_token as google_cost_per_token,
)
from litellm.llms.vertex_ai.cost_calculator import cost_router as google_cost_router
from litellm.llms.vertex_ai.image_generation.cost_calculator import ( from litellm.llms.vertex_ai.image_generation.cost_calculator import (
cost_calculator as vertex_ai_image_cost_calculator, cost_calculator as vertex_ai_image_cost_calculator,
) )
@ -272,12 +274,9 @@ def cost_per_token( # noqa: PLR0915
model=model, usage=usage_block, response_time_ms=response_time_ms model=model, usage=usage_block, response_time_ms=response_time_ms
) )
elif custom_llm_provider == "gemini": elif custom_llm_provider == "gemini":
return google_cost_per_token( return gemini_cost_per_token(model=model, usage=usage_block)
model=model_without_prefix, elif custom_llm_provider == "deepseek":
custom_llm_provider=custom_llm_provider, return deepseek_cost_per_token(model=model, usage=usage_block)
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
else: else:
model_info = litellm.get_model_info( model_info = litellm.get_model_info(
model=model, custom_llm_provider=custom_llm_provider model=model, custom_llm_provider=custom_llm_provider

View file

@ -5,6 +5,14 @@ from typing import Optional, Tuple
import litellm import litellm
from litellm import verbose_logger from litellm import verbose_logger
from litellm.types.utils import ModelInfo, Usage
from litellm.utils import get_model_info
def _is_above_128k(tokens: float) -> bool:
if tokens > 128000:
return True
return False
def _generic_cost_per_character( def _generic_cost_per_character(
@ -80,3 +88,93 @@ def _generic_cost_per_character(
completion_cost = None completion_cost = None
return prompt_cost, completion_cost return prompt_cost, completion_cost
def _get_prompt_token_base_cost(model_info: ModelInfo, usage: Usage) -> float:
"""
Return prompt cost for a given model and usage.
If input_tokens > 128k and `input_cost_per_token_above_128k_tokens` is set, then we use the `input_cost_per_token_above_128k_tokens` field.
"""
input_cost_per_token_above_128k_tokens = model_info.get(
"input_cost_per_token_above_128k_tokens"
)
if _is_above_128k(usage.prompt_tokens) and input_cost_per_token_above_128k_tokens:
return input_cost_per_token_above_128k_tokens
return model_info["input_cost_per_token"]
def _get_completion_token_base_cost(model_info: ModelInfo, usage: Usage) -> float:
"""
Return prompt cost for a given model and usage.
If input_tokens > 128k and `input_cost_per_token_above_128k_tokens` is set, then we use the `input_cost_per_token_above_128k_tokens` field.
"""
output_cost_per_token_above_128k_tokens = model_info.get(
"output_cost_per_token_above_128k_tokens"
)
if (
_is_above_128k(usage.completion_tokens)
and output_cost_per_token_above_128k_tokens
):
return output_cost_per_token_above_128k_tokens
return model_info["output_cost_per_token"]
def generic_cost_per_token(
model: str, usage: Usage, custom_llm_provider: str
) -> Tuple[float, float]:
"""
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
Handles context caching as well.
Input:
- model: str, the model name without provider prefix
- usage: LiteLLM Usage block, containing anthropic caching information
Returns:
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
"""
## GET MODEL INFO
model_info = get_model_info(model=model, custom_llm_provider=custom_llm_provider)
## CALCULATE INPUT COST
### Cost of processing (non-cache hit + cache hit) + Cost of cache-writing (cache writing)
prompt_cost = 0.0
### PROCESSING COST
non_cache_hit_tokens = usage.prompt_tokens
cache_hit_tokens = 0
if usage.prompt_tokens_details and usage.prompt_tokens_details.cached_tokens:
cache_hit_tokens = usage.prompt_tokens_details.cached_tokens
non_cache_hit_tokens = non_cache_hit_tokens - cache_hit_tokens
prompt_base_cost = _get_prompt_token_base_cost(model_info=model_info, usage=usage)
prompt_cost = float(non_cache_hit_tokens) * prompt_base_cost
_cache_read_input_token_cost = model_info.get("cache_read_input_token_cost")
if (
_cache_read_input_token_cost is not None
and usage.prompt_tokens_details
and usage.prompt_tokens_details.cached_tokens
):
prompt_cost += (
float(usage.prompt_tokens_details.cached_tokens)
* _cache_read_input_token_cost
)
### CACHE WRITING COST
_cache_creation_input_token_cost = model_info.get("cache_creation_input_token_cost")
if _cache_creation_input_token_cost is not None:
prompt_cost += (
float(usage._cache_creation_input_tokens) * _cache_creation_input_token_cost
)
## CALCULATE OUTPUT COST
completion_base_cost = _get_completion_token_base_cost(
model_info=model_info, usage=usage
)
completion_cost = usage["completion_tokens"] * completion_base_cost
return prompt_cost, completion_cost

View file

@ -5,8 +5,8 @@ Helper util for handling anthropic-specific cost calculation
from typing import Tuple from typing import Tuple
from litellm.litellm_core_utils.llm_cost_calc.utils import generic_cost_per_token
from litellm.types.utils import Usage from litellm.types.utils import Usage
from litellm.utils import get_model_info
def cost_per_token(model: str, usage: Usage) -> Tuple[float, float]: def cost_per_token(model: str, usage: Usage) -> Tuple[float, float]:
@ -20,40 +20,6 @@ def cost_per_token(model: str, usage: Usage) -> Tuple[float, float]:
Returns: Returns:
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
""" """
## GET MODEL INFO return generic_cost_per_token(
model_info = get_model_info(model=model, custom_llm_provider="anthropic") model=model, usage=usage, custom_llm_provider="anthropic"
## CALCULATE INPUT COST
### Cost of processing (non-cache hit + cache hit) + Cost of cache-writing (cache writing)
prompt_cost = 0.0
### PROCESSING COST
non_cache_hit_tokens = usage.prompt_tokens
cache_hit_tokens = 0
if usage.prompt_tokens_details and usage.prompt_tokens_details.cached_tokens:
cache_hit_tokens = usage.prompt_tokens_details.cached_tokens
non_cache_hit_tokens = non_cache_hit_tokens - cache_hit_tokens
prompt_cost = float(non_cache_hit_tokens) * model_info["input_cost_per_token"]
_cache_read_input_token_cost = model_info.get("cache_read_input_token_cost")
if (
_cache_read_input_token_cost is not None
and usage.prompt_tokens_details
and usage.prompt_tokens_details.cached_tokens
):
prompt_cost += (
float(usage.prompt_tokens_details.cached_tokens)
* _cache_read_input_token_cost
) )
### CACHE WRITING COST
_cache_creation_input_token_cost = model_info.get("cache_creation_input_token_cost")
if _cache_creation_input_token_cost is not None:
prompt_cost += (
float(usage._cache_creation_input_tokens) * _cache_creation_input_token_cost
)
## CALCULATE OUTPUT COST
completion_cost = usage["completion_tokens"] * model_info["output_cost_per_token"]
return prompt_cost, completion_cost

View file

@ -0,0 +1,21 @@
"""
Cost calculator for DeepSeek Chat models.
Handles prompt caching scenario.
"""
from typing import Tuple
from litellm.litellm_core_utils.llm_cost_calc.utils import generic_cost_per_token
from litellm.types.utils import Usage
def cost_per_token(model: str, usage: Usage) -> Tuple[float, float]:
"""
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
Follows the same logic as Anthropic's cost per token calculation.
"""
return generic_cost_per_token(
model=model, usage=usage, custom_llm_provider="deepseek"
)

View file

@ -0,0 +1,131 @@
from typing import Dict, List, Optional
from litellm.litellm_core_utils.prompt_templates.factory import (
convert_generic_image_chunk_to_openai_image_obj,
convert_to_anthropic_image_obj,
)
from litellm.types.llms.openai import AllMessageValues
from litellm.types.llms.vertex_ai import ContentType, PartType
from ...vertex_ai.gemini.transformation import _gemini_convert_messages_with_history
from ...vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexGeminiConfig
class GoogleAIStudioGeminiConfig(
VertexGeminiConfig
): # key diff from VertexAI - 'frequency_penalty' and 'presence_penalty' not supported
"""
Reference: https://ai.google.dev/api/rest/v1beta/GenerationConfig
The class `GoogleAIStudioGeminiConfig` provides configuration for the Google AI Studio's Gemini API interface. Below are the parameters:
- `temperature` (float): This controls the degree of randomness in token selection.
- `max_output_tokens` (integer): This sets the limitation for the maximum amount of token in the text output. In this case, the default value is 256.
- `top_p` (float): The tokens are selected from the most probable to the least probable until the sum of their probabilities equals the `top_p` value. Default is 0.95.
- `top_k` (integer): The value of `top_k` determines how many of the most probable tokens are considered in the selection. For example, a `top_k` of 1 means the selected token is the most probable among all tokens. The default value is 40.
- `response_mime_type` (str): The MIME type of the response. The default value is 'text/plain'. Other values - `application/json`.
- `response_schema` (dict): Optional. Output response schema of the generated candidate text when response mime type can have schema. Schema can be objects, primitives or arrays and is a subset of OpenAPI schema. If set, a compatible response_mime_type must also be set. Compatible mimetypes: application/json: Schema for JSON response.
- `candidate_count` (int): Number of generated responses to return.
- `stop_sequences` (List[str]): The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop at the first appearance of a stop sequence. The stop sequence will not be included as part of the response.
Note: Please make sure to modify the default parameters as required for your use case.
"""
temperature: Optional[float] = None
max_output_tokens: Optional[int] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
response_mime_type: Optional[str] = None
response_schema: Optional[dict] = None
candidate_count: Optional[int] = None
stop_sequences: Optional[list] = None
def __init__(
self,
temperature: Optional[float] = None,
max_output_tokens: Optional[int] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
response_mime_type: Optional[str] = None,
response_schema: Optional[dict] = None,
candidate_count: Optional[int] = None,
stop_sequences: Optional[list] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return super().get_config()
def get_supported_openai_params(self, model: str) -> List[str]:
return [
"temperature",
"top_p",
"max_tokens",
"max_completion_tokens",
"stream",
"tools",
"tool_choice",
"functions",
"response_format",
"n",
"stop",
"logprobs",
]
def map_openai_params(
self,
non_default_params: Dict,
optional_params: Dict,
model: str,
drop_params: bool,
) -> Dict:
# drop frequency_penalty and presence_penalty
if "frequency_penalty" in non_default_params:
del non_default_params["frequency_penalty"]
if "presence_penalty" in non_default_params:
del non_default_params["presence_penalty"]
return super().map_openai_params(
model=model,
non_default_params=non_default_params,
optional_params=optional_params,
drop_params=drop_params,
)
def _transform_messages(
self, messages: List[AllMessageValues]
) -> List[ContentType]:
"""
Google AI Studio Gemini does not support image urls in messages.
"""
for message in messages:
_message_content = message.get("content")
if _message_content is not None and isinstance(_message_content, list):
_parts: List[PartType] = []
for element in _message_content:
if element.get("type") == "image_url":
img_element = element
_image_url: Optional[str] = None
if isinstance(img_element.get("image_url"), dict):
_image_url = img_element["image_url"].get("url") # type: ignore
else:
_image_url = img_element.get("image_url") # type: ignore
if _image_url and "https://" in _image_url:
image_obj = convert_to_anthropic_image_obj(_image_url)
img_element["image_url"] = ( # type: ignore
convert_generic_image_chunk_to_openai_image_obj(
image_obj
)
)
return _gemini_convert_messages_with_history(messages=messages)

View file

@ -0,0 +1 @@
[Go here for the Gemini Context Caching code](../../vertex_ai/context_caching/)

View file

@ -0,0 +1,21 @@
"""
This file is used to calculate the cost of the Gemini API.
Handles the context caching for Gemini API.
"""
from typing import Tuple
from litellm.litellm_core_utils.llm_cost_calc.utils import generic_cost_per_token
from litellm.types.utils import Usage
def cost_per_token(model: str, usage: Usage) -> Tuple[float, float]:
"""
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
Follows the same logic as Anthropic's cost per token calculation.
"""
return generic_cost_per_token(
model=model, usage=usage, custom_llm_provider="gemini"
)

View file

@ -10,13 +10,43 @@ from litellm.types.llms.openai import AllMessageValues
from litellm.types.llms.vertex_ai import CachedContentRequestBody from litellm.types.llms.vertex_ai import CachedContentRequestBody
from litellm.utils import is_cached_message from litellm.utils import is_cached_message
from ..common_utils import VertexAIError, get_supports_system_message from ..common_utils import get_supports_system_message
from ..gemini.transformation import ( from ..gemini.transformation import (
_gemini_convert_messages_with_history, _gemini_convert_messages_with_history,
_transform_system_message, _transform_system_message,
) )
def get_first_continuous_block_idx(
filtered_messages: List[Tuple[int, AllMessageValues]] # (idx, message)
) -> int:
"""
Find the array index that ends the first continuous sequence of message blocks.
Args:
filtered_messages: List of tuples containing (index, message) pairs
Returns:
int: The array index where the first continuous sequence ends
"""
if not filtered_messages:
return -1
if len(filtered_messages) == 1:
return 0
current_value = filtered_messages[0][0]
# Search forward through the array indices
for i in range(1, len(filtered_messages)):
if filtered_messages[i][0] != current_value + 1:
return i - 1
current_value = filtered_messages[i][0]
# If we made it through the whole list, return the last index
return len(filtered_messages) - 1
def separate_cached_messages( def separate_cached_messages(
messages: List[AllMessageValues], messages: List[AllMessageValues],
) -> Tuple[List[AllMessageValues], List[AllMessageValues]]: ) -> Tuple[List[AllMessageValues], List[AllMessageValues]]:
@ -41,22 +71,11 @@ def separate_cached_messages(
filtered_messages.append((idx, message)) filtered_messages.append((idx, message))
# Validate only one block of continuous cached messages # Validate only one block of continuous cached messages
if len(filtered_messages) > 1: last_continuous_block_idx = get_first_continuous_block_idx(filtered_messages)
expected_idx = filtered_messages[0][0] + 1
for idx, _ in filtered_messages[1:]:
if idx != expected_idx:
raise VertexAIError(
status_code=422,
message="Gemini Context Caching only supports 1 message/block of continuous messages. Your idx, messages were - {}".format(
filtered_messages
),
)
expected_idx += 1
# Separate messages based on the block of cached messages # Separate messages based on the block of cached messages
if filtered_messages: if filtered_messages and last_continuous_block_idx is not None:
first_cached_idx = filtered_messages[0][0] first_cached_idx = filtered_messages[0][0]
last_cached_idx = filtered_messages[-1][0] last_cached_idx = filtered_messages[last_continuous_block_idx][0]
cached_messages = messages[first_cached_idx : last_cached_idx + 1] cached_messages = messages[first_cached_idx : last_cached_idx + 1]
non_cached_messages = ( non_cached_messages = (

View file

@ -4,6 +4,7 @@ from typing import Literal, Optional, Tuple, Union
import litellm import litellm
from litellm import verbose_logger from litellm import verbose_logger
from litellm.litellm_core_utils.llm_cost_calc.utils import _is_above_128k
""" """
Gemini pricing covers: Gemini pricing covers:
@ -22,12 +23,6 @@ Google AI Studio -> token based pricing
models_without_dynamic_pricing = ["gemini-1.0-pro", "gemini-pro"] models_without_dynamic_pricing = ["gemini-1.0-pro", "gemini-pro"]
def _is_above_128k(tokens: float) -> bool:
if tokens > 128000:
return True
return False
def cost_router( def cost_router(
model: str, model: str,
custom_llm_provider: str, custom_llm_provider: str,
@ -47,8 +42,6 @@ def cost_router(
or "codestral" in model or "codestral" in model
): ):
return "cost_per_token" return "cost_per_token"
elif custom_llm_provider == "gemini":
return "cost_per_token"
elif custom_llm_provider == "vertex_ai" and ( elif custom_llm_provider == "vertex_ai" and (
call_type == "embedding" or call_type == "aembedding" call_type == "embedding" or call_type == "aembedding"
): ):

View file

@ -26,10 +26,6 @@ import litellm.litellm_core_utils
import litellm.litellm_core_utils.litellm_logging import litellm.litellm_core_utils.litellm_logging
from litellm import verbose_logger from litellm import verbose_logger
from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.litellm_core_utils.prompt_templates.factory import (
convert_generic_image_chunk_to_openai_image_obj,
convert_to_anthropic_image_obj,
)
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler, AsyncHTTPHandler,
@ -52,7 +48,6 @@ from litellm.types.llms.vertex_ai import (
GenerateContentResponseBody, GenerateContentResponseBody,
HttpxPartType, HttpxPartType,
LogprobsResult, LogprobsResult,
PartType,
ToolConfig, ToolConfig,
Tools, Tools,
) )
@ -60,7 +55,9 @@ from litellm.types.utils import (
ChatCompletionTokenLogprob, ChatCompletionTokenLogprob,
ChoiceLogprobs, ChoiceLogprobs,
GenericStreamingChunk, GenericStreamingChunk,
PromptTokensDetailsWrapper,
TopLogprob, TopLogprob,
Usage,
) )
from litellm.utils import CustomStreamWrapper, ModelResponse from litellm.utils import CustomStreamWrapper, ModelResponse
@ -679,7 +676,7 @@ class VertexGeminiConfig(BaseConfig):
model_response.choices = [choice] model_response.choices = [choice]
## GET USAGE ## ## GET USAGE ##
usage = litellm.Usage( usage = Usage(
prompt_tokens=completion_response["usageMetadata"].get( prompt_tokens=completion_response["usageMetadata"].get(
"promptTokenCount", 0 "promptTokenCount", 0
), ),
@ -717,7 +714,7 @@ class VertexGeminiConfig(BaseConfig):
model_response.choices = [choice] model_response.choices = [choice]
## GET USAGE ## ## GET USAGE ##
usage = litellm.Usage( usage = Usage(
prompt_tokens=completion_response["usageMetadata"].get( prompt_tokens=completion_response["usageMetadata"].get(
"promptTokenCount", 0 "promptTokenCount", 0
), ),
@ -731,6 +728,35 @@ class VertexGeminiConfig(BaseConfig):
return model_response return model_response
def _calculate_usage(
self,
completion_response: GenerateContentResponseBody,
) -> Usage:
cached_tokens: Optional[int] = None
prompt_tokens_details: Optional[PromptTokensDetailsWrapper] = None
if "cachedContentTokenCount" in completion_response["usageMetadata"]:
cached_tokens = completion_response["usageMetadata"][
"cachedContentTokenCount"
]
if cached_tokens is not None:
prompt_tokens_details = PromptTokensDetailsWrapper(
cached_tokens=cached_tokens,
)
## GET USAGE ##
usage = Usage(
prompt_tokens=completion_response["usageMetadata"].get(
"promptTokenCount", 0
),
completion_tokens=completion_response["usageMetadata"].get(
"candidatesTokenCount", 0
),
total_tokens=completion_response["usageMetadata"].get("totalTokenCount", 0),
prompt_tokens_details=prompt_tokens_details,
)
return usage
def transform_response( def transform_response(
self, self,
model: str, model: str,
@ -854,19 +880,7 @@ class VertexGeminiConfig(BaseConfig):
model_response.choices.append(choice) model_response.choices.append(choice)
## GET USAGE ## usage = self._calculate_usage(completion_response=completion_response)
usage = litellm.Usage(
prompt_tokens=completion_response["usageMetadata"].get(
"promptTokenCount", 0
),
completion_tokens=completion_response["usageMetadata"].get(
"candidatesTokenCount", 0
),
total_tokens=completion_response["usageMetadata"].get(
"totalTokenCount", 0
),
)
setattr(model_response, "usage", usage) setattr(model_response, "usage", usage)
## ADD GROUNDING METADATA ## ## ADD GROUNDING METADATA ##
@ -943,126 +957,6 @@ class VertexGeminiConfig(BaseConfig):
return default_headers return default_headers
class GoogleAIStudioGeminiConfig(
VertexGeminiConfig
): # key diff from VertexAI - 'frequency_penalty' and 'presence_penalty' not supported
"""
Reference: https://ai.google.dev/api/rest/v1beta/GenerationConfig
The class `GoogleAIStudioGeminiConfig` provides configuration for the Google AI Studio's Gemini API interface. Below are the parameters:
- `temperature` (float): This controls the degree of randomness in token selection.
- `max_output_tokens` (integer): This sets the limitation for the maximum amount of token in the text output. In this case, the default value is 256.
- `top_p` (float): The tokens are selected from the most probable to the least probable until the sum of their probabilities equals the `top_p` value. Default is 0.95.
- `top_k` (integer): The value of `top_k` determines how many of the most probable tokens are considered in the selection. For example, a `top_k` of 1 means the selected token is the most probable among all tokens. The default value is 40.
- `response_mime_type` (str): The MIME type of the response. The default value is 'text/plain'. Other values - `application/json`.
- `response_schema` (dict): Optional. Output response schema of the generated candidate text when response mime type can have schema. Schema can be objects, primitives or arrays and is a subset of OpenAPI schema. If set, a compatible response_mime_type must also be set. Compatible mimetypes: application/json: Schema for JSON response.
- `candidate_count` (int): Number of generated responses to return.
- `stop_sequences` (List[str]): The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop at the first appearance of a stop sequence. The stop sequence will not be included as part of the response.
Note: Please make sure to modify the default parameters as required for your use case.
"""
temperature: Optional[float] = None
max_output_tokens: Optional[int] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
response_mime_type: Optional[str] = None
response_schema: Optional[dict] = None
candidate_count: Optional[int] = None
stop_sequences: Optional[list] = None
def __init__(
self,
temperature: Optional[float] = None,
max_output_tokens: Optional[int] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
response_mime_type: Optional[str] = None,
response_schema: Optional[dict] = None,
candidate_count: Optional[int] = None,
stop_sequences: Optional[list] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return super().get_config()
def get_supported_openai_params(self, model: str) -> List[str]:
return [
"temperature",
"top_p",
"max_tokens",
"max_completion_tokens",
"stream",
"tools",
"tool_choice",
"functions",
"response_format",
"n",
"stop",
"logprobs",
]
def map_openai_params(
self,
non_default_params: Dict,
optional_params: Dict,
model: str,
drop_params: bool,
) -> Dict:
# drop frequency_penalty and presence_penalty
if "frequency_penalty" in non_default_params:
del non_default_params["frequency_penalty"]
if "presence_penalty" in non_default_params:
del non_default_params["presence_penalty"]
return super().map_openai_params(
model=model,
non_default_params=non_default_params,
optional_params=optional_params,
drop_params=drop_params,
)
def _transform_messages(
self, messages: List[AllMessageValues]
) -> List[ContentType]:
"""
Google AI Studio Gemini does not support image urls in messages.
"""
for message in messages:
_message_content = message.get("content")
if _message_content is not None and isinstance(_message_content, list):
_parts: List[PartType] = []
for element in _message_content:
if element.get("type") == "image_url":
img_element = element
_image_url: Optional[str] = None
if isinstance(img_element.get("image_url"), dict):
_image_url = img_element["image_url"].get("url") # type: ignore
else:
_image_url = img_element.get("image_url") # type: ignore
if _image_url and "https://" in _image_url:
image_obj = convert_to_anthropic_image_obj(_image_url)
img_element["image_url"] = ( # type: ignore
convert_generic_image_chunk_to_openai_image_obj(
image_obj
)
)
return _gemini_convert_messages_with_history(messages=messages)
async def make_call( async def make_call(
client: Optional[AsyncHTTPHandler], client: Optional[AsyncHTTPHandler],
api_base: str, api_base: str,

View file

@ -1944,6 +1944,8 @@
"max_output_tokens": 4096, "max_output_tokens": 4096,
"input_cost_per_token": 0.00000014, "input_cost_per_token": 0.00000014,
"input_cost_per_token_cache_hit": 0.000000014, "input_cost_per_token_cache_hit": 0.000000014,
"cache_read_input_token_cost": 0.000000014,
"cache_creation_input_token_cost": 0.0,
"output_cost_per_token": 0.00000028, "output_cost_per_token": 0.00000028,
"litellm_provider": "deepseek", "litellm_provider": "deepseek",
"mode": "chat", "mode": "chat",
@ -3758,6 +3760,8 @@
"max_audio_length_hours": 8.4, "max_audio_length_hours": 8.4,
"max_audio_per_prompt": 1, "max_audio_per_prompt": 1,
"max_pdf_size_mb": 30, "max_pdf_size_mb": 30,
"cache_read_input_token_cost": 0.00000001875,
"cache_creation_input_token_cost": 0.000001,
"input_cost_per_token": 0.000000075, "input_cost_per_token": 0.000000075,
"input_cost_per_token_above_128k_tokens": 0.00000015, "input_cost_per_token_above_128k_tokens": 0.00000015,
"output_cost_per_token": 0.0000003, "output_cost_per_token": 0.0000003,
@ -3783,6 +3787,8 @@
"max_audio_length_hours": 8.4, "max_audio_length_hours": 8.4,
"max_audio_per_prompt": 1, "max_audio_per_prompt": 1,
"max_pdf_size_mb": 30, "max_pdf_size_mb": 30,
"cache_read_input_token_cost": 0.00000001875,
"cache_creation_input_token_cost": 0.000001,
"input_cost_per_token": 0.000000075, "input_cost_per_token": 0.000000075,
"input_cost_per_token_above_128k_tokens": 0.00000015, "input_cost_per_token_above_128k_tokens": 0.00000015,
"output_cost_per_token": 0.0000003, "output_cost_per_token": 0.0000003,
@ -3842,6 +3848,7 @@
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true, "supports_vision": true,
"supports_response_schema": true, "supports_response_schema": true,
"supports_prompt_caching": true,
"tpm": 4000000, "tpm": 4000000,
"rpm": 2000, "rpm": 2000,
"source": "https://ai.google.dev/pricing" "source": "https://ai.google.dev/pricing"
@ -3866,6 +3873,7 @@
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true, "supports_vision": true,
"supports_response_schema": true, "supports_response_schema": true,
"supports_prompt_caching": true,
"tpm": 4000000, "tpm": 4000000,
"rpm": 4000, "rpm": 4000,
"source": "https://ai.google.dev/pricing" "source": "https://ai.google.dev/pricing"
@ -3890,6 +3898,7 @@
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true, "supports_vision": true,
"supports_response_schema": true, "supports_response_schema": true,
"supports_prompt_caching": true,
"tpm": 4000000, "tpm": 4000000,
"rpm": 4000, "rpm": 4000,
"source": "https://ai.google.dev/pricing" "source": "https://ai.google.dev/pricing"

View file

@ -183,6 +183,7 @@ class UsageMetadata(TypedDict, total=False):
promptTokenCount: int promptTokenCount: int
totalTokenCount: int totalTokenCount: int
candidatesTokenCount: int candidatesTokenCount: int
cachedContentTokenCount: int
class CachedContent(TypedDict, total=False): class CachedContent(TypedDict, total=False):

View file

@ -1944,6 +1944,8 @@
"max_output_tokens": 4096, "max_output_tokens": 4096,
"input_cost_per_token": 0.00000014, "input_cost_per_token": 0.00000014,
"input_cost_per_token_cache_hit": 0.000000014, "input_cost_per_token_cache_hit": 0.000000014,
"cache_read_input_token_cost": 0.000000014,
"cache_creation_input_token_cost": 0.0,
"output_cost_per_token": 0.00000028, "output_cost_per_token": 0.00000028,
"litellm_provider": "deepseek", "litellm_provider": "deepseek",
"mode": "chat", "mode": "chat",
@ -3758,6 +3760,8 @@
"max_audio_length_hours": 8.4, "max_audio_length_hours": 8.4,
"max_audio_per_prompt": 1, "max_audio_per_prompt": 1,
"max_pdf_size_mb": 30, "max_pdf_size_mb": 30,
"cache_read_input_token_cost": 0.00000001875,
"cache_creation_input_token_cost": 0.000001,
"input_cost_per_token": 0.000000075, "input_cost_per_token": 0.000000075,
"input_cost_per_token_above_128k_tokens": 0.00000015, "input_cost_per_token_above_128k_tokens": 0.00000015,
"output_cost_per_token": 0.0000003, "output_cost_per_token": 0.0000003,
@ -3783,6 +3787,8 @@
"max_audio_length_hours": 8.4, "max_audio_length_hours": 8.4,
"max_audio_per_prompt": 1, "max_audio_per_prompt": 1,
"max_pdf_size_mb": 30, "max_pdf_size_mb": 30,
"cache_read_input_token_cost": 0.00000001875,
"cache_creation_input_token_cost": 0.000001,
"input_cost_per_token": 0.000000075, "input_cost_per_token": 0.000000075,
"input_cost_per_token_above_128k_tokens": 0.00000015, "input_cost_per_token_above_128k_tokens": 0.00000015,
"output_cost_per_token": 0.0000003, "output_cost_per_token": 0.0000003,
@ -3842,6 +3848,7 @@
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true, "supports_vision": true,
"supports_response_schema": true, "supports_response_schema": true,
"supports_prompt_caching": true,
"tpm": 4000000, "tpm": 4000000,
"rpm": 2000, "rpm": 2000,
"source": "https://ai.google.dev/pricing" "source": "https://ai.google.dev/pricing"
@ -3866,6 +3873,7 @@
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true, "supports_vision": true,
"supports_response_schema": true, "supports_response_schema": true,
"supports_prompt_caching": true,
"tpm": 4000000, "tpm": 4000000,
"rpm": 4000, "rpm": 4000,
"source": "https://ai.google.dev/pricing" "source": "https://ai.google.dev/pricing"
@ -3890,6 +3898,7 @@
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true, "supports_vision": true,
"supports_response_schema": true, "supports_response_schema": true,
"supports_prompt_caching": true,
"tpm": 4000000, "tpm": 4000000,
"rpm": 4000, "rpm": 4000,
"source": "https://ai.google.dev/pricing" "source": "https://ai.google.dev/pricing"

View file

@ -5,6 +5,7 @@ import sys
from typing import Any, Dict, List from typing import Any, Dict, List
from unittest.mock import MagicMock, Mock, patch from unittest.mock import MagicMock, Mock, patch
import os import os
import uuid
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
@ -45,6 +46,7 @@ def _usage_format_tests(usage: litellm.Usage):
} }
``` ```
""" """
print(f"usage={usage}")
assert usage.total_tokens == usage.prompt_tokens + usage.completion_tokens assert usage.total_tokens == usage.prompt_tokens + usage.completion_tokens
assert usage.prompt_tokens > usage.prompt_tokens_details.cached_tokens assert usage.prompt_tokens > usage.prompt_tokens_details.cached_tokens
@ -342,18 +344,17 @@ class BaseLLMChatTest(ABC):
print("Model does not support prompt caching") print("Model does not support prompt caching")
pytest.skip("Model does not support prompt caching") pytest.skip("Model does not support prompt caching")
try: uuid_str = str(uuid.uuid4())
for _ in range(2): messages = [
response = self.completion_function(
**base_completion_call_args,
messages=[
# System Message # System Message
{ {
"role": "system", "role": "system",
"content": [ "content": [
{ {
"type": "text", "type": "text",
"text": "Here is the full text of a complex legal agreement" "text": "Here is the full text of a complex legal agreement {}".format(
uuid_str
)
* 400, * 400,
"cache_control": {"type": "ephemeral"}, "cache_control": {"type": "ephemeral"},
} }
@ -385,10 +386,32 @@ class BaseLLMChatTest(ABC):
} }
], ],
}, },
], ]
try:
## call 1
response = self.completion_function(
**base_completion_call_args,
messages=messages,
max_tokens=10, max_tokens=10,
) )
initial_cost = response._hidden_params["response_cost"]
## call 2
response = self.completion_function(
**base_completion_call_args,
messages=messages,
max_tokens=10,
)
cached_cost = response._hidden_params["response_cost"]
assert (
cached_cost <= initial_cost
), "Cached cost={} should be less than initial cost={}".format(
cached_cost, initial_cost
)
_usage_format_tests(response.usage) _usage_format_tests(response.usage)
print("response=", response) print("response=", response)

View file

@ -1,9 +1,21 @@
import os
import sys
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from base_llm_unit_tests import BaseLLMChatTest from base_llm_unit_tests import BaseLLMChatTest
from litellm.llms.vertex_ai.context_caching.transformation import (
separate_cached_messages,
)
class TestGoogleAIStudioGemini(BaseLLMChatTest): class TestGoogleAIStudioGemini(BaseLLMChatTest):
def get_base_completion_call_args(self) -> dict: def get_base_completion_call_args(self) -> dict:
return {"model": "gemini/gemini-1.5-flash"} return {"model": "gemini/gemini-1.5-flash-002"}
def test_tool_call_no_arguments(self, tool_call_no_arguments): def test_tool_call_no_arguments(self, tool_call_no_arguments):
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833""" """Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
@ -13,3 +25,50 @@ class TestGoogleAIStudioGemini(BaseLLMChatTest):
result = convert_to_gemini_tool_call_invoke(tool_call_no_arguments) result = convert_to_gemini_tool_call_invoke(tool_call_no_arguments)
print(result) print(result)
def test_gemini_context_caching_separate_messages():
messages = [
# System Message
{
"role": "system",
"content": [
{
"type": "text",
"text": "Here is the full text of a complex legal agreement" * 400,
"cache_control": {"type": "ephemeral"},
}
],
},
# marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache.
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key terms and conditions in this agreement?",
"cache_control": {"type": "ephemeral"},
}
],
},
{
"role": "assistant",
"content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/mo",
},
# The final turn is marked with cache-control, for continuing in followups.
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key terms and conditions in this agreement?",
"cache_control": {"type": "ephemeral"},
}
],
},
]
cached_messages, non_cached_messages = separate_cached_messages(messages)
print(cached_messages)
print(non_cached_messages)
assert len(cached_messages) > 0, "Cached messages should be present"
assert len(non_cached_messages) > 0, "Non-cached messages should be present"

View file

@ -353,9 +353,9 @@ def test_all_model_configs():
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import ( from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
VertexAIConfig, VertexAIConfig,
GoogleAIStudioGeminiConfig,
VertexGeminiConfig, VertexGeminiConfig,
) )
from litellm.llms.gemini.chat.transformation import GoogleAIStudioGeminiConfig
assert "max_completion_tokens" in VertexAIConfig().get_supported_openai_params() assert "max_completion_tokens" in VertexAIConfig().get_supported_openai_params()

View file

@ -1163,14 +1163,19 @@ def test_completion_cost_azure_common_deployment_name():
assert "azure/gpt-4" == mock_client.call_args.kwargs["base_model"] assert "azure/gpt-4" == mock_client.call_args.kwargs["base_model"]
def test_completion_cost_anthropic_prompt_caching(): @pytest.mark.parametrize(
"model, custom_llm_provider",
[
("claude-3-5-sonnet-20240620", "anthropic"),
("gemini/gemini-1.5-flash-001", "gemini"),
],
)
def test_completion_cost_prompt_caching(model, custom_llm_provider):
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="") litellm.model_cost = litellm.get_model_cost_map(url="")
from litellm.utils import Choices, Message, ModelResponse, Usage from litellm.utils import Choices, Message, ModelResponse, Usage
model = "anthropic/claude-3-5-sonnet-20240620"
## WRITE TO CACHE ## (MORE EXPENSIVE) ## WRITE TO CACHE ## (MORE EXPENSIVE)
response_1 = ModelResponse( response_1 = ModelResponse(
id="chatcmpl-3f427194-0840-4d08-b571-56bfe38a5424", id="chatcmpl-3f427194-0840-4d08-b571-56bfe38a5424",
@ -1187,7 +1192,7 @@ def test_completion_cost_anthropic_prompt_caching():
) )
], ],
created=1725036547, created=1725036547,
model="claude-3-5-sonnet-20240620", model=model,
object="chat.completion", object="chat.completion",
system_fingerprint=None, system_fingerprint=None,
usage=Usage( usage=Usage(
@ -1203,7 +1208,7 @@ def test_completion_cost_anthropic_prompt_caching():
cost_1 = completion_cost(model=model, completion_response=response_1) cost_1 = completion_cost(model=model, completion_response=response_1)
_model_info = litellm.get_model_info( _model_info = litellm.get_model_info(
model="claude-3-5-sonnet-20240620", custom_llm_provider="anthropic" model=model, custom_llm_provider=custom_llm_provider
) )
expected_cost = ( expected_cost = (
( (
@ -1211,11 +1216,12 @@ def test_completion_cost_anthropic_prompt_caching():
- response_1.usage.prompt_tokens_details.cached_tokens - response_1.usage.prompt_tokens_details.cached_tokens
) )
* _model_info["input_cost_per_token"] * _model_info["input_cost_per_token"]
+ response_1.usage.prompt_tokens_details.cached_tokens + (response_1.usage.prompt_tokens_details.cached_tokens or 0)
* _model_info["cache_read_input_token_cost"] * _model_info["cache_read_input_token_cost"]
+ response_1.usage.cache_creation_input_tokens + (response_1.usage.cache_creation_input_tokens or 0)
* _model_info["cache_creation_input_token_cost"] * _model_info["cache_creation_input_token_cost"]
+ response_1.usage.completion_tokens * _model_info["output_cost_per_token"] + (response_1.usage.completion_tokens or 0)
* _model_info["output_cost_per_token"]
) # Cost of processing (non-cache hit + cache hit) + Cost of cache-writing (cache writing) ) # Cost of processing (non-cache hit + cache hit) + Cost of cache-writing (cache writing)
assert round(expected_cost, 5) == round(cost_1, 5) assert round(expected_cost, 5) == round(cost_1, 5)
@ -1238,7 +1244,7 @@ def test_completion_cost_anthropic_prompt_caching():
) )
], ],
created=1725036547, created=1725036547,
model="claude-3-5-sonnet-20240620", model=model,
object="chat.completion", object="chat.completion",
system_fingerprint=None, system_fingerprint=None,
usage=Usage( usage=Usage(
@ -2437,7 +2443,7 @@ def test_completion_cost_params_2():
def test_completion_cost_params_gemini_3(): def test_completion_cost_params_gemini_3():
from litellm.utils import Choices, Message, ModelResponse, Usage from litellm.utils import Choices, Message, ModelResponse, Usage
from litellm.litellm_core_utils.llm_cost_calc.google import cost_per_character from litellm.llms.vertex_ai.cost_calculator import cost_per_character
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="") litellm.model_cost = litellm.get_model_cost_map(url="")