forked from phoenix/litellm-mirror
(refactor) get_cache_key
to be under 100 LOC function (#6327)
* refactor - use helpers for name space and hashing * use openai to get the relevant supported params * use helpers for getting cache key * fix test caching * use get/set helpers for preset cache keys * make get_cache_key under 100 LOC * fix _get_model_param_value * fix _get_caching_group * fix linting error * add unit testing for get cache key * test_generate_streaming_content
This commit is contained in:
parent
4cbdad9fc5
commit
979e8ea526
5 changed files with 477 additions and 124 deletions
|
@ -17,13 +17,26 @@ import logging
|
|||
import time
|
||||
import traceback
|
||||
from enum import Enum
|
||||
from typing import Any, List, Literal, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union
|
||||
|
||||
from openai.types.audio.transcription_create_params import TranscriptionCreateParams
|
||||
from openai.types.chat.completion_create_params import (
|
||||
CompletionCreateParamsNonStreaming,
|
||||
CompletionCreateParamsStreaming,
|
||||
)
|
||||
from openai.types.completion_create_params import (
|
||||
CompletionCreateParamsNonStreaming as TextCompletionCreateParamsNonStreaming,
|
||||
)
|
||||
from openai.types.completion_create_params import (
|
||||
CompletionCreateParamsStreaming as TextCompletionCreateParamsStreaming,
|
||||
)
|
||||
from openai.types.embedding_create_params import EmbeddingCreateParams
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.types.caching import *
|
||||
from litellm.types.rerank import RerankRequest
|
||||
from litellm.types.utils import all_litellm_params
|
||||
|
||||
from .base_cache import BaseCache
|
||||
|
@ -220,7 +233,7 @@ class Cache:
|
|||
if self.namespace is not None and isinstance(self.cache, RedisCache):
|
||||
self.cache.namespace = self.namespace
|
||||
|
||||
def get_cache_key(self, *args, **kwargs) -> str: # noqa: PLR0915
|
||||
def get_cache_key(self, *args, **kwargs) -> str:
|
||||
"""
|
||||
Get the cache key for the given arguments.
|
||||
|
||||
|
@ -232,106 +245,19 @@ class Cache:
|
|||
str: The cache key generated from the arguments, or None if no cache key could be generated.
|
||||
"""
|
||||
cache_key = ""
|
||||
print_verbose(f"\nGetting Cache key. Kwargs: {kwargs}")
|
||||
verbose_logger.debug("\nGetting Cache key. Kwargs: %s", kwargs)
|
||||
|
||||
# for streaming, we use preset_cache_key. It's created in wrapper(), we do this because optional params like max_tokens, get transformed for bedrock -> max_new_tokens
|
||||
if kwargs.get("litellm_params", {}).get("preset_cache_key", None) is not None:
|
||||
_preset_cache_key = kwargs.get("litellm_params", {}).get(
|
||||
"preset_cache_key", None
|
||||
)
|
||||
print_verbose(f"\nReturning preset cache key: {_preset_cache_key}")
|
||||
return _preset_cache_key
|
||||
preset_cache_key = self._get_preset_cache_key_from_kwargs(**kwargs)
|
||||
if preset_cache_key is not None:
|
||||
verbose_logger.debug("\nReturning preset cache key: %s", preset_cache_key)
|
||||
return preset_cache_key
|
||||
|
||||
# sort kwargs by keys, since model: [gpt-4, temperature: 0.2, max_tokens: 200] == [temperature: 0.2, max_tokens: 200, model: gpt-4]
|
||||
completion_kwargs = [
|
||||
"model",
|
||||
"messages",
|
||||
"prompt",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"n",
|
||||
"stop",
|
||||
"max_tokens",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"logit_bias",
|
||||
"user",
|
||||
"response_format",
|
||||
"seed",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"stream",
|
||||
]
|
||||
embedding_only_kwargs = [
|
||||
"input",
|
||||
"encoding_format",
|
||||
] # embedding kwargs = model, input, user, encoding_format. Model, user are checked in completion_kwargs
|
||||
transcription_only_kwargs = [
|
||||
"file",
|
||||
"language",
|
||||
]
|
||||
rerank_only_kwargs = [
|
||||
"top_n",
|
||||
"rank_fields",
|
||||
"return_documents",
|
||||
"max_chunks_per_doc",
|
||||
"documents",
|
||||
"query",
|
||||
]
|
||||
# combined_kwargs - NEEDS to be ordered across get_cache_key(). Do not use a set()
|
||||
combined_kwargs = (
|
||||
completion_kwargs
|
||||
+ embedding_only_kwargs
|
||||
+ transcription_only_kwargs
|
||||
+ rerank_only_kwargs
|
||||
)
|
||||
combined_kwargs = self._get_relevant_args_to_use_for_cache_key()
|
||||
litellm_param_kwargs = all_litellm_params
|
||||
for param in kwargs:
|
||||
if param in combined_kwargs:
|
||||
# check if param == model and model_group is passed in, then override model with model_group
|
||||
if param == "model":
|
||||
model_group = None
|
||||
caching_group = None
|
||||
metadata = kwargs.get("metadata", None)
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
if metadata is not None:
|
||||
model_group = metadata.get("model_group")
|
||||
model_group = metadata.get("model_group", None)
|
||||
caching_groups = metadata.get("caching_groups", None)
|
||||
if caching_groups:
|
||||
for group in caching_groups:
|
||||
if model_group in group:
|
||||
caching_group = group
|
||||
break
|
||||
if litellm_params is not None:
|
||||
metadata = litellm_params.get("metadata", None)
|
||||
if metadata is not None:
|
||||
model_group = metadata.get("model_group", None)
|
||||
caching_groups = metadata.get("caching_groups", None)
|
||||
if caching_groups:
|
||||
for group in caching_groups:
|
||||
if model_group in group:
|
||||
caching_group = group
|
||||
break
|
||||
param_value = (
|
||||
caching_group or model_group or kwargs[param]
|
||||
) # use caching_group, if set then model_group if it exists, else use kwargs["model"]
|
||||
elif param == "file":
|
||||
file = kwargs.get("file")
|
||||
metadata = kwargs.get("metadata", {})
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
|
||||
# get checksum of file content
|
||||
param_value = (
|
||||
metadata.get("file_checksum")
|
||||
or getattr(file, "name", None)
|
||||
or metadata.get("file_name")
|
||||
or litellm_params.get("file_name")
|
||||
)
|
||||
else:
|
||||
if kwargs[param] is None:
|
||||
continue # ignore None params
|
||||
param_value = kwargs[param]
|
||||
param_value: Optional[str] = self._get_param_value(param, kwargs)
|
||||
if param_value is not None:
|
||||
cache_key += f"{str(param)}: {str(param_value)}"
|
||||
elif (
|
||||
param not in litellm_param_kwargs
|
||||
|
@ -344,19 +270,200 @@ class Cache:
|
|||
param_value = kwargs[param]
|
||||
cache_key += f"{str(param)}: {str(param_value)}"
|
||||
|
||||
print_verbose(f"\nCreated cache key: {cache_key}")
|
||||
# Use hashlib to create a sha256 hash of the cache key
|
||||
verbose_logger.debug("\nCreated cache key: %s", cache_key)
|
||||
hashed_cache_key = self._get_hashed_cache_key(cache_key)
|
||||
hashed_cache_key = self._add_redis_namespace_to_cache_key(
|
||||
hashed_cache_key, **kwargs
|
||||
)
|
||||
self._set_preset_cache_key_in_kwargs(
|
||||
preset_cache_key=hashed_cache_key, **kwargs
|
||||
)
|
||||
return hashed_cache_key
|
||||
|
||||
def _get_param_value(
|
||||
self,
|
||||
param: str,
|
||||
kwargs: dict,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get the value for the given param from kwargs
|
||||
"""
|
||||
if param == "model":
|
||||
return self._get_model_param_value(kwargs)
|
||||
elif param == "file":
|
||||
return self._get_file_param_value(kwargs)
|
||||
return kwargs[param]
|
||||
|
||||
def _get_model_param_value(self, kwargs: dict) -> str:
|
||||
"""
|
||||
Handles getting the value for the 'model' param from kwargs
|
||||
|
||||
1. If caching groups are set, then return the caching group as the model https://docs.litellm.ai/docs/routing#caching-across-model-groups
|
||||
2. Else if a model_group is set, then return the model_group as the model. This is used for all requests sent through the litellm.Router()
|
||||
3. Else use the `model` passed in kwargs
|
||||
"""
|
||||
metadata: Dict = kwargs.get("metadata", {}) or {}
|
||||
litellm_params: Dict = kwargs.get("litellm_params", {}) or {}
|
||||
metadata_in_litellm_params: Dict = litellm_params.get("metadata", {}) or {}
|
||||
model_group: Optional[str] = metadata.get(
|
||||
"model_group"
|
||||
) or metadata_in_litellm_params.get("model_group")
|
||||
caching_group = self._get_caching_group(metadata, model_group)
|
||||
return caching_group or model_group or kwargs["model"]
|
||||
|
||||
def _get_caching_group(
|
||||
self, metadata: dict, model_group: Optional[str]
|
||||
) -> Optional[str]:
|
||||
caching_groups: Optional[List] = metadata.get("caching_groups", [])
|
||||
if caching_groups:
|
||||
for group in caching_groups:
|
||||
if model_group in group:
|
||||
return str(group)
|
||||
return None
|
||||
|
||||
def _get_file_param_value(self, kwargs: dict) -> str:
|
||||
"""
|
||||
Handles getting the value for the 'file' param from kwargs. Used for `transcription` requests
|
||||
"""
|
||||
file = kwargs.get("file")
|
||||
metadata = kwargs.get("metadata", {})
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
return (
|
||||
metadata.get("file_checksum")
|
||||
or getattr(file, "name", None)
|
||||
or metadata.get("file_name")
|
||||
or litellm_params.get("file_name")
|
||||
)
|
||||
|
||||
def _get_preset_cache_key_from_kwargs(self, **kwargs) -> Optional[str]:
|
||||
"""
|
||||
Get the preset cache key from kwargs["litellm_params"]
|
||||
|
||||
We use _get_preset_cache_keys for two reasons
|
||||
|
||||
1. optional params like max_tokens, get transformed for bedrock -> max_new_tokens
|
||||
2. avoid doing duplicate / repeated work
|
||||
"""
|
||||
if kwargs:
|
||||
if "litellm_params" in kwargs:
|
||||
return kwargs["litellm_params"].get("preset_cache_key", None)
|
||||
return None
|
||||
|
||||
def _set_preset_cache_key_in_kwargs(self, preset_cache_key: str, **kwargs) -> None:
|
||||
"""
|
||||
Set the calculated cache key in kwargs
|
||||
|
||||
This is used to avoid doing duplicate / repeated work
|
||||
|
||||
Placed in kwargs["litellm_params"]
|
||||
"""
|
||||
if kwargs:
|
||||
if "litellm_params" in kwargs:
|
||||
kwargs["litellm_params"]["preset_cache_key"] = preset_cache_key
|
||||
|
||||
def _get_relevant_args_to_use_for_cache_key(self) -> Set[str]:
|
||||
"""
|
||||
Gets the supported kwargs for each call type and combines them
|
||||
"""
|
||||
chat_completion_kwargs = self._get_litellm_supported_chat_completion_kwargs()
|
||||
text_completion_kwargs = self._get_litellm_supported_text_completion_kwargs()
|
||||
embedding_kwargs = self._get_litellm_supported_embedding_kwargs()
|
||||
transcription_kwargs = self._get_litellm_supported_transcription_kwargs()
|
||||
rerank_kwargs = self._get_litellm_supported_rerank_kwargs()
|
||||
exclude_kwargs = self._get_kwargs_to_exclude_from_cache_key()
|
||||
|
||||
combined_kwargs = chat_completion_kwargs.union(
|
||||
text_completion_kwargs,
|
||||
embedding_kwargs,
|
||||
transcription_kwargs,
|
||||
rerank_kwargs,
|
||||
)
|
||||
combined_kwargs = combined_kwargs.difference(exclude_kwargs)
|
||||
return combined_kwargs
|
||||
|
||||
def _get_litellm_supported_chat_completion_kwargs(self) -> Set[str]:
|
||||
"""
|
||||
Get the litellm supported chat completion kwargs
|
||||
|
||||
This follows the OpenAI API Spec
|
||||
"""
|
||||
all_chat_completion_kwargs = set(
|
||||
CompletionCreateParamsNonStreaming.__annotations__.keys()
|
||||
).union(set(CompletionCreateParamsStreaming.__annotations__.keys()))
|
||||
return all_chat_completion_kwargs
|
||||
|
||||
def _get_litellm_supported_text_completion_kwargs(self) -> Set[str]:
|
||||
"""
|
||||
Get the litellm supported text completion kwargs
|
||||
|
||||
This follows the OpenAI API Spec
|
||||
"""
|
||||
all_text_completion_kwargs = set(
|
||||
TextCompletionCreateParamsNonStreaming.__annotations__.keys()
|
||||
).union(set(TextCompletionCreateParamsStreaming.__annotations__.keys()))
|
||||
return all_text_completion_kwargs
|
||||
|
||||
def _get_litellm_supported_rerank_kwargs(self) -> Set[str]:
|
||||
"""
|
||||
Get the litellm supported rerank kwargs
|
||||
"""
|
||||
return set(RerankRequest.model_fields.keys())
|
||||
|
||||
def _get_litellm_supported_embedding_kwargs(self) -> Set[str]:
|
||||
"""
|
||||
Get the litellm supported embedding kwargs
|
||||
|
||||
This follows the OpenAI API Spec
|
||||
"""
|
||||
return set(EmbeddingCreateParams.__annotations__.keys())
|
||||
|
||||
def _get_litellm_supported_transcription_kwargs(self) -> Set[str]:
|
||||
"""
|
||||
Get the litellm supported transcription kwargs
|
||||
|
||||
This follows the OpenAI API Spec
|
||||
"""
|
||||
return set(TranscriptionCreateParams.__annotations__.keys())
|
||||
|
||||
def _get_kwargs_to_exclude_from_cache_key(self) -> Set[str]:
|
||||
"""
|
||||
Get the kwargs to exclude from the cache key
|
||||
"""
|
||||
return set(["metadata"])
|
||||
|
||||
def _get_hashed_cache_key(self, cache_key: str) -> str:
|
||||
"""
|
||||
Get the hashed cache key for the given cache key.
|
||||
|
||||
Use hashlib to create a sha256 hash of the cache key
|
||||
|
||||
Args:
|
||||
cache_key (str): The cache key to hash.
|
||||
|
||||
Returns:
|
||||
str: The hashed cache key.
|
||||
"""
|
||||
hash_object = hashlib.sha256(cache_key.encode())
|
||||
# Hexadecimal representation of the hash
|
||||
hash_hex = hash_object.hexdigest()
|
||||
print_verbose(f"Hashed cache key (SHA-256): {hash_hex}")
|
||||
if kwargs.get("metadata", {}).get("redis_namespace", None) is not None:
|
||||
_namespace = kwargs.get("metadata", {}).get("redis_namespace", None)
|
||||
hash_hex = f"{_namespace}:{hash_hex}"
|
||||
print_verbose(f"Hashed Key with Namespace: {hash_hex}")
|
||||
elif self.namespace is not None:
|
||||
hash_hex = f"{self.namespace}:{hash_hex}"
|
||||
print_verbose(f"Hashed Key with Namespace: {hash_hex}")
|
||||
verbose_logger.debug("Hashed cache key (SHA-256): %s", hash_hex)
|
||||
return hash_hex
|
||||
|
||||
def _add_redis_namespace_to_cache_key(self, hash_hex: str, **kwargs) -> str:
|
||||
"""
|
||||
If a redis namespace is provided, add it to the cache key
|
||||
|
||||
Args:
|
||||
hash_hex (str): The hashed cache key.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
str: The final hashed cache key with the redis namespace.
|
||||
"""
|
||||
namespace = kwargs.get("metadata", {}).get("redis_namespace") or self.namespace
|
||||
if namespace:
|
||||
hash_hex = f"{namespace}:{hash_hex}"
|
||||
verbose_logger.debug("Final hashed key: %s", hash_hex)
|
||||
return hash_hex
|
||||
|
||||
def generate_streaming_content(self, content):
|
||||
|
|
|
@ -182,7 +182,9 @@ class LLMCachingHandler:
|
|||
end_time=end_time,
|
||||
cache_hit=cache_hit,
|
||||
)
|
||||
cache_key = kwargs.get("preset_cache_key", None)
|
||||
cache_key = litellm.cache._get_preset_cache_key_from_kwargs(
|
||||
**kwargs
|
||||
)
|
||||
if (
|
||||
isinstance(cached_result, BaseModel)
|
||||
or isinstance(cached_result, CustomStreamWrapper)
|
||||
|
@ -236,12 +238,7 @@ class LLMCachingHandler:
|
|||
original_function=original_function
|
||||
):
|
||||
print_verbose("Checking Cache")
|
||||
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
|
||||
kwargs["preset_cache_key"] = (
|
||||
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
|
||||
)
|
||||
cached_result = litellm.cache.get_cache(*args, **kwargs)
|
||||
|
||||
if cached_result is not None:
|
||||
if "detail" in cached_result:
|
||||
# implies an error occurred
|
||||
|
@ -285,7 +282,9 @@ class LLMCachingHandler:
|
|||
target=logging_obj.success_handler,
|
||||
args=(cached_result, start_time, end_time, cache_hit),
|
||||
).start()
|
||||
cache_key = kwargs.get("preset_cache_key", None)
|
||||
cache_key = litellm.cache._get_preset_cache_key_from_kwargs(
|
||||
**kwargs
|
||||
)
|
||||
if (
|
||||
isinstance(cached_result, BaseModel)
|
||||
or isinstance(cached_result, CustomStreamWrapper)
|
||||
|
@ -493,10 +492,6 @@ class LLMCachingHandler:
|
|||
if all(result is None for result in cached_result):
|
||||
cached_result = None
|
||||
else:
|
||||
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
|
||||
kwargs["preset_cache_key"] = (
|
||||
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
|
||||
)
|
||||
if litellm.cache._supports_async() is True:
|
||||
cached_result = await litellm.cache.async_get_cache(*args, **kwargs)
|
||||
else: # for s3 caching. [NOT RECOMMENDED IN PROD - this will slow down responses since boto3 is sync]
|
||||
|
@ -842,10 +837,16 @@ class LLMCachingHandler:
|
|||
"metadata": kwargs.get("metadata", {}),
|
||||
"model_info": kwargs.get("model_info", {}),
|
||||
"proxy_server_request": kwargs.get("proxy_server_request", None),
|
||||
"preset_cache_key": kwargs.get("preset_cache_key", None),
|
||||
"stream_response": kwargs.get("stream_response", {}),
|
||||
}
|
||||
|
||||
if litellm.cache is not None:
|
||||
litellm_params["preset_cache_key"] = (
|
||||
litellm.cache._get_preset_cache_key_from_kwargs(**kwargs)
|
||||
)
|
||||
else:
|
||||
litellm_params["preset_cache_key"] = None
|
||||
|
||||
logging_obj.update_environment_variables(
|
||||
model=model,
|
||||
user=kwargs.get("user", None),
|
||||
|
|
|
@ -664,10 +664,10 @@ class LangFuseLogger:
|
|||
if "cache_key" in litellm.langfuse_default_tags:
|
||||
_hidden_params = metadata.get("hidden_params", {}) or {}
|
||||
_cache_key = _hidden_params.get("cache_key", None)
|
||||
if _cache_key is None:
|
||||
if _cache_key is None and litellm.cache is not None:
|
||||
# fallback to using "preset_cache_key"
|
||||
_preset_cache_key = kwargs.get("litellm_params", {}).get(
|
||||
"preset_cache_key", None
|
||||
_preset_cache_key = litellm.cache._get_preset_cache_key_from_kwargs(
|
||||
**kwargs
|
||||
)
|
||||
_cache_key = _preset_cache_key
|
||||
tags.append(f"cache_key:{_cache_key}")
|
||||
|
|
|
@ -974,7 +974,7 @@ async def test_redis_cache_acompletion_stream():
|
|||
response_1_content += chunk.choices[0].delta.content or ""
|
||||
print(response_1_content)
|
||||
|
||||
time.sleep(0.5)
|
||||
await asyncio.sleep(0.5)
|
||||
print("\n\n Response 1 content: ", response_1_content, "\n\n")
|
||||
|
||||
response2 = await litellm.acompletion(
|
||||
|
|
245
tests/local_testing/test_unit_test_caching.py
Normal file
245
tests/local_testing/test_unit_test_caching.py
Normal file
|
@ -0,0 +1,245 @@
|
|||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from test_rerank import assert_response_shape
|
||||
|
||||
|
||||
load_dotenv()
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import asyncio
|
||||
import hashlib
|
||||
import random
|
||||
|
||||
import pytest
|
||||
|
||||
import litellm
|
||||
from litellm import aembedding, completion, embedding
|
||||
from litellm.caching.caching import Cache
|
||||
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from litellm.caching.caching_handler import LLMCachingHandler, CachingHandlerResponse
|
||||
from litellm.caching.caching import LiteLLMCacheType
|
||||
from litellm.types.utils import CallTypes
|
||||
from litellm.types.rerank import RerankResponse
|
||||
from litellm.types.utils import (
|
||||
ModelResponse,
|
||||
EmbeddingResponse,
|
||||
TextCompletionResponse,
|
||||
TranscriptionResponse,
|
||||
Embedding,
|
||||
)
|
||||
from datetime import timedelta, datetime
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
||||
from litellm._logging import verbose_logger
|
||||
import logging
|
||||
|
||||
|
||||
def test_get_kwargs_for_cache_key():
|
||||
_cache = litellm.Cache()
|
||||
relevant_kwargs = _cache._get_relevant_args_to_use_for_cache_key()
|
||||
print(relevant_kwargs)
|
||||
|
||||
|
||||
def test_get_cache_key_chat_completion():
|
||||
cache = Cache()
|
||||
kwargs = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "Hello, world!"}],
|
||||
"temperature": 0.7,
|
||||
}
|
||||
cache_key_1 = cache.get_cache_key(**kwargs)
|
||||
assert isinstance(cache_key_1, str)
|
||||
assert len(cache_key_1) > 0
|
||||
|
||||
kwargs_2 = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "Hello, world!"}],
|
||||
"max_completion_tokens": 100,
|
||||
}
|
||||
cache_key_2 = cache.get_cache_key(**kwargs_2)
|
||||
assert cache_key_1 != cache_key_2
|
||||
|
||||
kwargs_3 = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "Hello, world!"}],
|
||||
"max_completion_tokens": 100,
|
||||
}
|
||||
cache_key_3 = cache.get_cache_key(**kwargs_3)
|
||||
assert cache_key_2 == cache_key_3
|
||||
|
||||
|
||||
def test_get_cache_key_embedding():
|
||||
cache = Cache()
|
||||
kwargs = {
|
||||
"model": "text-embedding-3-small",
|
||||
"input": "Hello, world!",
|
||||
"dimensions": 1536,
|
||||
}
|
||||
cache_key_1 = cache.get_cache_key(**kwargs)
|
||||
assert isinstance(cache_key_1, str)
|
||||
assert len(cache_key_1) > 0
|
||||
|
||||
kwargs_2 = {
|
||||
"model": "text-embedding-3-small",
|
||||
"input": "Hello, world!",
|
||||
"dimensions": 1539,
|
||||
}
|
||||
cache_key_2 = cache.get_cache_key(**kwargs_2)
|
||||
assert cache_key_1 != cache_key_2
|
||||
|
||||
kwargs_3 = {
|
||||
"model": "text-embedding-3-small",
|
||||
"input": "Hello, world!",
|
||||
"dimensions": 1539,
|
||||
}
|
||||
cache_key_3 = cache.get_cache_key(**kwargs_3)
|
||||
assert cache_key_2 == cache_key_3
|
||||
|
||||
|
||||
def test_get_cache_key_text_completion():
|
||||
cache = Cache()
|
||||
kwargs = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"prompt": "Hello, world! here is a second line",
|
||||
"best_of": 3,
|
||||
"logit_bias": {"123": 1},
|
||||
"seed": 42,
|
||||
}
|
||||
cache_key_1 = cache.get_cache_key(**kwargs)
|
||||
assert isinstance(cache_key_1, str)
|
||||
assert len(cache_key_1) > 0
|
||||
|
||||
kwargs_2 = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"prompt": "Hello, world! here is a second line",
|
||||
"best_of": 30,
|
||||
}
|
||||
cache_key_2 = cache.get_cache_key(**kwargs_2)
|
||||
assert cache_key_1 != cache_key_2
|
||||
|
||||
kwargs_3 = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"prompt": "Hello, world! here is a second line",
|
||||
"best_of": 30,
|
||||
}
|
||||
cache_key_3 = cache.get_cache_key(**kwargs_3)
|
||||
assert cache_key_2 == cache_key_3
|
||||
|
||||
|
||||
def test_get_hashed_cache_key():
|
||||
cache = Cache()
|
||||
cache_key = "model:gpt-3.5-turbo,messages:Hello world"
|
||||
hashed_key = cache._get_hashed_cache_key(cache_key)
|
||||
assert len(hashed_key) == 64 # SHA-256 produces a 64-character hex string
|
||||
|
||||
|
||||
def test_add_redis_namespace_to_cache_key():
|
||||
cache = Cache(namespace="test_namespace")
|
||||
hashed_key = "abcdef1234567890"
|
||||
|
||||
# Test with class-level namespace
|
||||
result = cache._add_redis_namespace_to_cache_key(hashed_key)
|
||||
assert result == "test_namespace:abcdef1234567890"
|
||||
|
||||
# Test with metadata namespace
|
||||
kwargs = {"metadata": {"redis_namespace": "custom_namespace"}}
|
||||
result = cache._add_redis_namespace_to_cache_key(hashed_key, **kwargs)
|
||||
assert result == "custom_namespace:abcdef1234567890"
|
||||
|
||||
|
||||
def test_get_model_param_value():
|
||||
cache = Cache()
|
||||
|
||||
# Test with regular model
|
||||
kwargs = {"model": "gpt-3.5-turbo"}
|
||||
assert cache._get_model_param_value(kwargs) == "gpt-3.5-turbo"
|
||||
|
||||
# Test with model_group
|
||||
kwargs = {"model": "gpt-3.5-turbo", "metadata": {"model_group": "gpt-group"}}
|
||||
assert cache._get_model_param_value(kwargs) == "gpt-group"
|
||||
|
||||
# Test with caching_group
|
||||
kwargs = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"metadata": {
|
||||
"model_group": "openai-gpt-3.5-turbo",
|
||||
"caching_groups": [("openai-gpt-3.5-turbo", "azure-gpt-3.5-turbo")],
|
||||
},
|
||||
}
|
||||
assert (
|
||||
cache._get_model_param_value(kwargs)
|
||||
== "('openai-gpt-3.5-turbo', 'azure-gpt-3.5-turbo')"
|
||||
)
|
||||
|
||||
kwargs = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"metadata": {
|
||||
"model_group": "azure-gpt-3.5-turbo",
|
||||
"caching_groups": [("openai-gpt-3.5-turbo", "azure-gpt-3.5-turbo")],
|
||||
},
|
||||
}
|
||||
assert (
|
||||
cache._get_model_param_value(kwargs)
|
||||
== "('openai-gpt-3.5-turbo', 'azure-gpt-3.5-turbo')"
|
||||
)
|
||||
|
||||
kwargs = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"metadata": {
|
||||
"model_group": "not-in-caching-group-gpt-3.5-turbo",
|
||||
"caching_groups": [("openai-gpt-3.5-turbo", "azure-gpt-3.5-turbo")],
|
||||
},
|
||||
}
|
||||
assert cache._get_model_param_value(kwargs) == "not-in-caching-group-gpt-3.5-turbo"
|
||||
|
||||
|
||||
def test_preset_cache_key():
|
||||
"""
|
||||
Test that the preset cache key is used if it is set in kwargs["litellm_params"]
|
||||
"""
|
||||
cache = Cache()
|
||||
kwargs = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "Hello, world!"}],
|
||||
"temperature": 0.7,
|
||||
"litellm_params": {"preset_cache_key": "preset-cache-key"},
|
||||
}
|
||||
|
||||
assert cache.get_cache_key(**kwargs) == "preset-cache-key"
|
||||
|
||||
|
||||
def test_generate_streaming_content():
|
||||
cache = Cache()
|
||||
content = "Hello, this is a test message."
|
||||
generator = cache.generate_streaming_content(content)
|
||||
|
||||
full_response = ""
|
||||
chunk_count = 0
|
||||
|
||||
for chunk in generator:
|
||||
chunk_count += 1
|
||||
assert "choices" in chunk
|
||||
assert len(chunk["choices"]) == 1
|
||||
assert "delta" in chunk["choices"][0]
|
||||
assert "role" in chunk["choices"][0]["delta"]
|
||||
assert chunk["choices"][0]["delta"]["role"] == "assistant"
|
||||
assert "content" in chunk["choices"][0]["delta"]
|
||||
|
||||
chunk_content = chunk["choices"][0]["delta"]["content"]
|
||||
full_response += chunk_content
|
||||
|
||||
# Check that each chunk is no longer than 5 characters
|
||||
assert len(chunk_content) <= 5
|
||||
print("full_response from generate_streaming_content", full_response)
|
||||
# Check that the full content is reconstructed correctly
|
||||
assert full_response == content
|
||||
# Check that there were multiple chunks
|
||||
assert chunk_count > 1
|
||||
|
||||
print(f"Number of chunks: {chunk_count}")
|
Loading…
Add table
Add a link
Reference in a new issue