Merge pull request #1809 from BerriAI/litellm_embedding_caching_updates

Support caching individual items in embedding list (Async embedding only)
This commit is contained in:
Krish Dholakia 2024-02-03 21:04:23 -08:00 committed by GitHub
commit 28df60b609
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 638 additions and 196 deletions

View file

@ -285,6 +285,7 @@ openai_compatible_endpoints: List = [
"api.endpoints.anyscale.com/v1", "api.endpoints.anyscale.com/v1",
"api.deepinfra.com/v1/openai", "api.deepinfra.com/v1/openai",
"api.mistral.ai/v1", "api.mistral.ai/v1",
"api.together.xyz/v1",
] ]
# this is maintained for Exception Mapping # this is maintained for Exception Mapping
@ -294,6 +295,7 @@ openai_compatible_providers: List = [
"deepinfra", "deepinfra",
"perplexity", "perplexity",
"xinference", "xinference",
"together_ai",
] ]

View file

@ -11,6 +11,7 @@
import os import os
import inspect import inspect
import redis, litellm import redis, litellm
import redis.asyncio as async_redis
from typing import List, Optional from typing import List, Optional
@ -67,7 +68,10 @@ def get_redis_url_from_environment():
) )
def get_redis_client(**env_overrides): def _get_redis_client_logic(**env_overrides):
"""
Common functionality across sync + async redis client implementations
"""
### check if "os.environ/<key-name>" passed in ### check if "os.environ/<key-name>" passed in
for k, v in env_overrides.items(): for k, v in env_overrides.items():
if isinstance(v, str) and v.startswith("os.environ/"): if isinstance(v, str) and v.startswith("os.environ/"):
@ -85,9 +89,33 @@ def get_redis_client(**env_overrides):
redis_kwargs.pop("port", None) redis_kwargs.pop("port", None)
redis_kwargs.pop("db", None) redis_kwargs.pop("db", None)
redis_kwargs.pop("password", None) redis_kwargs.pop("password", None)
return redis.Redis.from_url(**redis_kwargs)
elif "host" not in redis_kwargs or redis_kwargs["host"] is None: elif "host" not in redis_kwargs or redis_kwargs["host"] is None:
raise ValueError("Either 'host' or 'url' must be specified for redis.") raise ValueError("Either 'host' or 'url' must be specified for redis.")
litellm.print_verbose(f"redis_kwargs: {redis_kwargs}") litellm.print_verbose(f"redis_kwargs: {redis_kwargs}")
return redis_kwargs
def get_redis_client(**env_overrides):
redis_kwargs = _get_redis_client_logic(**env_overrides)
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
return redis.Redis.from_url(**redis_kwargs)
return redis.Redis(**redis_kwargs) return redis.Redis(**redis_kwargs)
def get_redis_async_client(**env_overrides):
redis_kwargs = _get_redis_client_logic(**env_overrides)
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
return async_redis.Redis.from_url(**redis_kwargs)
return async_redis.Redis(
socket_timeout=5,
**redis_kwargs,
)
def get_redis_connection_pool(**env_overrides):
redis_kwargs = _get_redis_client_logic(**env_overrides)
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
return async_redis.BlockingConnectionPool.from_url(
timeout=5, url=redis_kwargs["url"]
)
return async_redis.BlockingConnectionPool(timeout=5, **redis_kwargs)

View file

@ -8,7 +8,7 @@
# Thank you users! We ❤️ you! - Krrish & Ishaan # Thank you users! We ❤️ you! - Krrish & Ishaan
import litellm import litellm
import time, logging import time, logging, asyncio
import json, traceback, ast, hashlib import json, traceback, ast, hashlib
from typing import Optional, Literal, List, Union, Any from typing import Optional, Literal, List, Union, Any
from openai._models import BaseModel as OpenAIObject from openai._models import BaseModel as OpenAIObject
@ -28,9 +28,18 @@ class BaseCache:
def set_cache(self, key, value, **kwargs): def set_cache(self, key, value, **kwargs):
raise NotImplementedError raise NotImplementedError
async def async_set_cache(self, key, value, **kwargs):
raise NotImplementedError
def get_cache(self, key, **kwargs): def get_cache(self, key, **kwargs):
raise NotImplementedError raise NotImplementedError
async def async_get_cache(self, key, **kwargs):
raise NotImplementedError
async def disconnect(self):
raise NotImplementedError
class InMemoryCache(BaseCache): class InMemoryCache(BaseCache):
def __init__(self): def __init__(self):
@ -43,6 +52,16 @@ class InMemoryCache(BaseCache):
if "ttl" in kwargs: if "ttl" in kwargs:
self.ttl_dict[key] = time.time() + kwargs["ttl"] self.ttl_dict[key] = time.time() + kwargs["ttl"]
async def async_set_cache(self, key, value, **kwargs):
self.set_cache(key=key, value=value, **kwargs)
async def async_set_cache_pipeline(self, cache_list, ttl=None):
for cache_key, cache_value in cache_list:
if ttl is not None:
self.set_cache(key=cache_key, value=cache_value, ttl=ttl)
else:
self.set_cache(key=cache_key, value=cache_value)
def get_cache(self, key, **kwargs): def get_cache(self, key, **kwargs):
if key in self.cache_dict: if key in self.cache_dict:
if key in self.ttl_dict: if key in self.ttl_dict:
@ -57,21 +76,27 @@ class InMemoryCache(BaseCache):
return cached_response return cached_response
return None return None
async def async_get_cache(self, key, **kwargs):
return self.get_cache(key=key, **kwargs)
def flush_cache(self): def flush_cache(self):
self.cache_dict.clear() self.cache_dict.clear()
self.ttl_dict.clear() self.ttl_dict.clear()
async def disconnect(self):
pass
def delete_cache(self, key): def delete_cache(self, key):
self.cache_dict.pop(key, None) self.cache_dict.pop(key, None)
self.ttl_dict.pop(key, None) self.ttl_dict.pop(key, None)
class RedisCache(BaseCache): class RedisCache(BaseCache):
def __init__(self, host=None, port=None, password=None, **kwargs): # if users don't provider one, use the default litellm cache
import redis
# if users don't provider one, use the default litellm cache def __init__(self, host=None, port=None, password=None, **kwargs):
from ._redis import get_redis_client from ._redis import get_redis_client, get_redis_connection_pool
redis_kwargs = {} redis_kwargs = {}
if host is not None: if host is not None:
@ -82,18 +107,84 @@ class RedisCache(BaseCache):
redis_kwargs["password"] = password redis_kwargs["password"] = password
redis_kwargs.update(kwargs) redis_kwargs.update(kwargs)
self.redis_client = get_redis_client(**redis_kwargs) self.redis_client = get_redis_client(**redis_kwargs)
self.redis_kwargs = redis_kwargs
self.async_redis_conn_pool = get_redis_connection_pool()
def init_async_client(self):
from ._redis import get_redis_async_client
return get_redis_async_client(
connection_pool=self.async_redis_conn_pool, **self.redis_kwargs
)
def set_cache(self, key, value, **kwargs): def set_cache(self, key, value, **kwargs):
ttl = kwargs.get("ttl", None) ttl = kwargs.get("ttl", None)
print_verbose(f"Set Redis Cache: key: {key}\nValue {value}") print_verbose(f"Set Redis Cache: key: {key}\nValue {value}\nttl={ttl}")
try: try:
self.redis_client.set(name=key, value=str(value), ex=ttl) self.redis_client.set(name=key, value=str(value), ex=ttl)
except Exception as e: except Exception as e:
# NON blocking - notify users Redis is throwing an exception # NON blocking - notify users Redis is throwing an exception
logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e) logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e)
async def async_set_cache(self, key, value, **kwargs):
_redis_client = self.init_async_client()
async with _redis_client as redis_client:
ttl = kwargs.get("ttl", None)
print_verbose(
f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
)
try:
await redis_client.set(name=key, value=json.dumps(value), ex=ttl)
except Exception as e:
# NON blocking - notify users Redis is throwing an exception
logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e)
async def async_set_cache_pipeline(self, cache_list, ttl=None):
"""
Use Redis Pipelines for bulk write operations
"""
_redis_client = self.init_async_client()
try:
async with _redis_client as redis_client:
async with redis_client.pipeline(transaction=True) as pipe:
# Iterate through each key-value pair in the cache_list and set them in the pipeline.
for cache_key, cache_value in cache_list:
print_verbose(
f"Set ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {cache_value}\nttl={ttl}"
)
# Set the value with a TTL if it's provided.
if ttl is not None:
pipe.setex(cache_key, ttl, json.dumps(cache_value))
else:
pipe.set(cache_key, json.dumps(cache_value))
# Execute the pipeline and return the results.
results = await pipe.execute()
print_verbose(f"pipeline results: {results}")
# Optionally, you could process 'results' to make sure that all set operations were successful.
return results
except Exception as e:
print_verbose(f"Error occurred in pipeline write - {str(e)}")
# NON blocking - notify users Redis is throwing an exception
logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e)
def _get_cache_logic(self, cached_response: Any):
"""
Common 'get_cache_logic' across sync + async redis client implementations
"""
if cached_response is None:
return cached_response
# cached_response is in `b{} convert it to ModelResponse
cached_response = cached_response.decode("utf-8") # Convert bytes to string
try:
cached_response = json.loads(
cached_response
) # Convert string to dictionary
except:
cached_response = ast.literal_eval(cached_response)
return cached_response
def get_cache(self, key, **kwargs): def get_cache(self, key, **kwargs):
try: try:
print_verbose(f"Get Redis Cache: key: {key}") print_verbose(f"Get Redis Cache: key: {key}")
@ -101,30 +192,40 @@ class RedisCache(BaseCache):
print_verbose( print_verbose(
f"Got Redis Cache: key: {key}, cached_response {cached_response}" f"Got Redis Cache: key: {key}, cached_response {cached_response}"
) )
if cached_response != None: return self._get_cache_logic(cached_response=cached_response)
# cached_response is in `b{} convert it to ModelResponse
cached_response = cached_response.decode(
"utf-8"
) # Convert bytes to string
try:
cached_response = json.loads(
cached_response
) # Convert string to dictionary
except:
cached_response = ast.literal_eval(cached_response)
return cached_response
except Exception as e: except Exception as e:
# NON blocking - notify users Redis is throwing an exception # NON blocking - notify users Redis is throwing an exception
traceback.print_exc() traceback.print_exc()
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e) logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
async def async_get_cache(self, key, **kwargs):
_redis_client = self.init_async_client()
async with _redis_client as redis_client:
try:
print_verbose(f"Get Redis Cache: key: {key}")
cached_response = await redis_client.get(key)
print_verbose(
f"Got Async Redis Cache: key: {key}, cached_response {cached_response}"
)
response = self._get_cache_logic(cached_response=cached_response)
return response
except Exception as e:
# NON blocking - notify users Redis is throwing an exception
traceback.print_exc()
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
def flush_cache(self): def flush_cache(self):
self.redis_client.flushall() self.redis_client.flushall()
async def disconnect(self):
pass
def delete_cache(self, key): def delete_cache(self, key):
self.redis_client.delete(key) self.redis_client.delete(key)
class S3Cache(BaseCache): class S3Cache(BaseCache):
def __init__( def __init__(
self, self,
@ -202,6 +303,9 @@ class S3Cache(BaseCache):
# NON blocking - notify users S3 is throwing an exception # NON blocking - notify users S3 is throwing an exception
print_verbose(f"S3 Caching: set_cache() - Got exception from S3: {e}") print_verbose(f"S3 Caching: set_cache() - Got exception from S3: {e}")
async def async_set_cache(self, key, value, **kwargs):
self.set_cache(key=key, value=value, **kwargs)
def get_cache(self, key, **kwargs): def get_cache(self, key, **kwargs):
import boto3, botocore import boto3, botocore
@ -244,6 +348,9 @@ class S3Cache(BaseCache):
traceback.print_exc() traceback.print_exc()
print_verbose(f"S3 Caching: get_cache() - Got exception from S3: {e}") print_verbose(f"S3 Caching: get_cache() - Got exception from S3: {e}")
async def async_get_cache(self, key, **kwargs):
return self.get_cache(key=key, **kwargs)
def flush_cache(self): def flush_cache(self):
pass pass
@ -361,9 +468,9 @@ class Cache:
""" """
if type == "redis": if type == "redis":
self.cache: BaseCache = RedisCache(host, port, password, **kwargs) self.cache: BaseCache = RedisCache(host, port, password, **kwargs)
if type == "local": elif type == "local":
self.cache = InMemoryCache() self.cache = InMemoryCache()
if type == "s3": elif type == "s3":
self.cache = S3Cache( self.cache = S3Cache(
s3_bucket_name=s3_bucket_name, s3_bucket_name=s3_bucket_name,
s3_region_name=s3_region_name, s3_region_name=s3_region_name,
@ -489,6 +596,45 @@ class Cache:
} }
time.sleep(0.02) time.sleep(0.02)
def _get_cache_logic(
self,
cached_result: Optional[Any],
max_age: Optional[float],
):
"""
Common get cache logic across sync + async implementations
"""
# Check if a timestamp was stored with the cached response
if (
cached_result is not None
and isinstance(cached_result, dict)
and "timestamp" in cached_result
):
timestamp = cached_result["timestamp"]
current_time = time.time()
# Calculate age of the cached response
response_age = current_time - timestamp
# Check if the cached response is older than the max-age
if max_age is not None and response_age > max_age:
return None # Cached response is too old
# If the response is fresh, or there's no max-age requirement, return the cached response
# cached_response is in `b{} convert it to ModelResponse
cached_response = cached_result.get("response")
try:
if isinstance(cached_response, dict):
pass
else:
cached_response = json.loads(
cached_response # type: ignore
) # Convert string to dictionary
except:
cached_response = ast.literal_eval(cached_response) # type: ignore
return cached_response
return cached_result
def get_cache(self, *args, **kwargs): def get_cache(self, *args, **kwargs):
""" """
Retrieves the cached result for the given arguments. Retrieves the cached result for the given arguments.
@ -511,54 +657,40 @@ class Cache:
"s-max-age", cache_control_args.get("s-maxage", float("inf")) "s-max-age", cache_control_args.get("s-maxage", float("inf"))
) )
cached_result = self.cache.get_cache(cache_key) cached_result = self.cache.get_cache(cache_key)
# Check if a timestamp was stored with the cached response return self._get_cache_logic(
if ( cached_result=cached_result, max_age=max_age
cached_result is not None )
and isinstance(cached_result, dict)
and "timestamp" in cached_result
and max_age is not None
):
timestamp = cached_result["timestamp"]
current_time = time.time()
# Calculate age of the cached response
response_age = current_time - timestamp
# Check if the cached response is older than the max-age
if response_age > max_age:
print_verbose(
f"Cached response for key {cache_key} is too old. Max-age: {max_age}s, Age: {response_age}s"
)
return None # Cached response is too old
# If the response is fresh, or there's no max-age requirement, return the cached response
# cached_response is in `b{} convert it to ModelResponse
cached_response = cached_result.get("response")
try:
if isinstance(cached_response, dict):
pass
else:
cached_response = json.loads(
cached_response
) # Convert string to dictionary
except:
cached_response = ast.literal_eval(cached_response)
return cached_response
return cached_result
except Exception as e: except Exception as e:
print_verbose(f"An exception occurred: {traceback.format_exc()}") print_verbose(f"An exception occurred: {traceback.format_exc()}")
return None return None
def add_cache(self, result, *args, **kwargs): async def async_get_cache(self, *args, **kwargs):
""" """
Adds a result to the cache. Async get cache implementation.
Args: Used for embedding calls in async wrapper
*args: args to litellm.completion() or embedding() """
**kwargs: kwargs to litellm.completion() or embedding() try: # never block execution
if "cache_key" in kwargs:
cache_key = kwargs["cache_key"]
else:
cache_key = self.get_cache_key(*args, **kwargs)
if cache_key is not None:
cache_control_args = kwargs.get("cache", {})
max_age = cache_control_args.get(
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
)
cached_result = await self.cache.async_get_cache(cache_key)
return self._get_cache_logic(
cached_result=cached_result, max_age=max_age
)
except Exception as e:
print_verbose(f"An exception occurred: {traceback.format_exc()}")
return None
Returns: def _add_cache_logic(self, result, *args, **kwargs):
None """
Common implementation across sync + async add_cache functions
""" """
try: try:
if "cache_key" in kwargs: if "cache_key" in kwargs:
@ -577,14 +709,82 @@ class Cache:
if k == "ttl": if k == "ttl":
kwargs["ttl"] = v kwargs["ttl"] = v
cached_data = {"timestamp": time.time(), "response": result} cached_data = {"timestamp": time.time(), "response": result}
self.cache.set_cache(cache_key, cached_data, **kwargs) return cache_key, cached_data, kwargs
else:
raise Exception("cache key is None")
except Exception as e:
raise e
def add_cache(self, result, *args, **kwargs):
"""
Adds a result to the cache.
Args:
*args: args to litellm.completion() or embedding()
**kwargs: kwargs to litellm.completion() or embedding()
Returns:
None
"""
try:
cache_key, cached_data, kwargs = self._add_cache_logic(
result=result, *args, **kwargs
)
self.cache.set_cache(cache_key, cached_data, **kwargs)
except Exception as e: except Exception as e:
print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}") print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
traceback.print_exc() traceback.print_exc()
pass pass
async def _async_add_cache(self, result, *args, **kwargs): async def async_add_cache(self, result, *args, **kwargs):
self.add_cache(result, *args, **kwargs) """
Async implementation of add_cache
"""
try:
cache_key, cached_data, kwargs = self._add_cache_logic(
result=result, *args, **kwargs
)
await self.cache.async_set_cache(cache_key, cached_data, **kwargs)
except Exception as e:
print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
traceback.print_exc()
async def async_add_cache_pipeline(self, result, *args, **kwargs):
"""
Async implementation of add_cache for Embedding calls
Does a bulk write, to prevent using too many clients
"""
try:
cache_list = []
for idx, i in enumerate(kwargs["input"]):
preset_cache_key = litellm.cache.get_cache_key(
*args, **{**kwargs, "input": i}
)
kwargs["cache_key"] = preset_cache_key
embedding_response = result.data[idx]
cache_key, cached_data, kwargs = self._add_cache_logic(
result=embedding_response,
*args,
**kwargs,
)
cache_list.append((cache_key, cached_data))
if hasattr(self.cache, "async_set_cache_pipeline"):
await self.cache.async_set_cache_pipeline(cache_list=cache_list)
else:
tasks = []
for val in cache_list:
tasks.append(
self.cache.async_set_cache(cache_key, cached_data, **kwargs)
)
await asyncio.gather(*tasks)
except Exception as e:
print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
traceback.print_exc()
async def disconnect(self):
if hasattr(self.cache, "disconnect"):
await self.cache.disconnect()
def enable_cache( def enable_cache(

View file

@ -440,8 +440,8 @@ class OpenAIChatCompletion(BaseLLM):
input=data["messages"], input=data["messages"],
api_key=api_key, api_key=api_key,
additional_args={ additional_args={
"headers": headers, "headers": {"Authorization": f"Bearer {openai_client.api_key}"},
"api_base": api_base, "api_base": openai_client._base_url._uri_reference,
"acompletion": False, "acompletion": False,
"complete_input_dict": data, "complete_input_dict": data,
}, },

View file

@ -1,3 +1,7 @@
"""
Deprecated. We now do together ai calls via the openai client.
Reference: https://docs.together.ai/docs/openai-api-compatibility
"""
import os, types import os, types
import json import json
from enum import Enum from enum import Enum

View file

@ -10,6 +10,7 @@
import os, openai, sys, json, inspect, uuid, datetime, threading import os, openai, sys, json, inspect, uuid, datetime, threading
from typing import Any, Literal, Union from typing import Any, Literal, Union
from functools import partial from functools import partial
import dotenv, traceback, random, asyncio, time, contextvars import dotenv, traceback, random, asyncio, time, contextvars
from copy import deepcopy from copy import deepcopy
import httpx import httpx
@ -234,6 +235,9 @@ async def acompletion(
"model_list": model_list, "model_list": model_list,
"acompletion": True, # assuming this is a required parameter "acompletion": True, # assuming this is a required parameter
} }
_, custom_llm_provider, _, _ = get_llm_provider(
model=model, api_base=completion_kwargs.get("base_url", None)
)
try: try:
# Use a partial function to pass your keyword arguments # Use a partial function to pass your keyword arguments
func = partial(completion, **completion_kwargs, **kwargs) func = partial(completion, **completion_kwargs, **kwargs)
@ -245,7 +249,6 @@ async def acompletion(
_, custom_llm_provider, _, _ = get_llm_provider( _, custom_llm_provider, _, _ = get_llm_provider(
model=model, api_base=kwargs.get("api_base", None) model=model, api_base=kwargs.get("api_base", None)
) )
if ( if (
custom_llm_provider == "openai" custom_llm_provider == "openai"
or custom_llm_provider == "azure" or custom_llm_provider == "azure"
@ -788,6 +791,7 @@ def completion(
or custom_llm_provider == "anyscale" or custom_llm_provider == "anyscale"
or custom_llm_provider == "mistral" or custom_llm_provider == "mistral"
or custom_llm_provider == "openai" or custom_llm_provider == "openai"
or custom_llm_provider == "together_ai"
or "ft:gpt-3.5-turbo" in model # finetune gpt-3.5-turbo or "ft:gpt-3.5-turbo" in model # finetune gpt-3.5-turbo
): # allow user to make an openai call with a custom base ): # allow user to make an openai call with a custom base
# note: if a user sets a custom base - we should ensure this works # note: if a user sets a custom base - we should ensure this works
@ -1327,6 +1331,9 @@ def completion(
or ("togethercomputer" in model) or ("togethercomputer" in model)
or (model in litellm.together_ai_models) or (model in litellm.together_ai_models)
): ):
"""
Deprecated. We now do together ai calls via the openai client - https://docs.together.ai/docs/openai-api-compatibility
"""
custom_llm_provider = "together_ai" custom_llm_provider = "together_ai"
together_ai_key = ( together_ai_key = (
api_key api_key

View file

@ -380,7 +380,7 @@ def run_server(
import gunicorn.app.base import gunicorn.app.base
except: except:
raise ImportError( raise ImportError(
"Uvicorn, gunicorn needs to be imported. Run - `pip 'litellm[proxy]'`" "uvicorn, gunicorn needs to be imported. Run - `pip install 'litellm[proxy]'`"
) )
if config is not None: if config is not None:
@ -444,6 +444,7 @@ def run_server(
) )
if port == 8000 and is_port_in_use(port): if port == 8000 and is_port_in_use(port):
port = random.randint(1024, 49152) port = random.randint(1024, 49152)
from litellm.proxy.proxy_server import app from litellm.proxy.proxy_server import app
if run_gunicorn == False: if run_gunicorn == False:
@ -521,5 +522,6 @@ def run_server(
).run() # Run gunicorn ).run() # Run gunicorn
if __name__ == "__main__": if __name__ == "__main__":
run_server() run_server()

View file

@ -7,6 +7,20 @@ import secrets, subprocess
import hashlib, uuid import hashlib, uuid
import warnings import warnings
import importlib import importlib
import warnings
def showwarning(message, category, filename, lineno, file=None, line=None):
traceback_info = f"{filename}:{lineno}: {category.__name__}: {message}\n"
if file is not None:
file.write(traceback_info)
warnings.showwarning = showwarning
warnings.filterwarnings("default", category=UserWarning)
# Your client code here
messages: list = [] messages: list = []
sys.path.insert( sys.path.insert(
@ -4053,9 +4067,12 @@ def _has_user_setup_sso():
async def shutdown_event(): async def shutdown_event():
global prisma_client, master_key, user_custom_auth, user_custom_key_generate global prisma_client, master_key, user_custom_auth, user_custom_key_generate
if prisma_client: if prisma_client:
verbose_proxy_logger.debug("Disconnecting from Prisma") verbose_proxy_logger.debug("Disconnecting from Prisma")
await prisma_client.disconnect() await prisma_client.disconnect()
if litellm.cache is not None:
await litellm.cache.disconnect()
## RESET CUSTOM VARIABLES ## ## RESET CUSTOM VARIABLES ##
cleanup_router_config_variables() cleanup_router_config_variables()

View file

@ -21,10 +21,18 @@ def setup_and_teardown():
import litellm import litellm
importlib.reload(litellm) importlib.reload(litellm)
import asyncio
loop = asyncio.get_event_loop_policy().new_event_loop()
asyncio.set_event_loop(loop)
print(litellm) print(litellm)
# from litellm import Router, completion, aembedding, acompletion, embedding # from litellm import Router, completion, aembedding, acompletion, embedding
yield yield
# Teardown code (executes after the yield point)
loop.close() # Close the loop created earlier
asyncio.set_event_loop(None) # Remove the reference to the loop
def pytest_collection_modifyitems(config, items): def pytest_collection_modifyitems(config, items):
# Separate tests in 'test_amazing_proxy_custom_logger.py' and other tests # Separate tests in 'test_amazing_proxy_custom_logger.py' and other tests

View file

@ -11,10 +11,10 @@ sys.path.insert(
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import pytest import pytest
import litellm import litellm
from litellm import embedding, completion from litellm import embedding, completion, aembedding
from litellm.caching import Cache from litellm.caching import Cache
import random import random
import hashlib import hashlib, asyncio
# litellm.set_verbose=True # litellm.set_verbose=True
@ -106,10 +106,7 @@ def test_caching_with_cache_controls():
) )
print(f"response1: {response1}") print(f"response1: {response1}")
print(f"response2: {response2}") print(f"response2: {response2}")
assert ( assert response2["id"] == response1["id"]
response2["choices"][0]["message"]["content"]
== response1["choices"][0]["message"]["content"]
)
except Exception as e: except Exception as e:
print(f"error occurred: {traceback.format_exc()}") print(f"error occurred: {traceback.format_exc()}")
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@ -259,6 +256,84 @@ def test_embedding_caching_azure():
# test_embedding_caching_azure() # test_embedding_caching_azure()
@pytest.mark.asyncio
async def test_embedding_caching_azure_individual_items():
"""
Tests caching for individual items in an embedding list
- Cache an item
- call aembedding(..) with the item + 1 unique item
- compare to a 2nd aembedding (...) with 2 unique items
```
embedding_1 = ["hey how's it going", "I'm doing well"]
embedding_val_1 = embedding(...)
embedding_2 = ["hey how's it going", "I'm fine"]
embedding_val_2 = embedding(...)
assert embedding_val_1[0]["id"] == embedding_val_2[0]["id"]
```
"""
litellm.cache = Cache()
common_msg = f"hey how's it going {uuid.uuid4()}"
common_msg_2 = f"hey how's it going {uuid.uuid4()}"
embedding_1 = [common_msg]
embedding_2 = [
common_msg,
f"I'm fine {uuid.uuid4()}",
]
embedding_val_1 = await aembedding(
model="azure/azure-embedding-model", input=embedding_1, caching=True
)
embedding_val_2 = await aembedding(
model="azure/azure-embedding-model", input=embedding_2, caching=True
)
print(f"embedding_val_2._hidden_params: {embedding_val_2._hidden_params}")
assert embedding_val_2._hidden_params["cache_hit"] == True
@pytest.mark.asyncio
async def test_redis_cache_basic():
"""
Init redis client
- write to client
- read from client
"""
litellm.set_verbose = False
random_number = random.randint(
1, 100000
) # add a random number to ensure it's always adding / reading from cache
messages = [
{"role": "user", "content": f"write a one sentence poem about: {random_number}"}
]
litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
response1 = completion(
model="gpt-3.5-turbo",
messages=messages,
)
cache_key = litellm.cache.get_cache_key(
model="gpt-3.5-turbo",
messages=messages,
)
print(f"cache_key: {cache_key}")
litellm.cache.add_cache(result=response1, cache_key=cache_key)
print(f"cache key pre async get: {cache_key}")
stored_val = await litellm.cache.async_get_cache(
model="gpt-3.5-turbo",
messages=messages,
)
print(f"stored_val: {stored_val}")
assert stored_val["id"] == response1.id
def test_redis_cache_completion(): def test_redis_cache_completion():
litellm.set_verbose = False litellm.set_verbose = False
@ -406,7 +481,7 @@ def test_redis_cache_acompletion_stream():
import asyncio import asyncio
try: try:
litellm.set_verbose = True litellm.set_verbose = False
random_word = generate_random_word() random_word = generate_random_word()
messages = [ messages = [
{ {
@ -434,7 +509,6 @@ def test_redis_cache_acompletion_stream():
stream=True, stream=True,
) )
async for chunk in response1: async for chunk in response1:
print(chunk)
response_1_content += chunk.choices[0].delta.content or "" response_1_content += chunk.choices[0].delta.content or ""
print(response_1_content) print(response_1_content)
@ -452,7 +526,6 @@ def test_redis_cache_acompletion_stream():
stream=True, stream=True,
) )
async for chunk in response2: async for chunk in response2:
print(chunk)
response_2_content += chunk.choices[0].delta.content or "" response_2_content += chunk.choices[0].delta.content or ""
print(response_2_content) print(response_2_content)
@ -914,101 +987,3 @@ def test_cache_context_managers():
# test_cache_context_managers() # test_cache_context_managers()
# test_custom_redis_cache_params()
# def test_redis_cache_with_ttl():
# cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
# sample_model_response_object_str = """{
# "choices": [
# {
# "finish_reason": "stop",
# "index": 0,
# "message": {
# "role": "assistant",
# "content": "I'm doing well, thank you for asking. I am Claude, an AI assistant created by Anthropic."
# }
# }
# ],
# "created": 1691429984.3852863,
# "model": "claude-instant-1",
# "usage": {
# "prompt_tokens": 18,
# "completion_tokens": 23,
# "total_tokens": 41
# }
# }"""
# sample_model_response_object = {
# "choices": [
# {
# "finish_reason": "stop",
# "index": 0,
# "message": {
# "role": "assistant",
# "content": "I'm doing well, thank you for asking. I am Claude, an AI assistant created by Anthropic."
# }
# }
# ],
# "created": 1691429984.3852863,
# "model": "claude-instant-1",
# "usage": {
# "prompt_tokens": 18,
# "completion_tokens": 23,
# "total_tokens": 41
# }
# }
# cache.add_cache(cache_key="test_key", result=sample_model_response_object_str, ttl=1)
# cached_value = cache.get_cache(cache_key="test_key")
# print(f"cached-value: {cached_value}")
# assert cached_value['choices'][0]['message']['content'] == sample_model_response_object['choices'][0]['message']['content']
# time.sleep(2)
# assert cache.get_cache(cache_key="test_key") is None
# # test_redis_cache_with_ttl()
# def test_in_memory_cache_with_ttl():
# cache = Cache(type="local")
# sample_model_response_object_str = """{
# "choices": [
# {
# "finish_reason": "stop",
# "index": 0,
# "message": {
# "role": "assistant",
# "content": "I'm doing well, thank you for asking. I am Claude, an AI assistant created by Anthropic."
# }
# }
# ],
# "created": 1691429984.3852863,
# "model": "claude-instant-1",
# "usage": {
# "prompt_tokens": 18,
# "completion_tokens": 23,
# "total_tokens": 41
# }
# }"""
# sample_model_response_object = {
# "choices": [
# {
# "finish_reason": "stop",
# "index": 0,
# "message": {
# "role": "assistant",
# "content": "I'm doing well, thank you for asking. I am Claude, an AI assistant created by Anthropic."
# }
# }
# ],
# "created": 1691429984.3852863,
# "model": "claude-instant-1",
# "usage": {
# "prompt_tokens": 18,
# "completion_tokens": 23,
# "total_tokens": 41
# }
# }
# cache.add_cache(cache_key="test_key", result=sample_model_response_object_str, ttl=1)
# cached_value = cache.get_cache(cache_key="test_key")
# assert cached_value['choices'][0]['message']['content'] == sample_model_response_object['choices'][0]['message']['content']
# time.sleep(2)
# assert cache.get_cache(cache_key="test_key") is None
# # test_in_memory_cache_with_ttl()

View file

@ -1994,6 +1994,7 @@ def test_completion_palm_stream():
def test_completion_together_ai_stream(): def test_completion_together_ai_stream():
litellm.set_verbose = True
user_message = "Write 1pg about YC & litellm" user_message = "Write 1pg about YC & litellm"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
try: try:

View file

@ -556,7 +556,6 @@ async def test_async_chat_bedrock_stream():
# asyncio.run(test_async_chat_bedrock_stream()) # asyncio.run(test_async_chat_bedrock_stream())
## Test Sagemaker + Async ## Test Sagemaker + Async
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_chat_sagemaker_stream(): async def test_async_chat_sagemaker_stream():
@ -725,7 +724,7 @@ async def test_async_embedding_bedrock():
response = await litellm.aembedding( response = await litellm.aembedding(
model="bedrock/cohere.embed-multilingual-v3", model="bedrock/cohere.embed-multilingual-v3",
input=["good morning from litellm"], input=["good morning from litellm"],
aws_region_name="os.environ/AWS_REGION_NAME_2", aws_region_name="us-east-1",
) )
await asyncio.sleep(1) await asyncio.sleep(1)
print(f"customHandler_success.errors: {customHandler_success.errors}") print(f"customHandler_success.errors: {customHandler_success.errors}")
@ -758,6 +757,7 @@ async def test_async_embedding_bedrock():
## Test Azure - completion, embedding ## Test Azure - completion, embedding
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_completion_azure_caching(): async def test_async_completion_azure_caching():
litellm.set_verbose = True
customHandler_caching = CompletionCustomHandler() customHandler_caching = CompletionCustomHandler()
litellm.cache = Cache( litellm.cache = Cache(
type="redis", type="redis",
@ -812,6 +812,7 @@ async def test_async_embedding_azure_caching():
) )
await asyncio.sleep(1) # success callbacks are done in parallel await asyncio.sleep(1) # success callbacks are done in parallel
print(customHandler_caching.states) print(customHandler_caching.states)
print(customHandler_caching.errors)
assert len(customHandler_caching.errors) == 0 assert len(customHandler_caching.errors) == 0
assert len(customHandler_caching.states) == 4 # pre, post, success, success assert len(customHandler_caching.states) == 4 # pre, post, success, success

View file

@ -55,6 +55,7 @@ from .integrations.litedebugger import LiteDebugger
from .proxy._types import KeyManagementSystem from .proxy._types import KeyManagementSystem
from openai import OpenAIError as OriginalError from openai import OpenAIError as OriginalError
from openai._models import BaseModel as OpenAIObject from openai._models import BaseModel as OpenAIObject
from .caching import S3Cache
from .exceptions import ( from .exceptions import (
AuthenticationError, AuthenticationError,
BadRequestError, BadRequestError,
@ -862,6 +863,7 @@ class Logging:
curl_command += additional_args.get("request_str", None) curl_command += additional_args.get("request_str", None)
elif api_base == "": elif api_base == "":
curl_command = self.model_call_details curl_command = self.model_call_details
print_verbose(f"\033[92m{curl_command}\033[0m\n")
verbose_logger.info(f"\033[92m{curl_command}\033[0m\n") verbose_logger.info(f"\033[92m{curl_command}\033[0m\n")
if self.logger_fn and callable(self.logger_fn): if self.logger_fn and callable(self.logger_fn):
try: try:
@ -2196,12 +2198,21 @@ def client(original_function):
) )
# if caching is false or cache["no-cache"]==True, don't run this # if caching is false or cache["no-cache"]==True, don't run this
if ( if (
(kwargs.get("caching", None) is None and litellm.cache is not None) (
or kwargs.get("caching", False) == True (
or ( kwargs.get("caching", None) is None
kwargs.get("cache", None) is not None and kwargs.get("cache", None) is None
and kwargs.get("cache", {}).get("no-cache", False) != True and litellm.cache is not None
)
or kwargs.get("caching", False) == True
or (
kwargs.get("cache", None) is not None
and kwargs.get("cache", {}).get("no-cache", False) != True
)
) )
and kwargs.get("aembedding", False) != True
and kwargs.get("acompletion", False) != True
and kwargs.get("aimg_generation", False) != True
): # allow users to control returning cached responses from the completion function ): # allow users to control returning cached responses from the completion function
# checking cache # checking cache
print_verbose(f"INSIDE CHECKING CACHE") print_verbose(f"INSIDE CHECKING CACHE")
@ -2435,6 +2446,7 @@ def client(original_function):
result = None result = None
logging_obj = kwargs.get("litellm_logging_obj", None) logging_obj = kwargs.get("litellm_logging_obj", None)
# only set litellm_call_id if its not in kwargs # only set litellm_call_id if its not in kwargs
call_type = original_function.__name__
if "litellm_call_id" not in kwargs: if "litellm_call_id" not in kwargs:
kwargs["litellm_call_id"] = str(uuid.uuid4()) kwargs["litellm_call_id"] = str(uuid.uuid4())
try: try:
@ -2465,8 +2477,14 @@ def client(original_function):
f"kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}" f"kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}"
) )
# if caching is false, don't run this # if caching is false, don't run this
final_embedding_cached_response = None
if ( if (
(kwargs.get("caching", None) is None and litellm.cache is not None) (
kwargs.get("caching", None) is None
and kwargs.get("cache", None) is None
and litellm.cache is not None
)
or kwargs.get("caching", False) == True or kwargs.get("caching", False) == True
or ( or (
kwargs.get("cache", None) is not None kwargs.get("cache", None) is not None
@ -2481,8 +2499,36 @@ def client(original_function):
in litellm.cache.supported_call_types in litellm.cache.supported_call_types
): ):
print_verbose(f"Checking Cache") print_verbose(f"Checking Cache")
cached_result = litellm.cache.get_cache(*args, **kwargs) if call_type == CallTypes.aembedding.value and isinstance(
if cached_result != None: kwargs["input"], list
):
tasks = []
for idx, i in enumerate(kwargs["input"]):
preset_cache_key = litellm.cache.get_cache_key(
*args, **{**kwargs, "input": i}
)
tasks.append(
litellm.cache.async_get_cache(
cache_key=preset_cache_key
)
)
cached_result = await asyncio.gather(*tasks)
## check if cached result is None ##
if cached_result is not None and isinstance(
cached_result, list
):
if len(cached_result) == 1 and cached_result[0] is None:
cached_result = None
else:
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
kwargs[
"preset_cache_key"
] = preset_cache_key # for streaming calls, we need to pass the preset_cache_key
cached_result = litellm.cache.get_cache(*args, **kwargs)
if cached_result is not None and not isinstance(
cached_result, list
):
print_verbose(f"Cache Hit!") print_verbose(f"Cache Hit!")
call_type = original_function.__name__ call_type = original_function.__name__
if call_type == CallTypes.acompletion.value and isinstance( if call_type == CallTypes.acompletion.value and isinstance(
@ -2555,6 +2601,103 @@ def client(original_function):
args=(cached_result, start_time, end_time, cache_hit), args=(cached_result, start_time, end_time, cache_hit),
).start() ).start()
return cached_result return cached_result
elif (
call_type == CallTypes.aembedding.value
and cached_result is not None
and isinstance(cached_result, list)
and litellm.cache is not None
and not isinstance(
litellm.cache.cache, S3Cache
) # s3 doesn't support bulk writing. Exclude.
):
remaining_list = []
non_null_list = []
for idx, cr in enumerate(cached_result):
if cr is None:
remaining_list.append(kwargs["input"][idx])
else:
non_null_list.append((idx, cr))
original_kwargs_input = kwargs["input"]
kwargs["input"] = remaining_list
if len(non_null_list) > 0:
print_verbose(
f"EMBEDDING CACHE HIT! - {len(non_null_list)}"
)
final_embedding_cached_response = EmbeddingResponse(
model=kwargs.get("model"),
data=[None] * len(original_kwargs_input),
)
final_embedding_cached_response._hidden_params[
"cache_hit"
] = True
for val in non_null_list:
idx, cr = val # (idx, cr) tuple
if cr is not None:
final_embedding_cached_response.data[idx] = cr
if len(remaining_list) == 0:
# LOG SUCCESS
cache_hit = True
end_time = datetime.datetime.now()
(
model,
custom_llm_provider,
dynamic_api_key,
api_base,
) = litellm.get_llm_provider(
model=model,
custom_llm_provider=kwargs.get(
"custom_llm_provider", None
),
api_base=kwargs.get("api_base", None),
api_key=kwargs.get("api_key", None),
)
print_verbose(
f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}"
)
logging_obj.update_environment_variables(
model=model,
user=kwargs.get("user", None),
optional_params={},
litellm_params={
"logger_fn": kwargs.get("logger_fn", None),
"acompletion": True,
"metadata": kwargs.get("metadata", {}),
"model_info": kwargs.get("model_info", {}),
"proxy_server_request": kwargs.get(
"proxy_server_request", None
),
"preset_cache_key": kwargs.get(
"preset_cache_key", None
),
"stream_response": kwargs.get(
"stream_response", {}
),
},
input=kwargs.get("messages", ""),
api_key=kwargs.get("api_key", None),
original_response=str(final_embedding_cached_response),
additional_args=None,
stream=kwargs.get("stream", False),
)
asyncio.create_task(
logging_obj.async_success_handler(
final_embedding_cached_response,
start_time,
end_time,
cache_hit,
)
)
threading.Thread(
target=logging_obj.success_handler,
args=(
final_embedding_cached_response,
start_time,
end_time,
cache_hit,
),
).start()
return final_embedding_cached_response
# MODEL CALL # MODEL CALL
result = await original_function(*args, **kwargs) result = await original_function(*args, **kwargs)
end_time = datetime.datetime.now() end_time = datetime.datetime.now()
@ -2587,12 +2730,28 @@ def client(original_function):
if isinstance(result, litellm.ModelResponse) or isinstance( if isinstance(result, litellm.ModelResponse) or isinstance(
result, litellm.EmbeddingResponse result, litellm.EmbeddingResponse
): ):
asyncio.create_task( if (
litellm.cache._async_add_cache(result.json(), *args, **kwargs) isinstance(result, EmbeddingResponse)
) and isinstance(kwargs["input"], list)
and litellm.cache is not None
and not isinstance(
litellm.cache.cache, S3Cache
) # s3 doesn't support bulk writing. Exclude.
):
asyncio.create_task(
litellm.cache.async_add_cache_pipeline(
result, *args, **kwargs
)
)
else:
asyncio.create_task(
litellm.cache.async_add_cache(
result.json(), *args, **kwargs
)
)
else: else:
asyncio.create_task( asyncio.create_task(
litellm.cache._async_add_cache(result, *args, **kwargs) litellm.cache.async_add_cache(result, *args, **kwargs)
) )
# LOG SUCCESS - handle streaming success logging in the _next_ object # LOG SUCCESS - handle streaming success logging in the _next_ object
print_verbose( print_verbose(
@ -2616,6 +2775,27 @@ def client(original_function):
result._response_ms = ( result._response_ms = (
end_time - start_time end_time - start_time
).total_seconds() * 1000 # return response latency in ms like openai ).total_seconds() * 1000 # return response latency in ms like openai
if (
isinstance(result, EmbeddingResponse)
and final_embedding_cached_response is not None
):
idx = 0
final_data_list = []
for item in final_embedding_cached_response.data:
if item is None:
final_data_list.append(result.data[idx])
idx += 1
else:
final_data_list.append(item)
final_embedding_cached_response.data = final_data_list
final_embedding_cached_response._hidden_params["cache_hit"] = True
final_embedding_cached_response._response_ms = (
end_time - start_time
).total_seconds() * 1000
return final_embedding_cached_response
return result return result
except Exception as e: except Exception as e:
traceback_exception = traceback.format_exc() traceback_exception = traceback.format_exc()
@ -3275,7 +3455,11 @@ def completion_cost(
else: else:
raise Exception(f"Model={model} not found in completion cost model map") raise Exception(f"Model={model} not found in completion cost model map")
# Calculate cost based on prompt_tokens, completion_tokens # Calculate cost based on prompt_tokens, completion_tokens
if "togethercomputer" in model or "together_ai" in model: if (
"togethercomputer" in model
or "together_ai" in model
or custom_llm_provider == "together_ai"
):
# together ai prices based on size of llm # together ai prices based on size of llm
# get_model_params_and_category takes a model name and returns the category of LLM size it is in model_prices_and_context_window.json # get_model_params_and_category takes a model name and returns the category of LLM size it is in model_prices_and_context_window.json
model = get_model_params_and_category(model) model = get_model_params_and_category(model)
@ -3864,7 +4048,7 @@ def get_optional_params(
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
if stream: if stream:
optional_params["stream_tokens"] = stream optional_params["stream"] = stream
if temperature is not None: if temperature is not None:
optional_params["temperature"] = temperature optional_params["temperature"] = temperature
if top_p is not None: if top_p is not None:
@ -4498,6 +4682,14 @@ def get_llm_provider(
# voyage is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.voyageai.com/v1 # voyage is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.voyageai.com/v1
api_base = "https://api.voyageai.com/v1" api_base = "https://api.voyageai.com/v1"
dynamic_api_key = get_secret("VOYAGE_API_KEY") dynamic_api_key = get_secret("VOYAGE_API_KEY")
elif custom_llm_provider == "together_ai":
api_base = "https://api.together.xyz/v1"
dynamic_api_key = (
get_secret("TOGETHER_API_KEY")
or get_secret("TOGETHER_AI_API_KEY")
or get_secret("TOGETHERAI_API_KEY")
or get_secret("TOGETHER_AI_TOKEN")
)
return model, custom_llm_provider, dynamic_api_key, api_base return model, custom_llm_provider, dynamic_api_key, api_base
elif model.split("/", 1)[0] in litellm.provider_list: elif model.split("/", 1)[0] in litellm.provider_list:
custom_llm_provider = model.split("/", 1)[0] custom_llm_provider = model.split("/", 1)[0]
@ -6383,7 +6575,12 @@ def exception_type(
message=f"BedrockException - {original_exception.message}", message=f"BedrockException - {original_exception.message}",
llm_provider="bedrock", llm_provider="bedrock",
model=model, model=model,
response=original_exception.response, response=httpx.Response(
status_code=500,
request=httpx.Request(
method="POST", url="https://api.openai.com/v1/"
),
),
) )
elif original_exception.status_code == 401: elif original_exception.status_code == 401:
exception_mapping_worked = True exception_mapping_worked = True