Memory tests pass now

This commit is contained in:
Ashwin Bharambe 2024-12-15 20:55:06 -08:00
parent e51154964f
commit 59ce047aea
23 changed files with 122 additions and 81 deletions

View file

@ -17,6 +17,9 @@ from llama_stack.apis.agents import (
MemoryQueryGeneratorConfig,
)
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
async def generate_rag_query(
@ -42,7 +45,7 @@ async def default_rag_query_generator(
messages: List[Message],
**kwargs,
):
return config.sep.join(interleaved_text_media_as_str(m.content) for m in messages)
return config.sep.join(interleaved_content_as_str(m.content) for m in messages)
async def llm_rag_query_generator(

View file

@ -114,7 +114,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
async def completion(
self,
model_id: str,
content: InterleavedTextMedia,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
@ -218,8 +218,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
yield chunk
async def embeddings(
self, model_id: str, contents: list[InterleavedTextMedia]
self, model_id: str, contents: List[InterleavedContent]
) -> EmbeddingsResponse:
log.info("vLLM embeddings")
# TODO
raise NotImplementedError()

View file

@ -4,12 +4,18 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import ChromaInlineImplConfig
async def get_provider_impl(config: ChromaInlineImplConfig, _deps):
async def get_provider_impl(
config: ChromaInlineImplConfig, deps: Dict[Api, ProviderSpec]
):
from llama_stack.providers.remote.memory.chroma.chroma import ChromaMemoryAdapter
impl = ChromaMemoryAdapter(config)
impl = ChromaMemoryAdapter(config, deps[Api.inference])
await impl.initialize()
return impl

View file

@ -19,9 +19,10 @@ from numpy.typing import NDArray
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.memory_banks import MemoryBankType, VectorMemoryBank
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
EmbeddingIndex,
@ -208,7 +209,7 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = self.cache.get(bank_id)

View file

@ -15,6 +15,9 @@ from llama_stack.apis.safety import * # noqa: F403
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from .config import LlamaGuardConfig
@ -258,18 +261,18 @@ class LlamaGuardShield:
most_recent_img = None
for m in messages[::-1]:
if isinstance(m.content, str):
if isinstance(m.content, str) or isinstance(m.content, TextContentItem):
conversation.append(m)
elif isinstance(m.content, ImageMedia):
elif isinstance(m.content, ImageContentItem):
if most_recent_img is None and m.role == Role.user.value:
most_recent_img = m.content
conversation.append(m)
elif isinstance(m.content, list):
content = []
for c in m.content:
if isinstance(c, str):
if isinstance(c, str) or isinstance(c, TextContentItem):
content.append(c)
elif isinstance(c, ImageMedia):
elif isinstance(c, ImageContentItem):
if most_recent_img is None and m.role == Role.user.value:
most_recent_img = c
content.append(c)
@ -292,7 +295,7 @@ class LlamaGuardShield:
categories_str = "\n".join(categories)
conversations_str = "\n\n".join(
[
f"{m.role.capitalize()}: {interleaved_text_media_as_str(m.content)}"
f"{m.role.capitalize()}: {interleaved_content_as_str(m.content)}"
for m in messages
]
)

View file

@ -17,6 +17,9 @@ from llama_stack.apis.safety import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from .config import PromptGuardConfig, PromptGuardType
@ -83,7 +86,7 @@ class PromptGuardShield:
async def run(self, messages: List[Message]) -> RunShieldResponse:
message = messages[-1]
text = interleaved_text_media_as_str(message.content)
text = interleaved_content_as_str(message.content)
# run model on messages and return response
inputs = self.tokenizer(text, return_tensors="pt")