forked from phoenix/litellm-mirror
allow setting redis_semantic cache_embedding model
This commit is contained in:
parent
751fb1af89
commit
05f379234d
1 changed files with 42 additions and 12 deletions
|
@ -232,6 +232,7 @@ class RedisSemanticCache(BaseCache):
|
|||
redis_url=None,
|
||||
similarity_threshold=None,
|
||||
use_async=False,
|
||||
embedding_model="text-embedding-ada-002",
|
||||
**kwargs,
|
||||
):
|
||||
from redisvl.index import SearchIndex
|
||||
|
@ -243,6 +244,7 @@ class RedisSemanticCache(BaseCache):
|
|||
if similarity_threshold is None:
|
||||
raise Exception("similarity_threshold must be provided, passed None")
|
||||
self.similarity_threshold = similarity_threshold
|
||||
self.embedding_model = embedding_model
|
||||
schema = {
|
||||
"index": {
|
||||
"name": "litellm_semantic_cache_index",
|
||||
|
@ -322,7 +324,7 @@ class RedisSemanticCache(BaseCache):
|
|||
|
||||
# create an embedding for prompt
|
||||
embedding_response = litellm.embedding(
|
||||
model="text-embedding-ada-002",
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
)
|
||||
|
@ -359,7 +361,7 @@ class RedisSemanticCache(BaseCache):
|
|||
|
||||
# convert to embedding
|
||||
embedding_response = litellm.embedding(
|
||||
model="text-embedding-ada-002",
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
)
|
||||
|
@ -405,6 +407,7 @@ class RedisSemanticCache(BaseCache):
|
|||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
import numpy as np
|
||||
from litellm.proxy.proxy_server import llm_router, llm_model_list
|
||||
|
||||
try:
|
||||
await self.index.acreate(overwrite=False) # don't overwrite existing index
|
||||
|
@ -418,12 +421,24 @@ class RedisSemanticCache(BaseCache):
|
|||
for message in messages:
|
||||
prompt += message["content"]
|
||||
# create an embedding for prompt
|
||||
|
||||
embedding_response = await litellm.aembedding(
|
||||
model="text-embedding-ada-002",
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
router_model_names = (
|
||||
[m["model_name"] for m in llm_model_list]
|
||||
if llm_model_list is not None
|
||||
else []
|
||||
)
|
||||
if llm_router is not None and self.embedding_model in router_model_names:
|
||||
embedding_response = await llm_router.aembedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
)
|
||||
else:
|
||||
# convert to embedding
|
||||
embedding_response = await litellm.aembedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
)
|
||||
|
||||
# get the embedding
|
||||
embedding = embedding_response["data"][0]["embedding"]
|
||||
|
@ -445,6 +460,7 @@ class RedisSemanticCache(BaseCache):
|
|||
print_verbose(f"async redis semantic-cache get_cache, kwargs: {kwargs}")
|
||||
from redisvl.query import VectorQuery
|
||||
import numpy as np
|
||||
from litellm.proxy.proxy_server import llm_router, llm_model_list
|
||||
|
||||
# query
|
||||
|
||||
|
@ -454,12 +470,24 @@ class RedisSemanticCache(BaseCache):
|
|||
for message in messages:
|
||||
prompt += message["content"]
|
||||
|
||||
# convert to embedding
|
||||
embedding_response = await litellm.aembedding(
|
||||
model="text-embedding-ada-002",
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
router_model_names = (
|
||||
[m["model_name"] for m in llm_model_list]
|
||||
if llm_model_list is not None
|
||||
else []
|
||||
)
|
||||
if llm_router is not None and self.embedding_model in router_model_names:
|
||||
embedding_response = await llm_router.aembedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
)
|
||||
else:
|
||||
# convert to embedding
|
||||
embedding_response = await litellm.aembedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
)
|
||||
|
||||
# get the embedding
|
||||
embedding = embedding_response["data"][0]["embedding"]
|
||||
|
@ -727,6 +755,7 @@ class Cache:
|
|||
s3_aws_session_token: Optional[str] = None,
|
||||
s3_config: Optional[Any] = None,
|
||||
redis_semantic_cache_use_async=False,
|
||||
redis_semantic_cache_embedding_model="text-embedding-ada-002",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
@ -757,6 +786,7 @@ class Cache:
|
|||
password,
|
||||
similarity_threshold=similarity_threshold,
|
||||
use_async=redis_semantic_cache_use_async,
|
||||
embedding_model=redis_semantic_cache_embedding_model,
|
||||
**kwargs,
|
||||
)
|
||||
elif type == "local":
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue