From 3c71eb1e71d68e7c9a5f54bf9017fcedb206ac13 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 6 Feb 2024 10:22:02 -0800 Subject: [PATCH] allow setting redis_semantic cache_embedding model --- litellm/caching.py | 54 +++++++++++++++++++++++++++++++++++----------- 1 file changed, 42 insertions(+), 12 deletions(-) diff --git a/litellm/caching.py b/litellm/caching.py index 133d1db6dd..6bf53ea451 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -232,6 +232,7 @@ class RedisSemanticCache(BaseCache): redis_url=None, similarity_threshold=None, use_async=False, + embedding_model="text-embedding-ada-002", **kwargs, ): from redisvl.index import SearchIndex @@ -243,6 +244,7 @@ class RedisSemanticCache(BaseCache): if similarity_threshold is None: raise Exception("similarity_threshold must be provided, passed None") self.similarity_threshold = similarity_threshold + self.embedding_model = embedding_model schema = { "index": { "name": "litellm_semantic_cache_index", @@ -322,7 +324,7 @@ class RedisSemanticCache(BaseCache): # create an embedding for prompt embedding_response = litellm.embedding( - model="text-embedding-ada-002", + model=self.embedding_model, input=prompt, cache={"no-store": True, "no-cache": True}, ) @@ -359,7 +361,7 @@ class RedisSemanticCache(BaseCache): # convert to embedding embedding_response = litellm.embedding( - model="text-embedding-ada-002", + model=self.embedding_model, input=prompt, cache={"no-store": True, "no-cache": True}, ) @@ -405,6 +407,7 @@ class RedisSemanticCache(BaseCache): async def async_set_cache(self, key, value, **kwargs): import numpy as np + from litellm.proxy.proxy_server import llm_router, llm_model_list try: await self.index.acreate(overwrite=False) # don't overwrite existing index @@ -418,12 +421,24 @@ class RedisSemanticCache(BaseCache): for message in messages: prompt += message["content"] # create an embedding for prompt - - embedding_response = await litellm.aembedding( - model="text-embedding-ada-002", - input=prompt, - cache={"no-store": True, "no-cache": True}, + router_model_names = ( + [m["model_name"] for m in llm_model_list] + if llm_model_list is not None + else [] ) + if llm_router is not None and self.embedding_model in router_model_names: + embedding_response = await llm_router.aembedding( + model=self.embedding_model, + input=prompt, + cache={"no-store": True, "no-cache": True}, + ) + else: + # convert to embedding + embedding_response = await litellm.aembedding( + model=self.embedding_model, + input=prompt, + cache={"no-store": True, "no-cache": True}, + ) # get the embedding embedding = embedding_response["data"][0]["embedding"] @@ -445,6 +460,7 @@ class RedisSemanticCache(BaseCache): print_verbose(f"async redis semantic-cache get_cache, kwargs: {kwargs}") from redisvl.query import VectorQuery import numpy as np + from litellm.proxy.proxy_server import llm_router, llm_model_list # query @@ -454,12 +470,24 @@ class RedisSemanticCache(BaseCache): for message in messages: prompt += message["content"] - # convert to embedding - embedding_response = await litellm.aembedding( - model="text-embedding-ada-002", - input=prompt, - cache={"no-store": True, "no-cache": True}, + router_model_names = ( + [m["model_name"] for m in llm_model_list] + if llm_model_list is not None + else [] ) + if llm_router is not None and self.embedding_model in router_model_names: + embedding_response = await llm_router.aembedding( + model=self.embedding_model, + input=prompt, + cache={"no-store": True, "no-cache": True}, + ) + else: + # convert to embedding + embedding_response = await litellm.aembedding( + model=self.embedding_model, + input=prompt, + cache={"no-store": True, "no-cache": True}, + ) # get the embedding embedding = embedding_response["data"][0]["embedding"] @@ -727,6 +755,7 @@ class Cache: s3_aws_session_token: Optional[str] = None, s3_config: Optional[Any] = None, redis_semantic_cache_use_async=False, + redis_semantic_cache_embedding_model="text-embedding-ada-002", **kwargs, ): """ @@ -757,6 +786,7 @@ class Cache: password, similarity_threshold=similarity_threshold, use_async=redis_semantic_cache_use_async, + embedding_model=redis_semantic_cache_embedding_model, **kwargs, ) elif type == "local":