fix(utils.py): exclude s3 caching from individual item caching for embedding list

can't bulk upload to s3, so this will slow down calls

https://github.com/BerriAI/litellm/pull/1417
This commit is contained in:
Krrish Dholakia 2024-01-13 16:19:30 +05:30
parent 0bcca3fed3
commit f08bb7e41f
2 changed files with 14 additions and 4 deletions

View file

@ -444,9 +444,9 @@ class Cache:
""" """
if type == "redis": if type == "redis":
self.cache: BaseCache = RedisCache(host, port, password, **kwargs) self.cache: BaseCache = RedisCache(host, port, password, **kwargs)
if type == "local": elif type == "local":
self.cache = InMemoryCache() self.cache = InMemoryCache()
if type == "s3": elif type == "s3":
self.cache = S3Cache( self.cache = S3Cache(
s3_bucket_name=s3_bucket_name, s3_bucket_name=s3_bucket_name,
s3_region_name=s3_region_name, s3_region_name=s3_region_name,

View file

@ -53,6 +53,7 @@ from .integrations.litedebugger import LiteDebugger
from .proxy._types import KeyManagementSystem from .proxy._types import KeyManagementSystem
from openai import OpenAIError as OriginalError from openai import OpenAIError as OriginalError
from openai._models import BaseModel as OpenAIObject from openai._models import BaseModel as OpenAIObject
from .caching import S3Cache
from .exceptions import ( from .exceptions import (
AuthenticationError, AuthenticationError,
BadRequestError, BadRequestError,
@ -2338,6 +2339,10 @@ def client(original_function):
call_type == CallTypes.aembedding.value call_type == CallTypes.aembedding.value
and cached_result is not None and cached_result is not None
and isinstance(cached_result, list) and isinstance(cached_result, list)
and litellm.cache is not None
and not isinstance(
litellm.cache.cache, S3Cache
) # s3 doesn't support bulk writing. Exclude.
): ):
remaining_list = [] remaining_list = []
non_null_list = [] non_null_list = []
@ -2458,8 +2463,13 @@ def client(original_function):
if isinstance(result, litellm.ModelResponse) or isinstance( if isinstance(result, litellm.ModelResponse) or isinstance(
result, litellm.EmbeddingResponse result, litellm.EmbeddingResponse
): ):
if isinstance(result, EmbeddingResponse) and isinstance( if (
kwargs["input"], list isinstance(result, EmbeddingResponse)
and isinstance(kwargs["input"], list)
and litellm.cache is not None
and not isinstance(
litellm.cache.cache, S3Cache
) # s3 doesn't support bulk writing. Exclude.
): ):
asyncio.create_task( asyncio.create_task(
litellm.cache.async_add_cache_pipeline( litellm.cache.async_add_cache_pipeline(