litellm/tests/local_testing/test_unit_test_caching.py
Krish Dholakia 44e7ffd05c
(perf) Litellm redis router fix - ~100ms improvement (#6483)
* docs(exception_mapping.md): add missing exception types

Fixes https://github.com/Aider-AI/aider/issues/2120#issuecomment-2438971183

* fix(main.py): register custom model pricing with specific key

Ensure custom model pricing is registered to the specific model+provider key combination

* test: make testing more robust for custom pricing

* fix(redis_cache.py): instrument otel logging for sync redis calls

ensures complete coverage for all redis cache calls

* refactor: pass parent_otel_span for redis caching calls in router

allows for more observability into what calls are causing latency issues

* test: update tests with new params

* refactor: ensure e2e otel tracing for router

* refactor(router.py): add more otel tracing acrosss router

catch all latency issues for router requests

* fix: fix linting error

* fix(router.py): fix linting error

* fix: fix test

* test: fix tests

* fix(dual_cache.py): pass ttl to redis cache

* fix: fix param

* perf(cooldown_cache.py): improve cooldown cache, to store cache results in memory for 5s, prevents redis call from being made on each request

reduces 100ms latency per call with caching enabled on router

* fix: fix test

* fix(cooldown_cache.py): handle if a result is None

* fix(cooldown_cache.py): add debug statements

* refactor(dual_cache.py): move to using an in-memory check for batch get cache, to prevent redis from being hit for every call

* fix(cooldown_cache.py): fix linting erropr
2024-10-29 13:58:29 -07:00

245 lines
7.1 KiB
Python

import os
import sys
import time
import traceback
import uuid
from dotenv import load_dotenv
from test_rerank import assert_response_shape
load_dotenv()
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import asyncio
import hashlib
import random
import pytest
import litellm
from litellm import aembedding, completion, embedding
from litellm.caching.caching import Cache
from unittest.mock import AsyncMock, patch, MagicMock
from litellm.caching.caching_handler import LLMCachingHandler, CachingHandlerResponse
from litellm.caching.caching import LiteLLMCacheType
from litellm.types.utils import CallTypes
from litellm.types.rerank import RerankResponse
from litellm.types.utils import (
ModelResponse,
EmbeddingResponse,
TextCompletionResponse,
TranscriptionResponse,
Embedding,
)
from datetime import timedelta, datetime
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from litellm._logging import verbose_logger
import logging
def test_get_kwargs_for_cache_key():
_cache = litellm.Cache()
relevant_kwargs = _cache._get_relevant_args_to_use_for_cache_key()
print(relevant_kwargs)
def test_get_cache_key_chat_completion():
cache = Cache()
kwargs = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hello, world!"}],
"temperature": 0.7,
}
cache_key_1 = cache.get_cache_key(**kwargs)
assert isinstance(cache_key_1, str)
assert len(cache_key_1) > 0
kwargs_2 = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hello, world!"}],
"max_completion_tokens": 100,
}
cache_key_2 = cache.get_cache_key(**kwargs_2)
assert cache_key_1 != cache_key_2
kwargs_3 = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hello, world!"}],
"max_completion_tokens": 100,
}
cache_key_3 = cache.get_cache_key(**kwargs_3)
assert cache_key_2 == cache_key_3
def test_get_cache_key_embedding():
cache = Cache()
kwargs = {
"model": "text-embedding-3-small",
"input": "Hello, world!",
"dimensions": 1536,
}
cache_key_1 = cache.get_cache_key(**kwargs)
assert isinstance(cache_key_1, str)
assert len(cache_key_1) > 0
kwargs_2 = {
"model": "text-embedding-3-small",
"input": "Hello, world!",
"dimensions": 1539,
}
cache_key_2 = cache.get_cache_key(**kwargs_2)
assert cache_key_1 != cache_key_2
kwargs_3 = {
"model": "text-embedding-3-small",
"input": "Hello, world!",
"dimensions": 1539,
}
cache_key_3 = cache.get_cache_key(**kwargs_3)
assert cache_key_2 == cache_key_3
def test_get_cache_key_text_completion():
cache = Cache()
kwargs = {
"model": "gpt-3.5-turbo",
"prompt": "Hello, world! here is a second line",
"best_of": 3,
"logit_bias": {"123": 1},
"seed": 42,
}
cache_key_1 = cache.get_cache_key(**kwargs)
assert isinstance(cache_key_1, str)
assert len(cache_key_1) > 0
kwargs_2 = {
"model": "gpt-3.5-turbo",
"prompt": "Hello, world! here is a second line",
"best_of": 30,
}
cache_key_2 = cache.get_cache_key(**kwargs_2)
assert cache_key_1 != cache_key_2
kwargs_3 = {
"model": "gpt-3.5-turbo",
"prompt": "Hello, world! here is a second line",
"best_of": 30,
}
cache_key_3 = cache.get_cache_key(**kwargs_3)
assert cache_key_2 == cache_key_3
def test_get_hashed_cache_key():
cache = Cache()
cache_key = "model:gpt-3.5-turbo,messages:Hello world"
hashed_key = Cache._get_hashed_cache_key(cache_key)
assert len(hashed_key) == 64 # SHA-256 produces a 64-character hex string
def test_add_redis_namespace_to_cache_key():
cache = Cache(namespace="test_namespace")
hashed_key = "abcdef1234567890"
# Test with class-level namespace
result = cache._add_redis_namespace_to_cache_key(hashed_key)
assert result == "test_namespace:abcdef1234567890"
# Test with metadata namespace
kwargs = {"metadata": {"redis_namespace": "custom_namespace"}}
result = cache._add_redis_namespace_to_cache_key(hashed_key, **kwargs)
assert result == "custom_namespace:abcdef1234567890"
def test_get_model_param_value():
cache = Cache()
# Test with regular model
kwargs = {"model": "gpt-3.5-turbo"}
assert cache._get_model_param_value(kwargs) == "gpt-3.5-turbo"
# Test with model_group
kwargs = {"model": "gpt-3.5-turbo", "metadata": {"model_group": "gpt-group"}}
assert cache._get_model_param_value(kwargs) == "gpt-group"
# Test with caching_group
kwargs = {
"model": "gpt-3.5-turbo",
"metadata": {
"model_group": "openai-gpt-3.5-turbo",
"caching_groups": [("openai-gpt-3.5-turbo", "azure-gpt-3.5-turbo")],
},
}
assert (
cache._get_model_param_value(kwargs)
== "('openai-gpt-3.5-turbo', 'azure-gpt-3.5-turbo')"
)
kwargs = {
"model": "gpt-3.5-turbo",
"metadata": {
"model_group": "azure-gpt-3.5-turbo",
"caching_groups": [("openai-gpt-3.5-turbo", "azure-gpt-3.5-turbo")],
},
}
assert (
cache._get_model_param_value(kwargs)
== "('openai-gpt-3.5-turbo', 'azure-gpt-3.5-turbo')"
)
kwargs = {
"model": "gpt-3.5-turbo",
"metadata": {
"model_group": "not-in-caching-group-gpt-3.5-turbo",
"caching_groups": [("openai-gpt-3.5-turbo", "azure-gpt-3.5-turbo")],
},
}
assert cache._get_model_param_value(kwargs) == "not-in-caching-group-gpt-3.5-turbo"
def test_preset_cache_key():
"""
Test that the preset cache key is used if it is set in kwargs["litellm_params"]
"""
cache = Cache()
kwargs = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hello, world!"}],
"temperature": 0.7,
"litellm_params": {"preset_cache_key": "preset-cache-key"},
}
assert cache.get_cache_key(**kwargs) == "preset-cache-key"
def test_generate_streaming_content():
cache = Cache()
content = "Hello, this is a test message."
generator = cache.generate_streaming_content(content)
full_response = ""
chunk_count = 0
for chunk in generator:
chunk_count += 1
assert "choices" in chunk
assert len(chunk["choices"]) == 1
assert "delta" in chunk["choices"][0]
assert "role" in chunk["choices"][0]["delta"]
assert chunk["choices"][0]["delta"]["role"] == "assistant"
assert "content" in chunk["choices"][0]["delta"]
chunk_content = chunk["choices"][0]["delta"]["content"]
full_response += chunk_content
# Check that each chunk is no longer than 5 characters
assert len(chunk_content) <= 5
print("full_response from generate_streaming_content", full_response)
# Check that the full content is reconstructed correctly
assert full_response == content
# Check that there were multiple chunks
assert chunk_count > 1
print(f"Number of chunks: {chunk_count}")