mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
allow setting redis_semantic cache_embedding model
This commit is contained in:
parent
617716752e
commit
3c71eb1e71
1 changed files with 42 additions and 12 deletions
|
@ -232,6 +232,7 @@ class RedisSemanticCache(BaseCache):
|
||||||
redis_url=None,
|
redis_url=None,
|
||||||
similarity_threshold=None,
|
similarity_threshold=None,
|
||||||
use_async=False,
|
use_async=False,
|
||||||
|
embedding_model="text-embedding-ada-002",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
from redisvl.index import SearchIndex
|
from redisvl.index import SearchIndex
|
||||||
|
@ -243,6 +244,7 @@ class RedisSemanticCache(BaseCache):
|
||||||
if similarity_threshold is None:
|
if similarity_threshold is None:
|
||||||
raise Exception("similarity_threshold must be provided, passed None")
|
raise Exception("similarity_threshold must be provided, passed None")
|
||||||
self.similarity_threshold = similarity_threshold
|
self.similarity_threshold = similarity_threshold
|
||||||
|
self.embedding_model = embedding_model
|
||||||
schema = {
|
schema = {
|
||||||
"index": {
|
"index": {
|
||||||
"name": "litellm_semantic_cache_index",
|
"name": "litellm_semantic_cache_index",
|
||||||
|
@ -322,7 +324,7 @@ class RedisSemanticCache(BaseCache):
|
||||||
|
|
||||||
# create an embedding for prompt
|
# create an embedding for prompt
|
||||||
embedding_response = litellm.embedding(
|
embedding_response = litellm.embedding(
|
||||||
model="text-embedding-ada-002",
|
model=self.embedding_model,
|
||||||
input=prompt,
|
input=prompt,
|
||||||
cache={"no-store": True, "no-cache": True},
|
cache={"no-store": True, "no-cache": True},
|
||||||
)
|
)
|
||||||
|
@ -359,7 +361,7 @@ class RedisSemanticCache(BaseCache):
|
||||||
|
|
||||||
# convert to embedding
|
# convert to embedding
|
||||||
embedding_response = litellm.embedding(
|
embedding_response = litellm.embedding(
|
||||||
model="text-embedding-ada-002",
|
model=self.embedding_model,
|
||||||
input=prompt,
|
input=prompt,
|
||||||
cache={"no-store": True, "no-cache": True},
|
cache={"no-store": True, "no-cache": True},
|
||||||
)
|
)
|
||||||
|
@ -405,6 +407,7 @@ class RedisSemanticCache(BaseCache):
|
||||||
|
|
||||||
async def async_set_cache(self, key, value, **kwargs):
|
async def async_set_cache(self, key, value, **kwargs):
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from litellm.proxy.proxy_server import llm_router, llm_model_list
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.index.acreate(overwrite=False) # don't overwrite existing index
|
await self.index.acreate(overwrite=False) # don't overwrite existing index
|
||||||
|
@ -418,12 +421,24 @@ class RedisSemanticCache(BaseCache):
|
||||||
for message in messages:
|
for message in messages:
|
||||||
prompt += message["content"]
|
prompt += message["content"]
|
||||||
# create an embedding for prompt
|
# create an embedding for prompt
|
||||||
|
router_model_names = (
|
||||||
embedding_response = await litellm.aembedding(
|
[m["model_name"] for m in llm_model_list]
|
||||||
model="text-embedding-ada-002",
|
if llm_model_list is not None
|
||||||
input=prompt,
|
else []
|
||||||
cache={"no-store": True, "no-cache": True},
|
|
||||||
)
|
)
|
||||||
|
if llm_router is not None and self.embedding_model in router_model_names:
|
||||||
|
embedding_response = await llm_router.aembedding(
|
||||||
|
model=self.embedding_model,
|
||||||
|
input=prompt,
|
||||||
|
cache={"no-store": True, "no-cache": True},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# convert to embedding
|
||||||
|
embedding_response = await litellm.aembedding(
|
||||||
|
model=self.embedding_model,
|
||||||
|
input=prompt,
|
||||||
|
cache={"no-store": True, "no-cache": True},
|
||||||
|
)
|
||||||
|
|
||||||
# get the embedding
|
# get the embedding
|
||||||
embedding = embedding_response["data"][0]["embedding"]
|
embedding = embedding_response["data"][0]["embedding"]
|
||||||
|
@ -445,6 +460,7 @@ class RedisSemanticCache(BaseCache):
|
||||||
print_verbose(f"async redis semantic-cache get_cache, kwargs: {kwargs}")
|
print_verbose(f"async redis semantic-cache get_cache, kwargs: {kwargs}")
|
||||||
from redisvl.query import VectorQuery
|
from redisvl.query import VectorQuery
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from litellm.proxy.proxy_server import llm_router, llm_model_list
|
||||||
|
|
||||||
# query
|
# query
|
||||||
|
|
||||||
|
@ -454,12 +470,24 @@ class RedisSemanticCache(BaseCache):
|
||||||
for message in messages:
|
for message in messages:
|
||||||
prompt += message["content"]
|
prompt += message["content"]
|
||||||
|
|
||||||
# convert to embedding
|
router_model_names = (
|
||||||
embedding_response = await litellm.aembedding(
|
[m["model_name"] for m in llm_model_list]
|
||||||
model="text-embedding-ada-002",
|
if llm_model_list is not None
|
||||||
input=prompt,
|
else []
|
||||||
cache={"no-store": True, "no-cache": True},
|
|
||||||
)
|
)
|
||||||
|
if llm_router is not None and self.embedding_model in router_model_names:
|
||||||
|
embedding_response = await llm_router.aembedding(
|
||||||
|
model=self.embedding_model,
|
||||||
|
input=prompt,
|
||||||
|
cache={"no-store": True, "no-cache": True},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# convert to embedding
|
||||||
|
embedding_response = await litellm.aembedding(
|
||||||
|
model=self.embedding_model,
|
||||||
|
input=prompt,
|
||||||
|
cache={"no-store": True, "no-cache": True},
|
||||||
|
)
|
||||||
|
|
||||||
# get the embedding
|
# get the embedding
|
||||||
embedding = embedding_response["data"][0]["embedding"]
|
embedding = embedding_response["data"][0]["embedding"]
|
||||||
|
@ -727,6 +755,7 @@ class Cache:
|
||||||
s3_aws_session_token: Optional[str] = None,
|
s3_aws_session_token: Optional[str] = None,
|
||||||
s3_config: Optional[Any] = None,
|
s3_config: Optional[Any] = None,
|
||||||
redis_semantic_cache_use_async=False,
|
redis_semantic_cache_use_async=False,
|
||||||
|
redis_semantic_cache_embedding_model="text-embedding-ada-002",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -757,6 +786,7 @@ class Cache:
|
||||||
password,
|
password,
|
||||||
similarity_threshold=similarity_threshold,
|
similarity_threshold=similarity_threshold,
|
||||||
use_async=redis_semantic_cache_use_async,
|
use_async=redis_semantic_cache_use_async,
|
||||||
|
embedding_model=redis_semantic_cache_embedding_model,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
elif type == "local":
|
elif type == "local":
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue