Memory tests pass now

This commit is contained in:
Ashwin Bharambe 2024-12-15 20:55:06 -08:00
parent e51154964f
commit 59ce047aea
23 changed files with 122 additions and 81 deletions

View file

@ -15,6 +15,9 @@ from llama_stack.apis.safety import * # noqa: F403
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from .config import LlamaGuardConfig
@ -258,18 +261,18 @@ class LlamaGuardShield:
most_recent_img = None
for m in messages[::-1]:
if isinstance(m.content, str):
if isinstance(m.content, str) or isinstance(m.content, TextContentItem):
conversation.append(m)
elif isinstance(m.content, ImageMedia):
elif isinstance(m.content, ImageContentItem):
if most_recent_img is None and m.role == Role.user.value:
most_recent_img = m.content
conversation.append(m)
elif isinstance(m.content, list):
content = []
for c in m.content:
if isinstance(c, str):
if isinstance(c, str) or isinstance(c, TextContentItem):
content.append(c)
elif isinstance(c, ImageMedia):
elif isinstance(c, ImageContentItem):
if most_recent_img is None and m.role == Role.user.value:
most_recent_img = c
content.append(c)
@ -292,7 +295,7 @@ class LlamaGuardShield:
categories_str = "\n".join(categories)
conversations_str = "\n\n".join(
[
f"{m.role.capitalize()}: {interleaved_text_media_as_str(m.content)}"
f"{m.role.capitalize()}: {interleaved_content_as_str(m.content)}"
for m in messages
]
)

View file

@ -17,6 +17,9 @@ from llama_stack.apis.safety import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from .config import PromptGuardConfig, PromptGuardType
@ -83,7 +86,7 @@ class PromptGuardShield:
async def run(self, messages: List[Message]) -> RunShieldResponse:
message = messages[-1]
text = interleaved_text_media_as_str(message.content)
text = interleaved_content_as_str(message.content)
# run model on messages and return response
inputs = self.tokenizer(text, return_tensors="pt")