mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-10 03:30:58 +00:00
rebase on top of registry
This commit is contained in:
commit
6abef716dd
107 changed files with 4813 additions and 3587 deletions
|
|
@ -144,6 +144,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
async def create_and_execute_turn(
|
||||
self, request: AgentTurnCreateRequest
|
||||
) -> AsyncGenerator:
|
||||
assert request.stream is True, "Non-streaming not supported"
|
||||
|
||||
session_info = await self.storage.get_session_info(request.session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {request.session_id} not found")
|
||||
|
|
@ -635,14 +637,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
raise ValueError(f"Session {session_id} not found")
|
||||
|
||||
if session_info.memory_bank_id is None:
|
||||
memory_bank = await self.memory_api.create_memory_bank(
|
||||
name=f"memory_bank_{session_id}",
|
||||
config=VectorMemoryBankConfig(
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
),
|
||||
bank_id = f"memory_bank_{session_id}"
|
||||
memory_bank = VectorMemoryBankDef(
|
||||
identifier=bank_id,
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
)
|
||||
bank_id = memory_bank.bank_id
|
||||
await self.memory_api.register_memory_bank(memory_bank)
|
||||
await self.storage.add_memory_bank_to_session(session_id, bank_id)
|
||||
else:
|
||||
bank_id = session_info.memory_bank_id
|
||||
|
|
@ -673,7 +674,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
async def _retrieve_context(
|
||||
self, session_id: str, messages: List[Message], attachments: List[Attachment]
|
||||
) -> Tuple[List[str], List[int]]: # (rag_context, bank_ids)
|
||||
) -> Tuple[Optional[List[str]], Optional[List[int]]]: # (rag_context, bank_ids)
|
||||
bank_ids = []
|
||||
|
||||
memory = self._memory_tool_definition()
|
||||
|
|
@ -722,12 +723,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
chunks = [c for r in results for c in r.chunks]
|
||||
scores = [s for r in results for s in r.scores]
|
||||
|
||||
if not chunks:
|
||||
return None, bank_ids
|
||||
|
||||
# sort by score
|
||||
chunks, scores = zip(
|
||||
*sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
|
||||
)
|
||||
if not chunks:
|
||||
return None, bank_ids
|
||||
|
||||
tokens = 0
|
||||
picked = []
|
||||
|
|
|
|||
|
|
@ -100,7 +100,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
session_id=session_id,
|
||||
)
|
||||
|
||||
async def create_agent_turn(
|
||||
def create_agent_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
|
|
@ -113,16 +113,22 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
attachments: Optional[List[Attachment]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
) -> AsyncGenerator:
|
||||
agent = await self.get_agent(agent_id)
|
||||
|
||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||
request = AgentTurnCreateRequest(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=messages,
|
||||
attachments=attachments,
|
||||
stream=stream,
|
||||
stream=True,
|
||||
)
|
||||
if stream:
|
||||
return self._create_agent_turn_streaming(request)
|
||||
else:
|
||||
raise NotImplementedError("Non-streaming agent turns not yet implemented")
|
||||
|
||||
async def _create_agent_turn_streaming(
|
||||
self,
|
||||
request: AgentTurnCreateRequest,
|
||||
) -> AsyncGenerator:
|
||||
agent = await self.get_agent(request.agent_id)
|
||||
async for event in agent.create_and_execute_turn(request):
|
||||
yield event
|
||||
|
|
|
|||
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import CodeShieldConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: CodeShieldConfig, deps):
|
||||
from .code_scanner import MetaReferenceCodeScannerSafetyImpl
|
||||
|
||||
impl = MetaReferenceCodeScannerSafetyImpl(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message
|
||||
from termcolor import cprint
|
||||
|
||||
from .config import CodeScannerConfig
|
||||
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
|
||||
|
||||
class MetaReferenceCodeScannerSafetyImpl(Safety):
|
||||
def __init__(self, config: CodeScannerConfig, deps) -> None:
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_shield(self, shield: ShieldDef) -> None:
|
||||
if shield.type != ShieldType.code_scanner.value:
|
||||
raise ValueError(f"Unsupported safety shield type: {shield.type}")
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
shield_type: str,
|
||||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
shield_def = await self.shield_store.get_shield(shield_type)
|
||||
if not shield_def:
|
||||
raise ValueError(f"Unknown shield {shield_type}")
|
||||
|
||||
from codeshield.cs import CodeShield
|
||||
|
||||
text = "\n".join([interleaved_text_media_as_str(m.content) for m in messages])
|
||||
cprint(f"Running CodeScannerShield on {text[50:]}", color="magenta")
|
||||
result = await CodeShield.scan_code(text)
|
||||
|
||||
violation = None
|
||||
if result.is_insecure:
|
||||
violation = SafetyViolation(
|
||||
violation_level=(ViolationLevel.ERROR),
|
||||
user_message="Sorry, I found security concerns in the code.",
|
||||
metadata={
|
||||
"violation_type": ",".join(
|
||||
[issue.pattern_id for issue in result.issues_found]
|
||||
)
|
||||
},
|
||||
)
|
||||
return RunShieldResponse(violation=violation)
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CodeShieldConfig(BaseModel):
|
||||
pass
|
||||
|
|
@ -43,13 +43,12 @@ class MetaReferenceEvalsImpl(Evals):
|
|||
print("generation start")
|
||||
for msg in x1[:5]:
|
||||
print("generation for msg: ", msg)
|
||||
response = self.inference_api.chat_completion(
|
||||
response = await self.inference_api.chat_completion(
|
||||
model=model,
|
||||
messages=[msg],
|
||||
stream=False,
|
||||
)
|
||||
async for x in response:
|
||||
generation_outputs.append(x.completion_message.content)
|
||||
generation_outputs.append(response.completion_message.content)
|
||||
|
||||
x2 = task_impl.postprocess(generation_outputs)
|
||||
eval_results = task_impl.score(x2)
|
||||
|
|
|
|||
|
|
@ -297,7 +297,7 @@ class Llama:
|
|||
token=next_token[0].item(),
|
||||
text=self.tokenizer.decode(next_token.tolist()),
|
||||
logprobs=(
|
||||
token_logprobs[:, prev_pos + 1 : cur_pos + 1][0].tolist()
|
||||
token_logprobs[:, cur_pos : cur_pos + 1][0].tolist()
|
||||
if logprobs
|
||||
else None
|
||||
),
|
||||
|
|
|
|||
|
|
@ -6,15 +6,14 @@
|
|||
|
||||
import asyncio
|
||||
|
||||
from typing import AsyncIterator, List, Union
|
||||
from typing import AsyncGenerator, List
|
||||
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import RoutableProvider
|
||||
from llama_stack.providers.utils.inference.augment_messages import (
|
||||
augment_messages_for_tools,
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_messages,
|
||||
)
|
||||
|
||||
from .config import MetaReferenceImplConfig
|
||||
|
|
@ -25,7 +24,7 @@ from .model_parallel import LlamaModelParallelGenerator
|
|||
SEMAPHORE = asyncio.Semaphore(1)
|
||||
|
||||
|
||||
class MetaReferenceInferenceImpl(Inference, RoutableProvider):
|
||||
class MetaReferenceInferenceImpl(Inference):
|
||||
def __init__(self, config: MetaReferenceImplConfig) -> None:
|
||||
self.config = config
|
||||
model = resolve_model(config.model)
|
||||
|
|
@ -35,21 +34,20 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
|
|||
# verify that the checkpoint actually is for this model lol
|
||||
|
||||
async def initialize(self) -> None:
|
||||
print(f"Loading model `{self.model.descriptor()}`")
|
||||
self.generator = LlamaModelParallelGenerator(self.config)
|
||||
self.generator.start()
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
||||
assert (
|
||||
len(routing_keys) == 1
|
||||
), f"Only one routing key is supported {routing_keys}"
|
||||
assert routing_keys[0] == self.config.model
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
if model.identifier != self.model.descriptor():
|
||||
raise RuntimeError(
|
||||
f"Model mismatch: {model.identifier} != {self.model.descriptor()}"
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self.generator.stop()
|
||||
|
||||
# hm, when stream=False, we should not be doing SSE :/ which is what the
|
||||
# top-level server is going to do. make the typing more specific here
|
||||
async def chat_completion(
|
||||
def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
|
@ -59,9 +57,10 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
|
|||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncIterator[
|
||||
Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse]
|
||||
]:
|
||||
) -> AsyncGenerator:
|
||||
if logprobs:
|
||||
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||
|
||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
|
|
@ -74,7 +73,6 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
|
|||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
messages = augment_messages_for_tools(request)
|
||||
model = resolve_model(request.model)
|
||||
if model is None:
|
||||
raise RuntimeError(
|
||||
|
|
@ -88,21 +86,74 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
|
|||
if SEMAPHORE.locked():
|
||||
raise RuntimeError("Only one concurrent request is supported")
|
||||
|
||||
if request.stream:
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
return self._nonstream_chat_completion(request)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
async with SEMAPHORE:
|
||||
if request.stream:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
)
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
|
||||
tokens = []
|
||||
logprobs = []
|
||||
|
||||
stop_reason = None
|
||||
|
||||
buffer = ""
|
||||
for token_result in self.generator.chat_completion(
|
||||
messages=messages,
|
||||
temperature=request.sampling_params.temperature,
|
||||
top_p=request.sampling_params.top_p,
|
||||
max_gen_len=request.sampling_params.max_tokens,
|
||||
logprobs=request.logprobs,
|
||||
tool_prompt_format=request.tool_prompt_format,
|
||||
):
|
||||
tokens.append(token_result.token)
|
||||
|
||||
if token_result.text == "<|eot_id|>":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif token_result.text == "<|eom_id|>":
|
||||
stop_reason = StopReason.end_of_message
|
||||
|
||||
if request.logprobs:
|
||||
assert len(token_result.logprobs) == 1
|
||||
|
||||
logprobs.append(
|
||||
TokenLogProbs(
|
||||
logprobs_by_token={
|
||||
token_result.text: token_result.logprobs[0]
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
if stop_reason is None:
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
message = self.generator.formatter.decode_assistant_message(
|
||||
tokens, stop_reason
|
||||
)
|
||||
return ChatCompletionResponse(
|
||||
completion_message=message,
|
||||
logprobs=logprobs if request.logprobs else None,
|
||||
)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> AsyncGenerator:
|
||||
async with SEMAPHORE:
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
)
|
||||
)
|
||||
|
||||
tokens = []
|
||||
logprobs = []
|
||||
stop_reason = None
|
||||
ipython = False
|
||||
|
||||
for token_result in self.generator.chat_completion(
|
||||
|
|
@ -113,10 +164,9 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
|
|||
logprobs=request.logprobs,
|
||||
tool_prompt_format=request.tool_prompt_format,
|
||||
):
|
||||
buffer += token_result.text
|
||||
tokens.append(token_result.token)
|
||||
|
||||
if not ipython and buffer.startswith("<|python_tag|>"):
|
||||
if not ipython and token_result.text.startswith("<|python_tag|>"):
|
||||
ipython = True
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
|
|
@ -127,13 +177,6 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
|
|||
),
|
||||
)
|
||||
)
|
||||
buffer = buffer[len("<|python_tag|>") :]
|
||||
continue
|
||||
|
||||
if not request.stream:
|
||||
if request.logprobs:
|
||||
logprobs.append(token_result.logprob)
|
||||
|
||||
continue
|
||||
|
||||
if token_result.text == "<|eot_id|>":
|
||||
|
|
@ -154,59 +197,61 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
|
|||
delta = text
|
||||
|
||||
if stop_reason is None:
|
||||
if request.logprobs:
|
||||
assert len(token_result.logprobs) == 1
|
||||
|
||||
logprobs.append(
|
||||
TokenLogProbs(
|
||||
logprobs_by_token={
|
||||
token_result.text: token_result.logprobs[0]
|
||||
}
|
||||
)
|
||||
)
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=delta,
|
||||
stop_reason=stop_reason,
|
||||
logprobs=logprobs if request.logprobs else None,
|
||||
)
|
||||
)
|
||||
|
||||
if stop_reason is None:
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
# TODO(ashwin): parse tool calls separately here and report errors?
|
||||
# if someone breaks the iteration before coming here we are toast
|
||||
message = self.generator.formatter.decode_assistant_message(
|
||||
tokens, stop_reason
|
||||
)
|
||||
if request.stream:
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.failure,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=tool_call,
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.failure,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
# TODO(ashwin): what else do we need to send out here when everything finishes?
|
||||
else:
|
||||
yield ChatCompletionResponse(
|
||||
completion_message=message,
|
||||
logprobs=logprobs if request.logprobs else None,
|
||||
for tool_call in message.tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=tool_call,
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -13,15 +13,15 @@ from typing import Optional
|
|||
import torch
|
||||
|
||||
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
||||
from llama_models.llama3.api.model import Transformer, TransformerBlock
|
||||
|
||||
from llama_models.datatypes import CheckpointQuantizationFormat
|
||||
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
|
||||
from termcolor import cprint
|
||||
from torch import Tensor
|
||||
|
||||
from llama_stack.apis.inference import QuantizationType
|
||||
|
||||
from llama_stack.apis.inference.config import (
|
||||
CheckpointQuantizationFormat,
|
||||
from llama_stack.providers.impls.meta_reference.inference.config import (
|
||||
MetaReferenceImplConfig,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
|
@ -14,7 +13,6 @@ import numpy as np
|
|||
from numpy.typing import NDArray
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import RoutableProvider
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
|
|
@ -63,7 +61,7 @@ class FaissIndex(EmbeddingIndex):
|
|||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class FaissMemoryImpl(Memory, RoutableProvider):
|
||||
class FaissMemoryImpl(Memory):
|
||||
def __init__(self, config: FaissImplConfig) -> None:
|
||||
self.config = config
|
||||
self.cache = {}
|
||||
|
|
@ -72,37 +70,18 @@ class FaissMemoryImpl(Memory, RoutableProvider):
|
|||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
||||
print(f"[faiss] Registering memory bank routing keys: {routing_keys}")
|
||||
pass
|
||||
|
||||
async def create_memory_bank(
|
||||
async def register_memory_bank(
|
||||
self,
|
||||
name: str,
|
||||
config: MemoryBankConfig,
|
||||
url: Optional[URL] = None,
|
||||
) -> MemoryBank:
|
||||
assert url is None, "URL is not supported for this implementation"
|
||||
memory_bank: MemoryBankDef,
|
||||
) -> None:
|
||||
assert (
|
||||
config.type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {config.type}"
|
||||
memory_bank.type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {memory_bank.type}"
|
||||
|
||||
bank_id = str(uuid.uuid4())
|
||||
bank = MemoryBank(
|
||||
bank_id=bank_id,
|
||||
name=name,
|
||||
config=config,
|
||||
url=url,
|
||||
index = BankWithIndex(
|
||||
bank=memory_bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION)
|
||||
)
|
||||
index = BankWithIndex(bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION))
|
||||
self.cache[bank_id] = index
|
||||
return bank
|
||||
|
||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||
index = self.cache.get(bank_id)
|
||||
if index is None:
|
||||
return None
|
||||
return index.bank
|
||||
self.cache[memory_bank.identifier] = index
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -44,7 +44,6 @@ def message_content_as_str(message: Message) -> str:
|
|||
return interleaved_text_media_as_str(message.content)
|
||||
|
||||
|
||||
# For shields that operate on simple strings
|
||||
class TextShield(ShieldBase):
|
||||
def convert_messages_to_text(self, messages: List[Message]) -> str:
|
||||
return "\n".join([message_content_as_str(m) for m in messages])
|
||||
|
|
@ -56,9 +55,3 @@ class TextShield(ShieldBase):
|
|||
@abstractmethod
|
||||
async def run_impl(self, text: str) -> ShieldResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DummyShield(TextShield):
|
||||
async def run_impl(self, text: str) -> ShieldResponse:
|
||||
# Dummy return LOW to test e2e
|
||||
return ShieldResponse(is_violation=False)
|
||||
|
|
@ -9,23 +9,19 @@ from typing import List, Optional
|
|||
|
||||
from llama_models.sku_list import CoreModelId, safety_models
|
||||
|
||||
from pydantic import BaseModel, validator
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
||||
class MetaReferenceShieldType(Enum):
|
||||
llama_guard = "llama_guard"
|
||||
code_scanner_guard = "code_scanner_guard"
|
||||
injection_shield = "injection_shield"
|
||||
jailbreak_shield = "jailbreak_shield"
|
||||
class PromptGuardType(Enum):
|
||||
injection = "injection"
|
||||
jailbreak = "jailbreak"
|
||||
|
||||
|
||||
class LlamaGuardShieldConfig(BaseModel):
|
||||
model: str = "Llama-Guard-3-1B"
|
||||
excluded_categories: List[str] = []
|
||||
disable_input_check: bool = False
|
||||
disable_output_check: bool = False
|
||||
|
||||
@validator("model")
|
||||
@field_validator("model")
|
||||
@classmethod
|
||||
def validate_model(cls, model: str) -> str:
|
||||
permitted_models = [
|
||||
|
|
@ -47,10 +43,6 @@ class LlamaGuardShieldConfig(BaseModel):
|
|||
return model
|
||||
|
||||
|
||||
class PromptGuardShieldConfig(BaseModel):
|
||||
model: str = "Prompt-Guard-86M"
|
||||
|
||||
|
||||
class SafetyConfig(BaseModel):
|
||||
llama_guard_shield: Optional[LlamaGuardShieldConfig] = None
|
||||
prompt_guard_shield: Optional[PromptGuardShieldConfig] = None
|
||||
enable_prompt_guard: Optional[bool] = False
|
||||
|
|
|
|||
|
|
@ -113,8 +113,6 @@ class LlamaGuardShield(ShieldBase):
|
|||
model: str,
|
||||
inference_api: Inference,
|
||||
excluded_categories: List[str] = None,
|
||||
disable_input_check: bool = False,
|
||||
disable_output_check: bool = False,
|
||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||
):
|
||||
super().__init__(on_violation_action)
|
||||
|
|
@ -132,8 +130,6 @@ class LlamaGuardShield(ShieldBase):
|
|||
self.model = model
|
||||
self.inference_api = inference_api
|
||||
self.excluded_categories = excluded_categories
|
||||
self.disable_input_check = disable_input_check
|
||||
self.disable_output_check = disable_output_check
|
||||
|
||||
def check_unsafe_response(self, response: str) -> Optional[str]:
|
||||
match = re.match(r"^unsafe\n(.*)$", response)
|
||||
|
|
@ -180,12 +176,6 @@ class LlamaGuardShield(ShieldBase):
|
|||
|
||||
async def run(self, messages: List[Message]) -> ShieldResponse:
|
||||
messages = self.validate_messages(messages)
|
||||
if self.disable_input_check and messages[-1].role == Role.user.value:
|
||||
return ShieldResponse(is_violation=False)
|
||||
elif self.disable_output_check and messages[-1].role == Role.assistant.value:
|
||||
return ShieldResponse(
|
||||
is_violation=False,
|
||||
)
|
||||
|
||||
if self.model == CoreModelId.llama_guard_3_11b_vision.value:
|
||||
shield_input_message = self.build_vision_shield_input(messages)
|
||||
|
|
@ -6,56 +6,43 @@
|
|||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import Api, RoutableProvider
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
|
||||
OnViolationAction,
|
||||
)
|
||||
|
||||
from .config import MetaReferenceShieldType, SafetyConfig
|
||||
|
||||
from .shields import (
|
||||
CodeScannerShield,
|
||||
InjectionShield,
|
||||
JailbreakShield,
|
||||
LlamaGuardShield,
|
||||
PromptGuardShield,
|
||||
ShieldBase,
|
||||
)
|
||||
from .base import OnViolationAction, ShieldBase
|
||||
from .config import SafetyConfig
|
||||
from .llama_guard import LlamaGuardShield
|
||||
from .prompt_guard import InjectionShield, JailbreakShield, PromptGuardShield
|
||||
|
||||
|
||||
def resolve_and_get_path(model_name: str) -> str:
|
||||
model = resolve_model(model_name)
|
||||
assert model is not None, f"Could not resolve model {model_name}"
|
||||
model_dir = model_local_dir(model.descriptor())
|
||||
return model_dir
|
||||
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
|
||||
|
||||
|
||||
class MetaReferenceSafetyImpl(Safety, RoutableProvider):
|
||||
class MetaReferenceSafetyImpl(Safety):
|
||||
def __init__(self, config: SafetyConfig, deps) -> None:
|
||||
self.config = config
|
||||
self.inference_api = deps[Api.inference]
|
||||
|
||||
self.available_shields = []
|
||||
if config.llama_guard_shield:
|
||||
self.available_shields.append(ShieldType.llama_guard.value)
|
||||
if config.enable_prompt_guard:
|
||||
self.available_shields.append(ShieldType.prompt_guard.value)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
shield_cfg = self.config.prompt_guard_shield
|
||||
if shield_cfg is not None:
|
||||
model_dir = resolve_and_get_path(shield_cfg.model)
|
||||
if self.config.enable_prompt_guard:
|
||||
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
|
||||
_ = PromptGuardShield.instance(model_dir)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
||||
available_shields = [v.value for v in MetaReferenceShieldType]
|
||||
for key in routing_keys:
|
||||
if key not in available_shields:
|
||||
raise ValueError(f"Unknown safety shield type: {key}")
|
||||
async def register_shield(self, shield: ShieldDef) -> None:
|
||||
if shield.type not in self.available_shields:
|
||||
raise ValueError(f"Unsupported safety shield type: {shield.type}")
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
|
|
@ -63,10 +50,11 @@ class MetaReferenceSafetyImpl(Safety, RoutableProvider):
|
|||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
available_shields = [v.value for v in MetaReferenceShieldType]
|
||||
assert shield_type in available_shields, f"Unknown shield {shield_type}"
|
||||
shield_def = await self.shield_store.get_shield(shield_type)
|
||||
if not shield_def:
|
||||
raise ValueError(f"Unknown shield {shield_type}")
|
||||
|
||||
shield = self.get_shield_impl(MetaReferenceShieldType(shield_type))
|
||||
shield = self.get_shield_impl(shield_def)
|
||||
|
||||
messages = messages.copy()
|
||||
# some shields like llama-guard require the first message to be a user message
|
||||
|
|
@ -92,34 +80,22 @@ class MetaReferenceSafetyImpl(Safety, RoutableProvider):
|
|||
|
||||
return RunShieldResponse(violation=violation)
|
||||
|
||||
def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase:
|
||||
cfg = self.config
|
||||
if typ == MetaReferenceShieldType.llama_guard:
|
||||
cfg = cfg.llama_guard_shield
|
||||
assert (
|
||||
cfg is not None
|
||||
), "Cannot use LlamaGuardShield since not present in config"
|
||||
|
||||
def get_shield_impl(self, shield: ShieldDef) -> ShieldBase:
|
||||
if shield.type == ShieldType.llama_guard.value:
|
||||
cfg = self.config.llama_guard_shield
|
||||
return LlamaGuardShield(
|
||||
model=cfg.model,
|
||||
inference_api=self.inference_api,
|
||||
excluded_categories=cfg.excluded_categories,
|
||||
disable_input_check=cfg.disable_input_check,
|
||||
disable_output_check=cfg.disable_output_check,
|
||||
)
|
||||
elif typ == MetaReferenceShieldType.jailbreak_shield:
|
||||
assert (
|
||||
cfg.prompt_guard_shield is not None
|
||||
), "Cannot use Jailbreak Shield since Prompt Guard not present in config"
|
||||
model_dir = resolve_and_get_path(cfg.prompt_guard_shield.model)
|
||||
return JailbreakShield.instance(model_dir)
|
||||
elif typ == MetaReferenceShieldType.injection_shield:
|
||||
assert (
|
||||
cfg.prompt_guard_shield is not None
|
||||
), "Cannot use PromptGuardShield since not present in config"
|
||||
model_dir = resolve_and_get_path(cfg.prompt_guard_shield.model)
|
||||
return InjectionShield.instance(model_dir)
|
||||
elif typ == MetaReferenceShieldType.code_scanner_guard:
|
||||
return CodeScannerShield.instance()
|
||||
elif shield.type == ShieldType.prompt_guard.value:
|
||||
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
|
||||
subtype = shield.params.get("prompt_guard_type", "injection")
|
||||
if subtype == "injection":
|
||||
return InjectionShield.instance(model_dir)
|
||||
elif subtype == "jailbreak":
|
||||
return JailbreakShield.instance(model_dir)
|
||||
else:
|
||||
raise ValueError(f"Unknown prompt guard type: {subtype}")
|
||||
else:
|
||||
raise ValueError(f"Unknown shield type: {typ}")
|
||||
raise ValueError(f"Unknown shield type: {shield.type}")
|
||||
|
|
|
|||
|
|
@ -1,33 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# supress warnings and spew of logs from hugging face
|
||||
import transformers
|
||||
|
||||
from .base import ( # noqa: F401
|
||||
DummyShield,
|
||||
OnViolationAction,
|
||||
ShieldBase,
|
||||
ShieldResponse,
|
||||
TextShield,
|
||||
)
|
||||
from .code_scanner import CodeScannerShield # noqa: F401
|
||||
from .llama_guard import LlamaGuardShield # noqa: F401
|
||||
from .prompt_guard import ( # noqa: F401
|
||||
InjectionShield,
|
||||
JailbreakShield,
|
||||
PromptGuardShield,
|
||||
)
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
import os
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
|
@ -1,27 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
from .base import ShieldResponse, TextShield
|
||||
|
||||
|
||||
class CodeScannerShield(TextShield):
|
||||
async def run_impl(self, text: str) -> ShieldResponse:
|
||||
from codeshield.cs import CodeShield
|
||||
|
||||
cprint(f"Running CodeScannerShield on {text[50:]}", color="magenta")
|
||||
result = await CodeShield.scan_code(text)
|
||||
if result.is_insecure:
|
||||
return ShieldResponse(
|
||||
is_violation=True,
|
||||
violation_type=",".join(
|
||||
[issue.pattern_id for issue in result.issues_found]
|
||||
),
|
||||
violation_return_message="Sorry, I found security concerns in the code.",
|
||||
)
|
||||
else:
|
||||
return ShieldResponse(is_violation=False)
|
||||
Loading…
Add table
Add a link
Reference in a new issue