(refactor) get_cache_key to be under 100 LOC function (#6327)

* refactor - use helpers for name space and hashing

* use openai to get the relevant supported params

* use helpers for getting cache key

* fix test caching

* use get/set helpers for preset cache keys

* make get_cache_key under 100 LOC

* fix _get_model_param_value

* fix _get_caching_group

* fix linting error

* add unit testing for get cache key

* test_generate_streaming_content
This commit is contained in:
Ishaan Jaff 2024-10-19 15:21:11 +05:30 committed by GitHub
parent 4cbdad9fc5
commit 979e8ea526
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 477 additions and 124 deletions

View file

@ -974,7 +974,7 @@ async def test_redis_cache_acompletion_stream():
response_1_content += chunk.choices[0].delta.content or ""
print(response_1_content)
time.sleep(0.5)
await asyncio.sleep(0.5)
print("\n\n Response 1 content: ", response_1_content, "\n\n")
response2 = await litellm.acompletion(

View file

@ -0,0 +1,245 @@
import os
import sys
import time
import traceback
import uuid
from dotenv import load_dotenv
from test_rerank import assert_response_shape
load_dotenv()
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import asyncio
import hashlib
import random
import pytest
import litellm
from litellm import aembedding, completion, embedding
from litellm.caching.caching import Cache
from unittest.mock import AsyncMock, patch, MagicMock
from litellm.caching.caching_handler import LLMCachingHandler, CachingHandlerResponse
from litellm.caching.caching import LiteLLMCacheType
from litellm.types.utils import CallTypes
from litellm.types.rerank import RerankResponse
from litellm.types.utils import (
ModelResponse,
EmbeddingResponse,
TextCompletionResponse,
TranscriptionResponse,
Embedding,
)
from datetime import timedelta, datetime
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from litellm._logging import verbose_logger
import logging
def test_get_kwargs_for_cache_key():
_cache = litellm.Cache()
relevant_kwargs = _cache._get_relevant_args_to_use_for_cache_key()
print(relevant_kwargs)
def test_get_cache_key_chat_completion():
cache = Cache()
kwargs = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hello, world!"}],
"temperature": 0.7,
}
cache_key_1 = cache.get_cache_key(**kwargs)
assert isinstance(cache_key_1, str)
assert len(cache_key_1) > 0
kwargs_2 = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hello, world!"}],
"max_completion_tokens": 100,
}
cache_key_2 = cache.get_cache_key(**kwargs_2)
assert cache_key_1 != cache_key_2
kwargs_3 = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hello, world!"}],
"max_completion_tokens": 100,
}
cache_key_3 = cache.get_cache_key(**kwargs_3)
assert cache_key_2 == cache_key_3
def test_get_cache_key_embedding():
cache = Cache()
kwargs = {
"model": "text-embedding-3-small",
"input": "Hello, world!",
"dimensions": 1536,
}
cache_key_1 = cache.get_cache_key(**kwargs)
assert isinstance(cache_key_1, str)
assert len(cache_key_1) > 0
kwargs_2 = {
"model": "text-embedding-3-small",
"input": "Hello, world!",
"dimensions": 1539,
}
cache_key_2 = cache.get_cache_key(**kwargs_2)
assert cache_key_1 != cache_key_2
kwargs_3 = {
"model": "text-embedding-3-small",
"input": "Hello, world!",
"dimensions": 1539,
}
cache_key_3 = cache.get_cache_key(**kwargs_3)
assert cache_key_2 == cache_key_3
def test_get_cache_key_text_completion():
cache = Cache()
kwargs = {
"model": "gpt-3.5-turbo",
"prompt": "Hello, world! here is a second line",
"best_of": 3,
"logit_bias": {"123": 1},
"seed": 42,
}
cache_key_1 = cache.get_cache_key(**kwargs)
assert isinstance(cache_key_1, str)
assert len(cache_key_1) > 0
kwargs_2 = {
"model": "gpt-3.5-turbo",
"prompt": "Hello, world! here is a second line",
"best_of": 30,
}
cache_key_2 = cache.get_cache_key(**kwargs_2)
assert cache_key_1 != cache_key_2
kwargs_3 = {
"model": "gpt-3.5-turbo",
"prompt": "Hello, world! here is a second line",
"best_of": 30,
}
cache_key_3 = cache.get_cache_key(**kwargs_3)
assert cache_key_2 == cache_key_3
def test_get_hashed_cache_key():
cache = Cache()
cache_key = "model:gpt-3.5-turbo,messages:Hello world"
hashed_key = cache._get_hashed_cache_key(cache_key)
assert len(hashed_key) == 64 # SHA-256 produces a 64-character hex string
def test_add_redis_namespace_to_cache_key():
cache = Cache(namespace="test_namespace")
hashed_key = "abcdef1234567890"
# Test with class-level namespace
result = cache._add_redis_namespace_to_cache_key(hashed_key)
assert result == "test_namespace:abcdef1234567890"
# Test with metadata namespace
kwargs = {"metadata": {"redis_namespace": "custom_namespace"}}
result = cache._add_redis_namespace_to_cache_key(hashed_key, **kwargs)
assert result == "custom_namespace:abcdef1234567890"
def test_get_model_param_value():
cache = Cache()
# Test with regular model
kwargs = {"model": "gpt-3.5-turbo"}
assert cache._get_model_param_value(kwargs) == "gpt-3.5-turbo"
# Test with model_group
kwargs = {"model": "gpt-3.5-turbo", "metadata": {"model_group": "gpt-group"}}
assert cache._get_model_param_value(kwargs) == "gpt-group"
# Test with caching_group
kwargs = {
"model": "gpt-3.5-turbo",
"metadata": {
"model_group": "openai-gpt-3.5-turbo",
"caching_groups": [("openai-gpt-3.5-turbo", "azure-gpt-3.5-turbo")],
},
}
assert (
cache._get_model_param_value(kwargs)
== "('openai-gpt-3.5-turbo', 'azure-gpt-3.5-turbo')"
)
kwargs = {
"model": "gpt-3.5-turbo",
"metadata": {
"model_group": "azure-gpt-3.5-turbo",
"caching_groups": [("openai-gpt-3.5-turbo", "azure-gpt-3.5-turbo")],
},
}
assert (
cache._get_model_param_value(kwargs)
== "('openai-gpt-3.5-turbo', 'azure-gpt-3.5-turbo')"
)
kwargs = {
"model": "gpt-3.5-turbo",
"metadata": {
"model_group": "not-in-caching-group-gpt-3.5-turbo",
"caching_groups": [("openai-gpt-3.5-turbo", "azure-gpt-3.5-turbo")],
},
}
assert cache._get_model_param_value(kwargs) == "not-in-caching-group-gpt-3.5-turbo"
def test_preset_cache_key():
"""
Test that the preset cache key is used if it is set in kwargs["litellm_params"]
"""
cache = Cache()
kwargs = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hello, world!"}],
"temperature": 0.7,
"litellm_params": {"preset_cache_key": "preset-cache-key"},
}
assert cache.get_cache_key(**kwargs) == "preset-cache-key"
def test_generate_streaming_content():
cache = Cache()
content = "Hello, this is a test message."
generator = cache.generate_streaming_content(content)
full_response = ""
chunk_count = 0
for chunk in generator:
chunk_count += 1
assert "choices" in chunk
assert len(chunk["choices"]) == 1
assert "delta" in chunk["choices"][0]
assert "role" in chunk["choices"][0]["delta"]
assert chunk["choices"][0]["delta"]["role"] == "assistant"
assert "content" in chunk["choices"][0]["delta"]
chunk_content = chunk["choices"][0]["delta"]["content"]
full_response += chunk_content
# Check that each chunk is no longer than 5 characters
assert len(chunk_content) <= 5
print("full_response from generate_streaming_content", full_response)
# Check that the full content is reconstructed correctly
assert full_response == content
# Check that there were multiple chunks
assert chunk_count > 1
print(f"Number of chunks: {chunk_count}")