mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-18 01:17:15 +00:00
Memory tests pass now
This commit is contained in:
parent
e51154964f
commit
59ce047aea
23 changed files with 122 additions and 81 deletions
|
|
@ -17,6 +17,9 @@ from llama_stack.apis.agents import (
|
|||
MemoryQueryGeneratorConfig,
|
||||
)
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
|
||||
async def generate_rag_query(
|
||||
|
|
@ -42,7 +45,7 @@ async def default_rag_query_generator(
|
|||
messages: List[Message],
|
||||
**kwargs,
|
||||
):
|
||||
return config.sep.join(interleaved_text_media_as_str(m.content) for m in messages)
|
||||
return config.sep.join(interleaved_content_as_str(m.content) for m in messages)
|
||||
|
||||
|
||||
async def llm_rag_query_generator(
|
||||
|
|
|
|||
|
|
@ -114,7 +114,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedTextMedia,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
|
|
@ -218,8 +218,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
yield chunk
|
||||
|
||||
async def embeddings(
|
||||
self, model_id: str, contents: list[InterleavedTextMedia]
|
||||
self, model_id: str, contents: List[InterleavedContent]
|
||||
) -> EmbeddingsResponse:
|
||||
log.info("vLLM embeddings")
|
||||
# TODO
|
||||
raise NotImplementedError()
|
||||
|
|
|
|||
|
|
@ -4,12 +4,18 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import ChromaInlineImplConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: ChromaInlineImplConfig, _deps):
|
||||
async def get_provider_impl(
|
||||
config: ChromaInlineImplConfig, deps: Dict[Api, ProviderSpec]
|
||||
):
|
||||
from llama_stack.providers.remote.memory.chroma.chroma import ChromaMemoryAdapter
|
||||
|
||||
impl = ChromaMemoryAdapter(config)
|
||||
impl = ChromaMemoryAdapter(config, deps[Api.inference])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -19,9 +19,10 @@ from numpy.typing import NDArray
|
|||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
from llama_stack.apis.memory_banks import MemoryBankType, VectorMemoryBank
|
||||
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
BankWithIndex,
|
||||
EmbeddingIndex,
|
||||
|
|
@ -208,7 +209,7 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
|||
async def query_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
query: InterleavedTextMedia,
|
||||
query: InterleavedContent,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse:
|
||||
index = self.cache.get(bank_id)
|
||||
|
|
|
|||
|
|
@ -15,6 +15,9 @@ from llama_stack.apis.safety import * # noqa: F403
|
|||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
from .config import LlamaGuardConfig
|
||||
|
||||
|
|
@ -258,18 +261,18 @@ class LlamaGuardShield:
|
|||
most_recent_img = None
|
||||
|
||||
for m in messages[::-1]:
|
||||
if isinstance(m.content, str):
|
||||
if isinstance(m.content, str) or isinstance(m.content, TextContentItem):
|
||||
conversation.append(m)
|
||||
elif isinstance(m.content, ImageMedia):
|
||||
elif isinstance(m.content, ImageContentItem):
|
||||
if most_recent_img is None and m.role == Role.user.value:
|
||||
most_recent_img = m.content
|
||||
conversation.append(m)
|
||||
elif isinstance(m.content, list):
|
||||
content = []
|
||||
for c in m.content:
|
||||
if isinstance(c, str):
|
||||
if isinstance(c, str) or isinstance(c, TextContentItem):
|
||||
content.append(c)
|
||||
elif isinstance(c, ImageMedia):
|
||||
elif isinstance(c, ImageContentItem):
|
||||
if most_recent_img is None and m.role == Role.user.value:
|
||||
most_recent_img = c
|
||||
content.append(c)
|
||||
|
|
@ -292,7 +295,7 @@ class LlamaGuardShield:
|
|||
categories_str = "\n".join(categories)
|
||||
conversations_str = "\n\n".join(
|
||||
[
|
||||
f"{m.role.capitalize()}: {interleaved_text_media_as_str(m.content)}"
|
||||
f"{m.role.capitalize()}: {interleaved_content_as_str(m.content)}"
|
||||
for m in messages
|
||||
]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -17,6 +17,9 @@ from llama_stack.apis.safety import * # noqa: F403
|
|||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
from .config import PromptGuardConfig, PromptGuardType
|
||||
|
||||
|
|
@ -83,7 +86,7 @@ class PromptGuardShield:
|
|||
|
||||
async def run(self, messages: List[Message]) -> RunShieldResponse:
|
||||
message = messages[-1]
|
||||
text = interleaved_text_media_as_str(message.content)
|
||||
text = interleaved_content_as_str(message.content)
|
||||
|
||||
# run model on messages and return response
|
||||
inputs = self.tokenizer(text, return_tensors="pt")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue