forked from phoenix/litellm-mirror
refactor: add black formatting
This commit is contained in:
parent
b87d630b0a
commit
4905929de3
156 changed files with 19723 additions and 10869 deletions
|
@ -12,13 +12,15 @@ import time, logging
|
|||
import json, traceback, ast
|
||||
from typing import Optional, Literal, List
|
||||
|
||||
|
||||
def print_verbose(print_statement):
|
||||
try:
|
||||
if litellm.set_verbose:
|
||||
print(print_statement) # noqa
|
||||
print(print_statement) # noqa
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
class BaseCache:
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
@ -45,13 +47,13 @@ class InMemoryCache(BaseCache):
|
|||
self.cache_dict.pop(key, None)
|
||||
return None
|
||||
original_cached_response = self.cache_dict[key]
|
||||
try:
|
||||
try:
|
||||
cached_response = json.loads(original_cached_response)
|
||||
except:
|
||||
except:
|
||||
cached_response = original_cached_response
|
||||
return cached_response
|
||||
return None
|
||||
|
||||
|
||||
def flush_cache(self):
|
||||
self.cache_dict.clear()
|
||||
self.ttl_dict.clear()
|
||||
|
@ -60,17 +62,18 @@ class InMemoryCache(BaseCache):
|
|||
class RedisCache(BaseCache):
|
||||
def __init__(self, host=None, port=None, password=None, **kwargs):
|
||||
import redis
|
||||
|
||||
# if users don't provider one, use the default litellm cache
|
||||
from ._redis import get_redis_client
|
||||
|
||||
redis_kwargs = {}
|
||||
if host is not None:
|
||||
if host is not None:
|
||||
redis_kwargs["host"] = host
|
||||
if port is not None:
|
||||
redis_kwargs["port"] = port
|
||||
if password is not None:
|
||||
if password is not None:
|
||||
redis_kwargs["password"] = password
|
||||
|
||||
|
||||
redis_kwargs.update(kwargs)
|
||||
|
||||
self.redis_client = get_redis_client(**redis_kwargs)
|
||||
|
@ -88,13 +91,19 @@ class RedisCache(BaseCache):
|
|||
try:
|
||||
print_verbose(f"Get Redis Cache: key: {key}")
|
||||
cached_response = self.redis_client.get(key)
|
||||
print_verbose(f"Got Redis Cache: key: {key}, cached_response {cached_response}")
|
||||
print_verbose(
|
||||
f"Got Redis Cache: key: {key}, cached_response {cached_response}"
|
||||
)
|
||||
if cached_response != None:
|
||||
# cached_response is in `b{} convert it to ModelResponse
|
||||
cached_response = cached_response.decode("utf-8") # Convert bytes to string
|
||||
try:
|
||||
cached_response = json.loads(cached_response) # Convert string to dictionary
|
||||
except:
|
||||
cached_response = cached_response.decode(
|
||||
"utf-8"
|
||||
) # Convert bytes to string
|
||||
try:
|
||||
cached_response = json.loads(
|
||||
cached_response
|
||||
) # Convert string to dictionary
|
||||
except:
|
||||
cached_response = ast.literal_eval(cached_response)
|
||||
return cached_response
|
||||
except Exception as e:
|
||||
|
@ -105,34 +114,40 @@ class RedisCache(BaseCache):
|
|||
def flush_cache(self):
|
||||
self.redis_client.flushall()
|
||||
|
||||
class DualCache(BaseCache):
|
||||
|
||||
class DualCache(BaseCache):
|
||||
"""
|
||||
This updates both Redis and an in-memory cache simultaneously.
|
||||
When data is updated or inserted, it is written to both the in-memory cache + Redis.
|
||||
This updates both Redis and an in-memory cache simultaneously.
|
||||
When data is updated or inserted, it is written to both the in-memory cache + Redis.
|
||||
This ensures that even if Redis hasn't been updated yet, the in-memory cache reflects the most recent data.
|
||||
"""
|
||||
def __init__(self, in_memory_cache: Optional[InMemoryCache] =None, redis_cache: Optional[RedisCache] =None) -> None:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_memory_cache: Optional[InMemoryCache] = None,
|
||||
redis_cache: Optional[RedisCache] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# If in_memory_cache is not provided, use the default InMemoryCache
|
||||
self.in_memory_cache = in_memory_cache or InMemoryCache()
|
||||
# If redis_cache is not provided, use the default RedisCache
|
||||
self.redis_cache = redis_cache
|
||||
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
# Update both Redis and in-memory cache
|
||||
try:
|
||||
try:
|
||||
print_verbose(f"set cache: key: {key}; value: {value}")
|
||||
if self.in_memory_cache is not None:
|
||||
self.in_memory_cache.set_cache(key, value, **kwargs)
|
||||
|
||||
if self.redis_cache is not None:
|
||||
self.redis_cache.set_cache(key, value, **kwargs)
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
print_verbose(e)
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
# Try to fetch from in-memory cache first
|
||||
try:
|
||||
try:
|
||||
print_verbose(f"get cache: cache key: {key}")
|
||||
result = None
|
||||
if self.in_memory_cache is not None:
|
||||
|
@ -141,7 +156,7 @@ class DualCache(BaseCache):
|
|||
if in_memory_result is not None:
|
||||
result = in_memory_result
|
||||
|
||||
if self.redis_cache is not None:
|
||||
if self.redis_cache is not None:
|
||||
# If not found in in-memory cache, try fetching from Redis
|
||||
redis_result = self.redis_cache.get_cache(key, **kwargs)
|
||||
|
||||
|
@ -153,25 +168,28 @@ class DualCache(BaseCache):
|
|||
|
||||
print_verbose(f"get cache: cache result: {result}")
|
||||
return result
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
def flush_cache(self):
|
||||
if self.in_memory_cache is not None:
|
||||
self.in_memory_cache.flush_cache()
|
||||
if self.redis_cache is not None:
|
||||
self.redis_cache.flush_cache()
|
||||
|
||||
|
||||
#### LiteLLM.Completion / Embedding Cache ####
|
||||
class Cache:
|
||||
def __init__(
|
||||
self,
|
||||
type: Optional[Literal["local", "redis"]] = "local",
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
supported_call_types: Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]] = ["completion", "acompletion", "embedding", "aembedding"],
|
||||
**kwargs
|
||||
self,
|
||||
type: Optional[Literal["local", "redis"]] = "local",
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
supported_call_types: Optional[
|
||||
List[Literal["completion", "acompletion", "embedding", "aembedding"]]
|
||||
] = ["completion", "acompletion", "embedding", "aembedding"],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initializes the cache based on the given type.
|
||||
|
@ -200,7 +218,7 @@ class Cache:
|
|||
litellm.success_callback.append("cache")
|
||||
if "cache" not in litellm._async_success_callback:
|
||||
litellm._async_success_callback.append("cache")
|
||||
self.supported_call_types = supported_call_types # default to ["completion", "acompletion", "embedding", "aembedding"]
|
||||
self.supported_call_types = supported_call_types # default to ["completion", "acompletion", "embedding", "aembedding"]
|
||||
|
||||
def get_cache_key(self, *args, **kwargs):
|
||||
"""
|
||||
|
@ -215,18 +233,37 @@ class Cache:
|
|||
"""
|
||||
cache_key = ""
|
||||
print_verbose(f"\nGetting Cache key. Kwargs: {kwargs}")
|
||||
|
||||
|
||||
# for streaming, we use preset_cache_key. It's created in wrapper(), we do this because optional params like max_tokens, get transformed for bedrock -> max_new_tokens
|
||||
if kwargs.get("litellm_params", {}).get("preset_cache_key", None) is not None:
|
||||
print_verbose(f"\nReturning preset cache key: {cache_key}")
|
||||
return kwargs.get("litellm_params", {}).get("preset_cache_key", None)
|
||||
|
||||
# sort kwargs by keys, since model: [gpt-4, temperature: 0.2, max_tokens: 200] == [temperature: 0.2, max_tokens: 200, model: gpt-4]
|
||||
completion_kwargs = ["model", "messages", "temperature", "top_p", "n", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "response_format", "seed", "tools", "tool_choice"]
|
||||
embedding_only_kwargs = ["input", "encoding_format"] # embedding kwargs = model, input, user, encoding_format. Model, user are checked in completion_kwargs
|
||||
|
||||
completion_kwargs = [
|
||||
"model",
|
||||
"messages",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"n",
|
||||
"stop",
|
||||
"max_tokens",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"logit_bias",
|
||||
"user",
|
||||
"response_format",
|
||||
"seed",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
]
|
||||
embedding_only_kwargs = [
|
||||
"input",
|
||||
"encoding_format",
|
||||
] # embedding kwargs = model, input, user, encoding_format. Model, user are checked in completion_kwargs
|
||||
|
||||
# combined_kwargs - NEEDS to be ordered across get_cache_key(). Do not use a set()
|
||||
combined_kwargs = completion_kwargs + embedding_only_kwargs
|
||||
combined_kwargs = completion_kwargs + embedding_only_kwargs
|
||||
for param in combined_kwargs:
|
||||
# ignore litellm params here
|
||||
if param in kwargs:
|
||||
|
@ -241,8 +278,8 @@ class Cache:
|
|||
model_group = metadata.get("model_group", None)
|
||||
caching_groups = metadata.get("caching_groups", None)
|
||||
if caching_groups:
|
||||
for group in caching_groups:
|
||||
if model_group in group:
|
||||
for group in caching_groups:
|
||||
if model_group in group:
|
||||
caching_group = group
|
||||
break
|
||||
if litellm_params is not None:
|
||||
|
@ -251,23 +288,34 @@ class Cache:
|
|||
model_group = metadata.get("model_group", None)
|
||||
caching_groups = metadata.get("caching_groups", None)
|
||||
if caching_groups:
|
||||
for group in caching_groups:
|
||||
if model_group in group:
|
||||
for group in caching_groups:
|
||||
if model_group in group:
|
||||
caching_group = group
|
||||
break
|
||||
param_value = caching_group or model_group or kwargs[param] # use caching_group, if set then model_group if it exists, else use kwargs["model"]
|
||||
param_value = (
|
||||
caching_group or model_group or kwargs[param]
|
||||
) # use caching_group, if set then model_group if it exists, else use kwargs["model"]
|
||||
else:
|
||||
if kwargs[param] is None:
|
||||
continue # ignore None params
|
||||
continue # ignore None params
|
||||
param_value = kwargs[param]
|
||||
cache_key+= f"{str(param)}: {str(param_value)}"
|
||||
cache_key += f"{str(param)}: {str(param_value)}"
|
||||
print_verbose(f"\nCreated cache key: {cache_key}")
|
||||
return cache_key
|
||||
|
||||
def generate_streaming_content(self, content):
|
||||
chunk_size = 5 # Adjust the chunk size as needed
|
||||
for i in range(0, len(content), chunk_size):
|
||||
yield {'choices': [{'delta': {'role': 'assistant', 'content': content[i:i + chunk_size]}}]}
|
||||
yield {
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"role": "assistant",
|
||||
"content": content[i : i + chunk_size],
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
time.sleep(0.02)
|
||||
|
||||
def get_cache(self, *args, **kwargs):
|
||||
|
@ -319,4 +367,4 @@ class Cache:
|
|||
pass
|
||||
|
||||
async def _async_add_cache(self, result, *args, **kwargs):
|
||||
self.add_cache(result, *args, **kwargs)
|
||||
self.add_cache(result, *args, **kwargs)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue