Make embedding generation go through inference (#606)

This PR does the following:
1) adds the ability to generate embeddings in all supported inference
providers.
2) Moves all the memory providers to use the inference API and improved
the memory tests to setup the inference stack correctly and use the
embedding models

This is a merge from #589 and #598
This commit is contained in:
Dinesh Yeduguru 2024-12-12 11:47:50 -08:00 committed by GitHub
parent a14785af46
commit 96e158eaac
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
37 changed files with 677 additions and 156 deletions

View file

@ -22,28 +22,10 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.datatypes import Api
log = logging.getLogger(__name__)
ALL_MINILM_L6_V2_DIMENSION = 384
EMBEDDING_MODELS = {}
def get_embedding_model(model: str) -> "SentenceTransformer":
global EMBEDDING_MODELS
loaded_model = EMBEDDING_MODELS.get(model)
if loaded_model is not None:
return loaded_model
log.info(f"Loading sentence transformer for {model}...")
from sentence_transformers import SentenceTransformer
loaded_model = SentenceTransformer(model)
EMBEDDING_MODELS[model] = loaded_model
return loaded_model
def parse_pdf(data: bytes) -> str:
# For PDF and DOC/DOCX files, we can't reliably convert to string
@ -166,12 +148,12 @@ class EmbeddingIndex(ABC):
class BankWithIndex:
bank: VectorMemoryBank
index: EmbeddingIndex
inference_api: Api.inference
async def insert_documents(
self,
documents: List[MemoryBankDocument],
) -> None:
model = get_embedding_model(self.bank.embedding_model)
for doc in documents:
content = await content_from_doc(doc)
chunks = make_overlapped_chunks(
@ -183,7 +165,10 @@ class BankWithIndex:
)
if not chunks:
continue
embeddings = model.encode([x.content for x in chunks]).astype(np.float32)
embeddings_response = await self.inference_api.embeddings(
self.bank.embedding_model, [x.content for x in chunks]
)
embeddings = np.array(embeddings_response.embeddings)
await self.index.add_chunks(chunks, embeddings)
@ -208,6 +193,8 @@ class BankWithIndex:
else:
query_str = _process(query)
model = get_embedding_model(self.bank.embedding_model)
query_vector = model.encode([query_str])[0].astype(np.float32)
embeddings_response = await self.inference_api.embeddings(
self.bank.embedding_model, [query_str]
)
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
return await self.index.query(query_vector, k, score_threshold)