allow setting redis_semantic cache_embedding model

This commit is contained in:
ishaan-jaff 2024-02-06 10:22:02 -08:00
parent 617716752e
commit 3c71eb1e71

View file

@ -232,6 +232,7 @@ class RedisSemanticCache(BaseCache):
redis_url=None,
similarity_threshold=None,
use_async=False,
embedding_model="text-embedding-ada-002",
**kwargs,
):
from redisvl.index import SearchIndex
@ -243,6 +244,7 @@ class RedisSemanticCache(BaseCache):
if similarity_threshold is None:
raise Exception("similarity_threshold must be provided, passed None")
self.similarity_threshold = similarity_threshold
self.embedding_model = embedding_model
schema = {
"index": {
"name": "litellm_semantic_cache_index",
@ -322,7 +324,7 @@ class RedisSemanticCache(BaseCache):
# create an embedding for prompt
embedding_response = litellm.embedding(
model="text-embedding-ada-002",
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
)
@ -359,7 +361,7 @@ class RedisSemanticCache(BaseCache):
# convert to embedding
embedding_response = litellm.embedding(
model="text-embedding-ada-002",
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
)
@ -405,6 +407,7 @@ class RedisSemanticCache(BaseCache):
async def async_set_cache(self, key, value, **kwargs):
import numpy as np
from litellm.proxy.proxy_server import llm_router, llm_model_list
try:
await self.index.acreate(overwrite=False) # don't overwrite existing index
@ -418,9 +421,21 @@ class RedisSemanticCache(BaseCache):
for message in messages:
prompt += message["content"]
# create an embedding for prompt
router_model_names = (
[m["model_name"] for m in llm_model_list]
if llm_model_list is not None
else []
)
if llm_router is not None and self.embedding_model in router_model_names:
embedding_response = await llm_router.aembedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
)
else:
# convert to embedding
embedding_response = await litellm.aembedding(
model="text-embedding-ada-002",
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
)
@ -445,6 +460,7 @@ class RedisSemanticCache(BaseCache):
print_verbose(f"async redis semantic-cache get_cache, kwargs: {kwargs}")
from redisvl.query import VectorQuery
import numpy as np
from litellm.proxy.proxy_server import llm_router, llm_model_list
# query
@ -454,9 +470,21 @@ class RedisSemanticCache(BaseCache):
for message in messages:
prompt += message["content"]
router_model_names = (
[m["model_name"] for m in llm_model_list]
if llm_model_list is not None
else []
)
if llm_router is not None and self.embedding_model in router_model_names:
embedding_response = await llm_router.aembedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
)
else:
# convert to embedding
embedding_response = await litellm.aembedding(
model="text-embedding-ada-002",
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
)
@ -727,6 +755,7 @@ class Cache:
s3_aws_session_token: Optional[str] = None,
s3_config: Optional[Any] = None,
redis_semantic_cache_use_async=False,
redis_semantic_cache_embedding_model="text-embedding-ada-002",
**kwargs,
):
"""
@ -757,6 +786,7 @@ class Cache:
password,
similarity_threshold=similarity_threshold,
use_async=redis_semantic_cache_use_async,
embedding_model=redis_semantic_cache_embedding_model,
**kwargs,
)
elif type == "local":