Merge branch 'main' into add-nvidia-inference-adapter

This commit is contained in:
Matthew Farrellee 2024-11-15 14:09:12 -05:00
commit 43262df033
399 changed files with 17826 additions and 10490 deletions

View file

@ -1,187 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator
from fireworks.client import Fireworks
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
)
from .config import FireworksImplConfig
FIREWORKS_SUPPORTED_MODELS = {
"Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct",
"Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct",
"Llama3.1-405B-Instruct": "fireworks/llama-v3p1-405b-instruct",
"Llama3.2-1B-Instruct": "fireworks/llama-v3p2-1b-instruct",
"Llama3.2-3B-Instruct": "fireworks/llama-v3p2-3b-instruct",
"Llama3.2-11B-Vision-Instruct": "llama-v3p2-11b-vision-instruct",
"Llama3.2-90B-Vision-Instruct": "llama-v3p2-90b-vision-instruct",
}
class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: FireworksImplConfig) -> None:
ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS
)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self) -> None:
return
async def shutdown(self) -> None:
pass
async def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
request = CompletionRequest(
model=model,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
client = Fireworks(api_key=self.config.api_key)
if stream:
return self._stream_completion(request, client)
else:
return await self._nonstream_completion(request, client)
async def _nonstream_completion(
self, request: CompletionRequest, client: Fireworks
) -> CompletionResponse:
params = self._get_params(request)
r = await client.completion.acreate(**params)
return process_completion_response(r, self.formatter)
async def _stream_completion(
self, request: CompletionRequest, client: Fireworks
) -> AsyncGenerator:
params = self._get_params(request)
stream = client.completion.acreate(**params)
async for chunk in process_completion_stream_response(stream, self.formatter):
yield chunk
async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
client = Fireworks(api_key=self.config.api_key)
if stream:
return self._stream_chat_completion(request, client)
else:
return await self._nonstream_chat_completion(request, client)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: Fireworks
) -> ChatCompletionResponse:
params = self._get_params(request)
r = await client.completion.acreate(**params)
return process_chat_completion_response(r, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: Fireworks
) -> AsyncGenerator:
params = self._get_params(request)
stream = client.completion.acreate(**params)
async for chunk in process_chat_completion_stream_response(
stream, self.formatter
):
yield chunk
def _get_params(self, request) -> dict:
prompt = ""
if type(request) == ChatCompletionRequest:
prompt = chat_completion_request_to_prompt(request, self.formatter)
elif type(request) == CompletionRequest:
prompt = completion_request_to_prompt(request, self.formatter)
else:
raise ValueError(f"Unknown request type {type(request)}")
# Fireworks always prepends with BOS
if prompt.startswith("<|begin_of_text|>"):
prompt = prompt[len("<|begin_of_text|>") :]
options = get_sampling_options(request.sampling_params)
options.setdefault("max_tokens", 512)
if fmt := request.response_format:
if fmt.type == ResponseFormatType.json_schema.value:
options["response_format"] = {
"type": "json_object",
"schema": fmt.json_schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
options["response_format"] = {
"type": "grammar",
"grammar": fmt.bnf,
}
else:
raise ValueError(f"Unknown response format {fmt.type}")
return {
"model": self.map_to_provider_model(request.model),
"prompt": prompt,
"stream": request.stream,
**options,
}
async def embeddings(
self,
model: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -1,16 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel, Field
class BedrockSafetyConfig(BaseModel):
"""Configuration information for a guardrail that you want to use in the request."""
aws_profile: str = Field(
default="default",
description="The profile on the machine having valid aws credentials. This will ensure separation of creation to invocation",
)

View file

@ -1,26 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
class TogetherProviderDataValidator(BaseModel):
together_api_key: str
@json_schema_type
class TogetherSafetyConfig(BaseModel):
url: str = Field(
default="https://api.together.xyz/v1",
description="The URL for the Together AI server",
)
api_key: Optional[str] = Field(
default=None,
description="The Together AI API Key (default for the distribution, if any)",
)

View file

@ -1,101 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from together import Together
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from .config import TogetherSafetyConfig
TOGETHER_SHIELD_MODEL_MAP = {
"llama_guard": "meta-llama/Meta-Llama-Guard-3-8B",
"Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B",
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo",
}
class TogetherSafetyImpl(Safety, NeedsRequestProviderData, ShieldsProtocolPrivate):
def __init__(self, config: TogetherSafetyConfig) -> None:
self.config = config
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def register_shield(self, shield: ShieldDef) -> None:
raise ValueError("Registering dynamic shields is not supported")
async def list_shields(self) -> List[ShieldDef]:
return [
ShieldDef(
identifier=ShieldType.llama_guard.value,
type=ShieldType.llama_guard.value,
params={},
)
]
async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse:
shield_def = await self.shield_store.get_shield(shield_type)
if not shield_def:
raise ValueError(f"Unknown shield {shield_type}")
model = shield_def.params.get("model", "llama_guard")
if model not in TOGETHER_SHIELD_MODEL_MAP:
raise ValueError(f"Unsupported safety model: {model}")
together_api_key = None
if self.config.api_key is not None:
together_api_key = self.config.api_key
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.together_api_key:
raise ValueError(
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
)
together_api_key = provider_data.together_api_key
# messages can have role assistant or user
api_messages = []
for message in messages:
if message.role in (Role.user.value, Role.assistant.value):
api_messages.append({"role": message.role, "content": message.content})
violation = await get_safety_response(
together_api_key, TOGETHER_SHIELD_MODEL_MAP[model], api_messages
)
return RunShieldResponse(violation=violation)
async def get_safety_response(
api_key: str, model_name: str, messages: List[Dict[str, str]]
) -> Optional[SafetyViolation]:
client = Together(api_key=api_key)
response = client.chat.completions.create(messages=messages, model=model_name)
if len(response.choices) == 0:
return None
response_text = response.choices[0].message.content
if response_text == "safe":
return None
parts = response_text.split("\n")
if len(parts) != 2:
return None
if parts[0] == "unsafe":
return SafetyViolation(
violation_level=ViolationLevel.ERROR,
metadata={"violation_type": parts[1]},
)
return None

View file

@ -6,15 +6,17 @@
from enum import Enum
from typing import Any, List, Optional, Protocol
from urllib.parse import urlparse
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
from llama_stack.apis.datasets import DatasetDef
from llama_stack.apis.memory_banks import MemoryBankDef
from llama_stack.apis.models import ModelDef
from llama_stack.apis.scoring_functions import ScoringFnDef
from llama_stack.apis.shields import ShieldDef
from llama_stack.apis.datasets import Dataset
from llama_stack.apis.eval_tasks import EvalTask
from llama_stack.apis.memory_banks.memory_banks import MemoryBank
from llama_stack.apis.models import Model
from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.apis.shields import Shield
@json_schema_type
@ -34,39 +36,42 @@ class Api(Enum):
memory_banks = "memory_banks"
datasets = "datasets"
scoring_functions = "scoring_functions"
eval_tasks = "eval_tasks"
# built-in API
inspect = "inspect"
class ModelsProtocolPrivate(Protocol):
async def list_models(self) -> List[ModelDef]: ...
async def register_model(self, model: Model) -> None: ...
async def register_model(self, model: ModelDef) -> None: ...
async def unregister_model(self, model_id: str) -> None: ...
class ShieldsProtocolPrivate(Protocol):
async def list_shields(self) -> List[ShieldDef]: ...
async def register_shield(self, shield: ShieldDef) -> None: ...
async def register_shield(self, shield: Shield) -> None: ...
class MemoryBanksProtocolPrivate(Protocol):
async def list_memory_banks(self) -> List[MemoryBankDef]: ...
async def list_memory_banks(self) -> List[MemoryBank]: ...
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ...
async def register_memory_bank(self, memory_bank: MemoryBank) -> None: ...
async def unregister_memory_bank(self, memory_bank_id: str) -> None: ...
class DatasetsProtocolPrivate(Protocol):
async def list_datasets(self) -> List[DatasetDef]: ...
async def register_dataset(self, dataset_def: DatasetDef) -> None: ...
async def register_dataset(self, dataset: Dataset) -> None: ...
class ScoringFunctionsProtocolPrivate(Protocol):
async def list_scoring_functions(self) -> List[ScoringFnDef]: ...
async def list_scoring_functions(self) -> List[ScoringFn]: ...
async def register_scoring_function(self, function_def: ScoringFnDef) -> None: ...
async def register_scoring_function(self, scoring_fn: ScoringFn) -> None: ...
class EvalTasksProtocolPrivate(Protocol):
async def register_eval_task(self, eval_task: EvalTask) -> None: ...
@json_schema_type
@ -81,6 +86,14 @@ class ProviderSpec(BaseModel):
default_factory=list,
description="Higher-level API surfaces may depend on other providers to provide their functionality",
)
deprecation_warning: Optional[str] = Field(
default=None,
description="If this provider is deprecated, specify the warning message here",
)
deprecation_error: Optional[str] = Field(
default=None,
description="If this provider is deprecated and does NOT work, specify the error message here",
)
# used internally by the resolver; this is a hack for now
deps__: List[str] = Field(default_factory=list)
@ -90,6 +103,7 @@ class RoutingTable(Protocol):
def get_provider_impl(self, routing_key: str) -> Any: ...
# TODO: this can now be inlined into RemoteProviderSpec
@json_schema_type
class AdapterSpec(BaseModel):
adapter_type: str = Field(
@ -145,21 +159,27 @@ Fully-qualified name of the module to import. The module is expected to have:
class RemoteProviderConfig(BaseModel):
host: str = "localhost"
port: int
port: Optional[int] = None
protocol: str = "http"
@property
def url(self) -> str:
return f"http://{self.host}:{self.port}"
if self.port is None:
return f"{self.protocol}://{self.host}"
return f"{self.protocol}://{self.host}:{self.port}"
@classmethod
def from_url(cls, url: str) -> "RemoteProviderConfig":
parsed = urlparse(url)
return cls(host=parsed.hostname, port=parsed.port, protocol=parsed.scheme)
@json_schema_type
class RemoteProviderSpec(ProviderSpec):
adapter: Optional[AdapterSpec] = Field(
default=None,
adapter: AdapterSpec = Field(
description="""
If some code is needed to convert the remote responses into Llama Stack compatible
API responses, specify the adapter here. If not specified, it indicates the remote
as being "Llama Stack compatible"
API responses, specify the adapter here.
""",
)
@ -169,38 +189,21 @@ as being "Llama Stack compatible"
@property
def module(self) -> str:
if self.adapter:
return self.adapter.module
return "llama_stack.distribution.client"
return self.adapter.module
@property
def pip_packages(self) -> List[str]:
if self.adapter:
return self.adapter.pip_packages
return []
return self.adapter.pip_packages
@property
def provider_data_validator(self) -> Optional[str]:
if self.adapter:
return self.adapter.provider_data_validator
return None
return self.adapter.provider_data_validator
def is_passthrough(spec: ProviderSpec) -> bool:
return isinstance(spec, RemoteProviderSpec) and spec.adapter is None
# Can avoid this by using Pydantic computed_field
def remote_provider_spec(
api: Api, adapter: Optional[AdapterSpec] = None
) -> RemoteProviderSpec:
config_class = (
adapter.config_class
if adapter and adapter.config_class
else "llama_stack.distribution.datatypes.RemoteProviderConfig"
)
provider_type = f"remote::{adapter.adapter_type}" if adapter else "remote"
def remote_provider_spec(api: Api, adapter: AdapterSpec) -> RemoteProviderSpec:
return RemoteProviderSpec(
api=api, provider_type=provider_type, config_class=config_class, adapter=adapter
api=api,
provider_type=f"remote::{adapter.adapter_type}",
config_class=adapter.config_class,
adapter=adapter,
)

View file

@ -1,120 +0,0 @@
# LocalInference
LocalInference provides a local inference implementation powered by [executorch](https://github.com/pytorch/executorch/).
Llama Stack currently supports on-device inference for iOS with Android coming soon. You can run on-device inference on Android today using [executorch](https://github.com/pytorch/executorch/tree/main/examples/demo-apps/android/LlamaDemo), PyTorchs on-device inference library.
## Installation
We're working on making LocalInference easier to set up. For now, you'll need to import it via `.xcframework`:
1. Clone the executorch submodule in this repo and its dependencies: `git submodule update --init --recursive`
1. Install [Cmake](https://cmake.org/) for the executorch build`
1. Drag `LocalInference.xcodeproj` into your project
1. Add `LocalInference` as a framework in your app target
1. Add a package dependency on https://github.com/pytorch/executorch (branch latest)
1. Add all the kernels / backends from executorch (but not exectuorch itself!) as frameworks in your app target:
- backend_coreml
- backend_mps
- backend_xnnpack
- kernels_custom
- kernels_optimized
- kernels_portable
- kernels_quantized
1. In "Build Settings" > "Other Linker Flags" > "Any iOS Simulator SDK", add:
```
-force_load
$(BUILT_PRODUCTS_DIR)/libkernels_optimized-simulator-release.a
-force_load
$(BUILT_PRODUCTS_DIR)/libkernels_custom-simulator-release.a
-force_load
$(BUILT_PRODUCTS_DIR)/libkernels_quantized-simulator-release.a
-force_load
$(BUILT_PRODUCTS_DIR)/libbackend_xnnpack-simulator-release.a
-force_load
$(BUILT_PRODUCTS_DIR)/libbackend_coreml-simulator-release.a
-force_load
$(BUILT_PRODUCTS_DIR)/libbackend_mps-simulator-release.a
```
1. In "Build Settings" > "Other Linker Flags" > "Any iOS SDK", add:
```
-force_load
$(BUILT_PRODUCTS_DIR)/libkernels_optimized-simulator-release.a
-force_load
$(BUILT_PRODUCTS_DIR)/libkernels_custom-simulator-release.a
-force_load
$(BUILT_PRODUCTS_DIR)/libkernels_quantized-simulator-release.a
-force_load
$(BUILT_PRODUCTS_DIR)/libbackend_xnnpack-simulator-release.a
-force_load
$(BUILT_PRODUCTS_DIR)/libbackend_coreml-simulator-release.a
-force_load
$(BUILT_PRODUCTS_DIR)/libbackend_mps-simulator-release.a
```
## Preparing a model
1. Prepare a `.pte` file [following the executorch docs](https://github.com/pytorch/executorch/blob/main/examples/models/llama/README.md#step-2-prepare-model)
2. Bundle the `.pte` and `tokenizer.model` file into your app
We now support models quantized using SpinQuant and QAT-LoRA which offer a significant performance boost (demo app on iPhone 13 Pro):
| Llama 3.2 1B | Tokens / Second (total) | | Time-to-First-Token (sec) | |
| :---- | :---- | :---- | :---- | :---- |
| | Haiku | Paragraph | Haiku | Paragraph |
| BF16 | 2.2 | 2.5 | 2.3 | 1.9 |
| QAT+LoRA | 7.1 | 3.3 | 0.37 | 0.24 |
| SpinQuant | 10.1 | 5.2 | 0.2 | 0.2 |
## Using LocalInference
1. Instantiate LocalInference with a DispatchQueue. Optionally, pass it into your agents service:
```swift
init () {
runnerQueue = DispatchQueue(label: "org.meta.llamastack")
inferenceService = LocalInferenceService(queue: runnerQueue)
agentsService = LocalAgentsService(inference: inferenceService)
}
```
2. Before making any inference calls, load your model from your bundle:
```swift
let mainBundle = Bundle.main
inferenceService.loadModel(
modelPath: mainBundle.url(forResource: "llama32_1b_spinquant", withExtension: "pte"),
tokenizerPath: mainBundle.url(forResource: "tokenizer", withExtension: "model"),
completion: {_ in } // use to handle load failures
)
```
3. Make inference calls (or agents calls) as you normally would with LlamaStack:
```
for await chunk in try await agentsService.initAndCreateTurn(
messages: [
.UserMessage(Components.Schemas.UserMessage(
content: .case1("Call functions as needed to handle any actions in the following text:\n\n" + text),
role: .user))
]
) {
```
## Troubleshooting
If you receive errors like "missing package product" or "invalid checksum", try cleaning the build folder and resetting the Swift package cache:
(Opt+Click) Product > Clean Build Folder Immediately
```
rm -rf \
~/Library/org.swift.swiftpm \
~/Library/Caches/org.swift.swiftpm \
~/Library/Caches/com.apple.dt.Xcode \
~/Library/Developer/Xcode/DerivedData
```

View file

@ -1,17 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import SafetyConfig
async def get_provider_impl(config: SafetyConfig, deps):
from .safety import MetaReferenceSafetyImpl
assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}"
impl = MetaReferenceSafetyImpl(config, deps)
await impl.initialize()
return impl

View file

@ -1,57 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from abc import ABC, abstractmethod
from typing import List
from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message
from pydantic import BaseModel
from llama_stack.apis.safety import * # noqa: F403
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
# TODO: clean this up; just remove this type completely
class ShieldResponse(BaseModel):
is_violation: bool
violation_type: Optional[str] = None
violation_return_message: Optional[str] = None
# TODO: this is a caller / agent concern
class OnViolationAction(Enum):
IGNORE = 0
WARN = 1
RAISE = 2
class ShieldBase(ABC):
def __init__(
self,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
self.on_violation_action = on_violation_action
@abstractmethod
async def run(self, messages: List[Message]) -> ShieldResponse:
raise NotImplementedError()
def message_content_as_str(message: Message) -> str:
return interleaved_text_media_as_str(message.content)
class TextShield(ShieldBase):
def convert_messages_to_text(self, messages: List[Message]) -> str:
return "\n".join([message_content_as_str(m) for m in messages])
async def run(self, messages: List[Message]) -> ShieldResponse:
text = self.convert_messages_to_text(messages)
return await self.run_impl(text)
@abstractmethod
async def run_impl(self, text: str) -> ShieldResponse:
raise NotImplementedError()

View file

@ -1,48 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import List, Optional
from llama_models.sku_list import CoreModelId, safety_models
from pydantic import BaseModel, field_validator
class PromptGuardType(Enum):
injection = "injection"
jailbreak = "jailbreak"
class LlamaGuardShieldConfig(BaseModel):
model: str = "Llama-Guard-3-1B"
excluded_categories: List[str] = []
@field_validator("model")
@classmethod
def validate_model(cls, model: str) -> str:
permitted_models = [
m.descriptor()
for m in safety_models()
if (
m.core_model_id
in {
CoreModelId.llama_guard_3_8b,
CoreModelId.llama_guard_3_1b,
CoreModelId.llama_guard_3_11b_vision,
}
)
]
if model not in permitted_models:
raise ValueError(
f"Invalid model: {model}. Must be one of {permitted_models}"
)
return model
class SafetyConfig(BaseModel):
llama_guard_shield: Optional[LlamaGuardShieldConfig] = None
enable_prompt_guard: Optional[bool] = False

View file

@ -1,145 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import auto, Enum
from typing import List
import torch
from llama_models.llama3.api.datatypes import Message
from termcolor import cprint
from .base import message_content_as_str, OnViolationAction, ShieldResponse, TextShield
class PromptGuardShield(TextShield):
class Mode(Enum):
INJECTION = auto()
JAILBREAK = auto()
_instances = {}
_model_cache = None
@staticmethod
def instance(
model_dir: str,
threshold: float = 0.9,
temperature: float = 1.0,
mode: "PromptGuardShield.Mode" = Mode.JAILBREAK,
on_violation_action=OnViolationAction.RAISE,
) -> "PromptGuardShield":
action_value = on_violation_action.value
key = (model_dir, threshold, temperature, mode, action_value)
if key not in PromptGuardShield._instances:
PromptGuardShield._instances[key] = PromptGuardShield(
model_dir=model_dir,
threshold=threshold,
temperature=temperature,
mode=mode,
on_violation_action=on_violation_action,
)
return PromptGuardShield._instances[key]
def __init__(
self,
model_dir: str,
threshold: float = 0.9,
temperature: float = 1.0,
mode: "PromptGuardShield.Mode" = Mode.JAILBREAK,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
super().__init__(on_violation_action)
assert (
model_dir is not None
), "Must provide a model directory for prompt injection shield"
if temperature <= 0:
raise ValueError("Temperature must be greater than 0")
self.device = "cuda"
if PromptGuardShield._model_cache is None:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
# load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSequenceClassification.from_pretrained(
model_dir, device_map=self.device
)
PromptGuardShield._model_cache = (tokenizer, model)
self.tokenizer, self.model = PromptGuardShield._model_cache
self.temperature = temperature
self.threshold = threshold
self.mode = mode
def convert_messages_to_text(self, messages: List[Message]) -> str:
return message_content_as_str(messages[-1])
async def run_impl(self, text: str) -> ShieldResponse:
# run model on messages and return response
inputs = self.tokenizer(text, return_tensors="pt")
inputs = {name: tensor.to(self.model.device) for name, tensor in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs[0]
probabilities = torch.softmax(logits / self.temperature, dim=-1)
score_embedded = probabilities[0, 1].item()
score_malicious = probabilities[0, 2].item()
cprint(
f"Ran PromptGuardShield and got Scores: Embedded: {score_embedded}, Malicious: {score_malicious}",
color="magenta",
)
if self.mode == self.Mode.INJECTION and (
score_embedded + score_malicious > self.threshold
):
return ShieldResponse(
is_violation=True,
violation_type=f"prompt_injection:embedded={score_embedded},malicious={score_malicious}",
violation_return_message="Sorry, I cannot do this.",
)
elif self.mode == self.Mode.JAILBREAK and score_malicious > self.threshold:
return ShieldResponse(
is_violation=True,
violation_type=f"prompt_injection:malicious={score_malicious}",
violation_return_message="Sorry, I cannot do this.",
)
return ShieldResponse(
is_violation=False,
)
class JailbreakShield(PromptGuardShield):
def __init__(
self,
model_dir: str,
threshold: float = 0.9,
temperature: float = 1.0,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
super().__init__(
model_dir=model_dir,
threshold=threshold,
temperature=temperature,
mode=PromptGuardShield.Mode.JAILBREAK,
on_violation_action=on_violation_action,
)
class InjectionShield(PromptGuardShield):
def __init__(
self,
model_dir: str,
threshold: float = 0.9,
temperature: float = 1.0,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
super().__init__(
model_dir=model_dir,
threshold=threshold,
temperature=temperature,
mode=PromptGuardShield.Mode.INJECTION,
on_violation_action=on_violation_action,
)

View file

@ -1,112 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from .base import OnViolationAction, ShieldBase
from .config import SafetyConfig
from .llama_guard import LlamaGuardShield
from .prompt_guard import InjectionShield, JailbreakShield, PromptGuardShield
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate):
def __init__(self, config: SafetyConfig, deps) -> None:
self.config = config
self.inference_api = deps[Api.inference]
self.available_shields = []
if config.llama_guard_shield:
self.available_shields.append(ShieldType.llama_guard.value)
if config.enable_prompt_guard:
self.available_shields.append(ShieldType.prompt_guard.value)
async def initialize(self) -> None:
if self.config.enable_prompt_guard:
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
_ = PromptGuardShield.instance(model_dir)
async def shutdown(self) -> None:
pass
async def register_shield(self, shield: ShieldDef) -> None:
raise ValueError("Registering dynamic shields is not supported")
async def list_shields(self) -> List[ShieldDef]:
return [
ShieldDef(
identifier=shield_type,
type=shield_type,
params={},
)
for shield_type in self.available_shields
]
async def run_shield(
self,
shield_type: str,
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
shield_def = await self.shield_store.get_shield(shield_type)
if not shield_def:
raise ValueError(f"Unknown shield {shield_type}")
shield = self.get_shield_impl(shield_def)
messages = messages.copy()
# some shields like llama-guard require the first message to be a user message
# since this might be a tool call, first role might not be user
if len(messages) > 0 and messages[0].role != Role.user.value:
messages[0] = UserMessage(content=messages[0].content)
# TODO: we can refactor ShieldBase, etc. to be inline with the API types
res = await shield.run(messages)
violation = None
if res.is_violation and shield.on_violation_action != OnViolationAction.IGNORE:
violation = SafetyViolation(
violation_level=(
ViolationLevel.ERROR
if shield.on_violation_action == OnViolationAction.RAISE
else ViolationLevel.WARN
),
user_message=res.violation_return_message,
metadata={
"violation_type": res.violation_type,
},
)
return RunShieldResponse(violation=violation)
def get_shield_impl(self, shield: ShieldDef) -> ShieldBase:
if shield.type == ShieldType.llama_guard.value:
cfg = self.config.llama_guard_shield
return LlamaGuardShield(
model=cfg.model,
inference_api=self.inference_api,
excluded_categories=cfg.excluded_categories,
)
elif shield.type == ShieldType.prompt_guard.value:
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
subtype = shield.params.get("prompt_guard_type", "injection")
if subtype == "injection":
return InjectionShield.instance(model_dir)
elif subtype == "jailbreak":
return JailbreakShield.instance(model_dir)
else:
raise ValueError(f"Unknown prompt guard type: {subtype}")
else:
raise ValueError(f"Unknown shield type: {shield.type}")

View file

@ -156,7 +156,7 @@ class ChatAgent(ShieldRunnerMixin):
turns = await self.storage.get_session_turns(request.session_id)
messages = []
if len(turns) == 0 and self.agent_config.instructions != "":
if self.agent_config.instructions != "":
messages.append(SystemMessage(content=self.agent_config.instructions))
for i, turn in enumerate(turns):
@ -641,12 +641,13 @@ class ChatAgent(ShieldRunnerMixin):
if session_info.memory_bank_id is None:
bank_id = f"memory_bank_{session_id}"
memory_bank = VectorMemoryBankDef(
identifier=bank_id,
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
await self.memory_banks_api.register_memory_bank(
memory_bank_id=bank_id,
params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
),
)
await self.memory_banks_api.register_memory_bank(memory_bank)
await self.storage.add_memory_bank_to_session(session_id, bank_id)
else:
bank_id = session_info.memory_bank_id

View file

@ -4,10 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from pydantic import BaseModel, Field
from llama_stack.providers.utils.kvstore import KVStoreConfig
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
class MetaReferenceAgentsImplConfig(BaseModel):
persistence_store: KVStoreConfig
persistence_store: KVStoreConfig = Field(default=SqliteKVStoreConfig())

View file

@ -80,5 +80,5 @@ class AgentPersistence:
except Exception as e:
print(f"Error parsing turn: {e}")
continue
turns.sort(key=lambda x: (x.completed_at or datetime.min))
return turns

View file

@ -32,18 +32,18 @@ class ShieldRunnerMixin:
self.output_shields = output_shields
async def run_multiple_shields(
self, messages: List[Message], shield_types: List[str]
self, messages: List[Message], identifiers: List[str]
) -> None:
responses = await asyncio.gather(
*[
self.safety_api.run_shield(
shield_type=shield_type,
shield_id=identifier,
messages=messages,
)
for shield_type in shield_types
for identifier in identifiers
]
)
for shield_type, response in zip(shield_types, responses):
for identifier, response in zip(identifiers, responses):
if not response.violation:
continue
@ -52,6 +52,6 @@ class ShieldRunnerMixin:
raise SafetyException(violation)
elif violation.violation_level == ViolationLevel.WARN:
cprint(
f"[Warn]{shield_type} raised a warning",
f"[Warn]{identifier} raised a warning",
color="red",
)

View file

@ -80,7 +80,7 @@ class MockInferenceAPI:
class MockSafetyAPI:
async def run_shield(
self, shield_type: str, messages: List[Message]
self, shield_id: str, messages: List[Message]
) -> RunShieldResponse:
return RunShieldResponse(violation=None)

View file

@ -9,8 +9,7 @@ from typing import List
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.providers.impls.meta_reference.agents.safety import ShieldRunnerMixin
from ..safety import ShieldRunnerMixin
from .builtin import BaseTool

View file

@ -4,15 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import MetaReferenceDatasetIOConfig
from .config import LocalFSDatasetIOConfig
async def get_provider_impl(
config: MetaReferenceDatasetIOConfig,
config: LocalFSDatasetIOConfig,
_deps,
):
from .datasetio import MetaReferenceDatasetIOImpl
from .datasetio import LocalFSDatasetIOImpl
impl = MetaReferenceDatasetIOImpl(config)
impl = LocalFSDatasetIOImpl(config)
await impl.initialize()
return impl

View file

@ -6,4 +6,4 @@
from llama_stack.apis.datasetio import * # noqa: F401, F403
class MetaReferenceDatasetIOConfig(BaseModel): ...
class LocalFSDatasetIOConfig(BaseModel): ...

View file

@ -3,22 +3,19 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import io
from typing import List, Optional
from typing import Optional
import pandas
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
import base64
from abc import ABC, abstractmethod
from dataclasses import dataclass
from urllib.parse import unquote
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import parse_data_url
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
from .config import MetaReferenceDatasetIOConfig
from .config import LocalFSDatasetIOConfig
class BaseDataset(ABC):
@ -40,12 +37,12 @@ class BaseDataset(ABC):
@dataclass
class DatasetInfo:
dataset_def: DatasetDef
dataset_def: Dataset
dataset_impl: BaseDataset
class PandasDataframeDataset(BaseDataset):
def __init__(self, dataset_def: DatasetDef, *args, **kwargs) -> None:
def __init__(self, dataset_def: Dataset, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.dataset_def = dataset_def
self.df = None
@ -73,37 +70,15 @@ class PandasDataframeDataset(BaseDataset):
if self.df is not None:
return
# TODO: more robust support w/ data url
if self.dataset_def.url.uri.endswith(".csv"):
df = pandas.read_csv(self.dataset_def.url.uri)
elif self.dataset_def.url.uri.endswith(".xlsx"):
df = pandas.read_excel(self.dataset_def.url.uri)
elif self.dataset_def.url.uri.startswith("data:"):
parts = parse_data_url(self.dataset_def.url.uri)
data = parts["data"]
if parts["is_base64"]:
data = base64.b64decode(data)
else:
data = unquote(data)
encoding = parts["encoding"] or "utf-8"
data = data.encode(encoding)
mime_type = parts["mimetype"]
mime_category = mime_type.split("/")[0]
data_bytes = io.BytesIO(data)
if mime_category == "text":
df = pandas.read_csv(data_bytes)
else:
df = pandas.read_excel(data_bytes)
else:
raise ValueError(f"Unsupported file type: {self.dataset_def.url}")
df = get_dataframe_from_url(self.dataset_def.url)
if df is None:
raise ValueError(f"Failed to load dataset from {self.dataset_def.url}")
self.df = self._validate_dataset_schema(df)
class MetaReferenceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
def __init__(self, config: MetaReferenceDatasetIOConfig) -> None:
class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
def __init__(self, config: LocalFSDatasetIOConfig) -> None:
self.config = config
# local registry for keeping track of datasets within the provider
self.dataset_infos = {}
@ -114,17 +89,14 @@ class MetaReferenceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
async def register_dataset(
self,
dataset_def: DatasetDef,
dataset: Dataset,
) -> None:
dataset_impl = PandasDataframeDataset(dataset_def)
self.dataset_infos[dataset_def.identifier] = DatasetInfo(
dataset_def=dataset_def,
dataset_impl = PandasDataframeDataset(dataset)
self.dataset_infos[dataset.identifier] = DatasetInfo(
dataset_def=dataset,
dataset_impl=dataset_impl,
)
async def list_datasets(self) -> List[DatasetDef]:
return [i.dataset_def for i in self.dataset_infos.values()]
async def get_rows_paginated(
self,
dataset_id: str,

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
from pydantic import BaseModel
class MetaReferenceEvalConfig(BaseModel):
kvstore: KVStoreConfig = SqliteKVStoreConfig(
db_path=(RUNTIME_BASE_DIR / "meta_reference_eval.db").as_posix()
) # Uses SQLite config specific to Meta Reference Eval storage

View file

@ -6,16 +6,22 @@
from enum import Enum
from llama_models.llama3.api.datatypes import * # noqa: F403
from .....apis.common.job_types import Job
from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.common.job_types import Job
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval, EvalCandidate, EvaluateResponse, JobStatus
from llama_stack.apis.eval_tasks import EvalTask
from llama_stack.apis.inference import Inference
from llama_stack.apis.scoring import Scoring
from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from tqdm import tqdm
from .config import MetaReferenceEvalConfig
EVAL_TASKS_PREFIX = "eval_tasks:"
class ColumnName(Enum):
input_query = "input_query"
@ -25,7 +31,7 @@ class ColumnName(Enum):
generated_answer = "generated_answer"
class MetaReferenceEvalImpl(Eval):
class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
def __init__(
self,
config: MetaReferenceEvalConfig,
@ -43,12 +49,32 @@ class MetaReferenceEvalImpl(Eval):
# TODO: assume sync job, will need jobs API for async scheduling
self.jobs = {}
async def initialize(self) -> None: ...
self.eval_tasks = {}
async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore)
# Load existing eval_tasks from kvstore
start_key = EVAL_TASKS_PREFIX
end_key = f"{EVAL_TASKS_PREFIX}\xff"
stored_eval_tasks = await self.kvstore.range(start_key, end_key)
for eval_task in stored_eval_tasks:
eval_task = EvalTask.model_validate_json(eval_task)
self.eval_tasks[eval_task.identifier] = eval_task
async def shutdown(self) -> None: ...
async def register_eval_task(self, task_def: EvalTask) -> None:
# Store in kvstore
key = f"{EVAL_TASKS_PREFIX}{task_def.identifier}"
await self.kvstore.set(
key=key,
value=task_def.json(),
)
self.eval_tasks[task_def.identifier] = task_def
async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
@ -70,21 +96,28 @@ class MetaReferenceEvalImpl(Eval):
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
)
async def evaluate_batch(
async def run_eval(
self,
dataset_id: str,
candidate: EvalCandidate,
scoring_functions: List[str],
task_id: str,
task_config: EvalTaskConfig,
) -> Job:
task_def = self.eval_tasks[task_id]
dataset_id = task_def.dataset_id
candidate = task_config.eval_candidate
scoring_functions = task_def.scoring_functions
await self.validate_eval_input_dataset_schema(dataset_id=dataset_id)
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=-1,
rows_in_page=(
-1 if task_config.num_examples is None else task_config.num_examples
),
)
res = await self.evaluate(
res = await self.evaluate_rows(
task_id=task_id,
input_rows=all_rows.rows,
candidate=candidate,
scoring_functions=scoring_functions,
task_config=task_config,
)
# TODO: currently needs to wait for generation before returning
@ -93,12 +126,14 @@ class MetaReferenceEvalImpl(Eval):
self.jobs[job_id] = res
return Job(job_id=job_id)
async def evaluate(
async def evaluate_rows(
self,
task_id: str,
input_rows: List[Dict[str, Any]],
candidate: EvalCandidate,
scoring_functions: List[str],
task_config: EvalTaskConfig,
) -> EvaluateResponse:
candidate = task_config.eval_candidate
if candidate.type == "agent":
raise NotImplementedError(
"Evaluation with generation has not been implemented for agents"
@ -108,7 +143,7 @@ class MetaReferenceEvalImpl(Eval):
), "SamplingParams.max_tokens must be provided"
generations = []
for x in input_rows:
for x in tqdm(input_rows):
if ColumnName.completion_input.value in x:
input_content = eval(str(x[ColumnName.completion_input.value]))
response = await self.inference_api.completion(
@ -122,14 +157,17 @@ class MetaReferenceEvalImpl(Eval):
}
)
elif ColumnName.chat_completion_input.value in x:
input_messages = eval(str(x[ColumnName.chat_completion_input.value]))
chat_completion_input_str = str(
x[ColumnName.chat_completion_input.value]
)
input_messages = eval(chat_completion_input_str)
input_messages = [UserMessage(**x) for x in input_messages]
messages = []
if candidate.system_message:
messages.append(candidate.system_message)
messages += input_messages
response = await self.inference_api.chat_completion(
model=candidate.model,
model_id=candidate.model,
messages=messages,
sampling_params=candidate.sampling_params,
)
@ -147,23 +185,33 @@ class MetaReferenceEvalImpl(Eval):
for input_r, generated_r in zip(input_rows, generations)
]
if task_config.type == "app" and task_config.scoring_params is not None:
scoring_functions_dict = {
scoring_fn_id: task_config.scoring_params.get(scoring_fn_id, None)
for scoring_fn_id in scoring_functions
}
else:
scoring_functions_dict = {
scoring_fn_id: None for scoring_fn_id in scoring_functions
}
score_response = await self.scoring_api.score(
input_rows=score_input_rows, scoring_functions=scoring_functions
input_rows=score_input_rows, scoring_functions=scoring_functions_dict
)
return EvaluateResponse(generations=generations, scores=score_response.results)
async def job_status(self, job_id: str) -> Optional[JobStatus]:
async def job_status(self, task_id: str, job_id: str) -> Optional[JobStatus]:
if job_id in self.jobs:
return JobStatus.completed
return None
async def job_cancel(self, job_id: str) -> None:
async def job_cancel(self, task_id: str, job_id: str) -> None:
raise NotImplementedError("Job cancel is not implemented yet")
async def job_result(self, job_id: str) -> EvaluateResponse:
status = await self.job_status(job_id)
async def job_result(self, task_id: str, job_id: str) -> EvaluateResponse:
status = await self.job_status(task_id, job_id)
if not status or status != JobStatus.completed:
raise ValueError(f"Job is not completed, Status: {status.value}")

View file

@ -86,6 +86,7 @@ class Llama:
and loads the pre-trained model and tokenizer.
"""
model = resolve_model(config.model)
llama_model = model.core_model_id.value
if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
@ -186,13 +187,20 @@ class Llama:
model.load_state_dict(state_dict, strict=False)
print(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama(model, tokenizer, model_args)
return Llama(model, tokenizer, model_args, llama_model)
def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs):
def __init__(
self,
model: Transformer,
tokenizer: Tokenizer,
args: ModelArgs,
llama_model: str,
):
self.args = args
self.model = model
self.tokenizer = tokenizer
self.formatter = ChatFormat(tokenizer)
self.llama_model = llama_model
@torch.inference_mode()
def generate(
@ -369,7 +377,7 @@ class Llama:
self,
request: ChatCompletionRequest,
) -> Generator:
messages = chat_completion_request_to_messages(request)
messages = chat_completion_request_to_messages(request, self.llama_model)
sampling_params = request.sampling_params
max_gen_len = sampling_params.max_tokens

View file

@ -11,8 +11,15 @@ from typing import AsyncGenerator, List
from llama_models.sku_list import resolve_model
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.providers.utils.inference.model_registry import build_model_alias
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_media_to_url,
request_has_media,
)
from .config import MetaReferenceInferenceConfig
from .generation import Llama
@ -23,10 +30,19 @@ from .model_parallel import LlamaModelParallelGenerator
SEMAPHORE = asyncio.Semaphore(1)
class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolPrivate):
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
self.config = config
model = resolve_model(config.model)
ModelRegistryHelper.__init__(
self,
[
build_model_alias(
model.descriptor(),
model.core_model_id.value,
)
],
)
if model is None:
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
self.model = model
@ -40,17 +56,6 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
else:
self.generator = Llama.build(self.config)
async def register_model(self, model: ModelDef) -> None:
raise ValueError("Dynamic model registration is not supported")
async def list_models(self) -> List[ModelDef]:
return [
ModelDef(
identifier=self.model.descriptor(),
llama_model=self.model.descriptor(),
)
]
async def shutdown(self) -> None:
if self.config.create_distributed_process_group:
self.generator.stop()
@ -66,9 +71,12 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
f"Model mismatch: {request.model} != {self.model.descriptor()}"
)
async def unregister_model(self, model_id: str) -> None:
pass
async def completion(
self,
model: str,
model_id: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
@ -79,7 +87,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
request = CompletionRequest(
model=model,
model=model_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
@ -87,6 +95,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
logprobs=logprobs,
)
self.check_model(request)
request = await request_with_localized_media(request)
if request.stream:
return self._stream_completion(request)
@ -185,7 +194,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
async def chat_completion(
self,
model: str,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
@ -200,7 +209,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = ChatCompletionRequest(
model=model,
model=model_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
@ -211,6 +220,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
logprobs=logprobs,
)
self.check_model(request)
request = await request_with_localized_media(request)
if self.config.create_distributed_process_group:
if SEMAPHORE.locked():
@ -384,7 +394,35 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
async def embeddings(
self,
model: str,
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()
async def request_with_localized_media(
request: Union[ChatCompletionRequest, CompletionRequest],
) -> Union[ChatCompletionRequest, CompletionRequest]:
if not request_has_media(request):
return request
async def _convert_single_content(content):
if isinstance(content, ImageMedia):
url = await convert_image_media_to_url(content, download=True)
return ImageMedia(image=URL(uri=url))
else:
return content
async def _convert_content(content):
if isinstance(content, list):
return [await _convert_single_content(c) for c in content]
else:
return await _convert_single_content(content)
if isinstance(request, ChatCompletionRequest):
for m in request.messages:
m.content = await _convert_content(m.content)
else:
request.content = await _convert_content(request.content)
return request

View file

@ -20,6 +20,7 @@ from llama_models.datatypes import CheckpointQuantizationFormat
from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
from llama_models.sku_list import resolve_model
from termcolor import cprint
from torch import nn, Tensor
@ -27,9 +28,7 @@ from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
from llama_stack.apis.inference import QuantizationType
from llama_stack.providers.impls.meta_reference.inference.config import (
MetaReferenceQuantizedInferenceConfig,
)
from ..config import MetaReferenceQuantizedInferenceConfig
def swiglu_wrapper(

View file

@ -20,7 +20,7 @@ from vllm.sampling_params import SamplingParams as VLLMSamplingParams
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
@ -83,19 +83,11 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
if self.engine:
self.engine.shutdown_background_loop()
async def register_model(self, model: ModelDef) -> None:
async def register_model(self, model: Model) -> None:
raise ValueError(
"You cannot dynamically add a model to a running vllm instance"
)
async def list_models(self) -> List[ModelDef]:
return [
ModelDef(
identifier=self.config.model,
llama_model=self.config.model,
)
]
def _sampling_params(self, sampling_params: SamplingParams) -> VLLMSamplingParams:
if sampling_params is None:
return VLLMSamplingParams(max_tokens=self.config.max_tokens)
@ -116,9 +108,12 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
return VLLMSamplingParams(**kwargs)
async def unregister_model(self, model_id: str) -> None:
pass
async def completion(
self,
model: str,
model_id: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
@ -128,7 +123,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
log.info("vLLM completion")
messages = [UserMessage(content=content)]
return self.chat_completion(
model=model,
model=model_id,
messages=messages,
sampling_params=sampling_params,
stream=stream,
@ -137,7 +132,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
async def chat_completion(
self,
model: str,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
@ -152,7 +147,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
assert self.engine is not None
request = ChatCompletionRequest(
model=model,
model=model_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
@ -223,7 +218,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
yield chunk
async def embeddings(
self, model: str, contents: list[InterleavedTextMedia]
self, model_id: str, contents: list[InterleavedTextMedia]
) -> EmbeddingsResponse:
log.info("vLLM embeddings")
# TODO

View file

@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
@json_schema_type
class FaissImplConfig(BaseModel):
kvstore: KVStoreConfig = SqliteKVStoreConfig(
db_path=(RUNTIME_BASE_DIR / "faiss_store.db").as_posix()
) # Uses SQLite config specific to FAISS storage

View file

@ -4,11 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import json
import logging
from typing import Any, Dict, List, Optional
import faiss
import numpy as np
from numpy.typing import NDArray
@ -16,6 +19,7 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.memory.vector_store import (
ALL_MINILM_L6_V2_DIMENSION,
@ -28,15 +32,59 @@ from .config import FaissImplConfig
logger = logging.getLogger(__name__)
MEMORY_BANKS_PREFIX = "memory_banks:v1::"
class FaissIndex(EmbeddingIndex):
id_by_index: Dict[int, str]
chunk_by_index: Dict[int, str]
def __init__(self, dimension: int):
def __init__(self, dimension: int, kvstore=None, bank_id: str = None):
self.index = faiss.IndexFlatL2(dimension)
self.id_by_index = {}
self.chunk_by_index = {}
self.kvstore = kvstore
self.bank_id = bank_id
self.initialize()
async def initialize(self) -> None:
if not self.kvstore:
return
index_key = f"faiss_index:v1::{self.bank_id}"
stored_data = await self.kvstore.get(index_key)
if stored_data:
data = json.loads(stored_data)
self.id_by_index = {int(k): v for k, v in data["id_by_index"].items()}
self.chunk_by_index = {
int(k): Chunk.model_validate_json(v)
for k, v in data["chunk_by_index"].items()
}
index_bytes = base64.b64decode(data["faiss_index"])
self.index = faiss.deserialize_index(index_bytes)
async def _save_index(self):
if not self.kvstore or not self.bank_id:
return
index_bytes = faiss.serialize_index(self.index)
data = {
"id_by_index": self.id_by_index,
"chunk_by_index": {k: v.json() for k, v in self.chunk_by_index.items()},
"faiss_index": base64.b64encode(index_bytes).decode(),
}
index_key = f"faiss_index:v1::{self.bank_id}"
await self.kvstore.set(key=index_key, value=json.dumps(data))
async def delete(self):
if not self.kvstore or not self.bank_id:
return
await self.kvstore.delete(f"faiss_index:v1::{self.bank_id}")
@tracing.span(name="add_chunks")
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
@ -47,6 +95,9 @@ class FaissIndex(EmbeddingIndex):
self.index.add(np.array(embeddings).astype(np.float32))
# Save updated index
await self._save_index()
async def query(
self, embedding: NDArray, k: int, score_threshold: float
) -> QueryDocumentsResponse:
@ -69,27 +120,56 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
def __init__(self, config: FaissImplConfig) -> None:
self.config = config
self.cache = {}
self.kvstore = None
async def initialize(self) -> None: ...
async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore)
# Load existing banks from kvstore
start_key = MEMORY_BANKS_PREFIX
end_key = f"{MEMORY_BANKS_PREFIX}\xff"
stored_banks = await self.kvstore.range(start_key, end_key)
async def shutdown(self) -> None: ...
for bank_data in stored_banks:
bank = VectorMemoryBank.model_validate_json(bank_data)
index = BankWithIndex(
bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION, self.kvstore)
)
self.cache[bank.identifier] = index
async def shutdown(self) -> None:
# Cleanup if needed
pass
async def register_memory_bank(
self,
memory_bank: MemoryBankDef,
memory_bank: MemoryBank,
) -> None:
assert (
memory_bank.type == MemoryBankType.vector.value
memory_bank.memory_bank_type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.type}"
# Store in kvstore
key = f"{MEMORY_BANKS_PREFIX}{memory_bank.identifier}"
await self.kvstore.set(
key=key,
value=memory_bank.json(),
)
# Store in cache
index = BankWithIndex(
bank=memory_bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION)
bank=memory_bank,
index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION, self.kvstore),
)
self.cache[memory_bank.identifier] = index
async def list_memory_banks(self) -> List[MemoryBankDef]:
async def list_memory_banks(self) -> List[MemoryBank]:
return [i.bank for i in self.cache.values()]
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
await self.cache[memory_bank_id].index.delete()
del self.cache[memory_bank_id]
await self.kvstore.delete(f"{MEMORY_BANKS_PREFIX}{memory_bank_id}")
async def insert_documents(
self,
bank_id: str,

View file

@ -14,6 +14,12 @@ from .config import CodeScannerConfig
from llama_stack.apis.safety import * # noqa: F403
ALLOWED_CODE_SCANNER_MODEL_IDS = [
"CodeScanner",
"CodeShield",
]
class MetaReferenceCodeScannerSafetyImpl(Safety):
def __init__(self, config: CodeScannerConfig, deps) -> None:
self.config = config
@ -24,19 +30,21 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
async def shutdown(self) -> None:
pass
async def register_shield(self, shield: ShieldDef) -> None:
if shield.type != ShieldType.code_scanner.value:
raise ValueError(f"Unsupported safety shield type: {shield.type}")
async def register_shield(self, shield: Shield) -> None:
if shield.provider_resource_id not in ALLOWED_CODE_SCANNER_MODEL_IDS:
raise ValueError(
f"Unsupported Code Scanner ID: {shield.provider_resource_id}. Allowed IDs: {ALLOWED_CODE_SCANNER_MODEL_IDS}"
)
async def run_shield(
self,
shield_type: str,
shield_id: str,
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
shield_def = await self.shield_store.get_shield(shield_type)
if not shield_def:
raise ValueError(f"Unknown shield {shield_type}")
shield = await self.shield_store.get_shield(shield_id)
if not shield:
raise ValueError(f"Shield {shield_id} not found")
from codeshield.cs import CodeShield

View file

@ -7,5 +7,5 @@
from pydantic import BaseModel
class CodeShieldConfig(BaseModel):
class CodeScannerConfig(BaseModel):
pass

View file

@ -4,15 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import TogetherProviderDataValidator, TogetherSafetyConfig # noqa: F401
from .config import LlamaGuardConfig
async def get_adapter_impl(config: TogetherSafetyConfig, _deps):
from .together import TogetherSafetyImpl
async def get_provider_impl(config: LlamaGuardConfig, deps):
from .llama_guard import LlamaGuardSafetyImpl
assert isinstance(
config, TogetherSafetyConfig
config, LlamaGuardConfig
), f"Unexpected config type: {type(config)}"
impl = TogetherSafetyImpl(config)
impl = LlamaGuardSafetyImpl(config, deps)
await impl.initialize()
return impl

View file

@ -0,0 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from pydantic import BaseModel
class LlamaGuardConfig(BaseModel):
excluded_categories: List[str] = []

View file

@ -7,16 +7,21 @@
import re
from string import Template
from typing import List, Optional
from typing import Any, Dict, List, Optional
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.distribution.datatypes import Api
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from .config import LlamaGuardConfig
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
SAFE_RESPONSE = "safe"
_INSTANCE = None
CAT_VIOLENT_CRIMES = "Violent Crimes"
CAT_NON_VIOLENT_CRIMES = "Non-Violent Crimes"
@ -68,6 +73,11 @@ DEFAULT_LG_V3_SAFETY_CATEGORIES = [
CAT_ELECTIONS,
]
LLAMA_GUARD_MODEL_IDS = [
CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_1b.value,
CoreModelId.llama_guard_3_11b_vision.value,
]
MODEL_TO_SAFETY_CATEGORIES_MAP = {
CoreModelId.llama_guard_3_8b.value: (
@ -107,16 +117,55 @@ PROMPT_TEMPLATE = Template(
)
class LlamaGuardShield(ShieldBase):
class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
def __init__(self, config: LlamaGuardConfig, deps) -> None:
self.config = config
self.inference_api = deps[Api.inference]
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def register_shield(self, shield: Shield) -> None:
if shield.provider_resource_id not in LLAMA_GUARD_MODEL_IDS:
raise ValueError(
f"Unsupported Llama Guard type: {shield.provider_resource_id}. Allowed types: {LLAMA_GUARD_MODEL_IDS}"
)
async def run_shield(
self,
shield_id: str,
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
shield = await self.shield_store.get_shield(shield_id)
if not shield:
raise ValueError(f"Unknown shield {shield_id}")
messages = messages.copy()
# some shields like llama-guard require the first message to be a user message
# since this might be a tool call, first role might not be user
if len(messages) > 0 and messages[0].role != Role.user.value:
messages[0] = UserMessage(content=messages[0].content)
impl = LlamaGuardShield(
model=shield.provider_resource_id,
inference_api=self.inference_api,
excluded_categories=self.config.excluded_categories,
)
return await impl.run(messages)
class LlamaGuardShield:
def __init__(
self,
model: str,
inference_api: Inference,
excluded_categories: List[str] = None,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
excluded_categories: Optional[List[str]] = None,
):
super().__init__(on_violation_action)
if excluded_categories is None:
excluded_categories = []
@ -174,7 +223,7 @@ class LlamaGuardShield(ShieldBase):
)
return messages
async def run(self, messages: List[Message]) -> ShieldResponse:
async def run(self, messages: List[Message]) -> RunShieldResponse:
messages = self.validate_messages(messages)
if self.model == CoreModelId.llama_guard_3_11b_vision.value:
@ -185,7 +234,7 @@ class LlamaGuardShield(ShieldBase):
# TODO: llama-stack inference protocol has issues with non-streaming inference code
content = ""
async for chunk in await self.inference_api.chat_completion(
model=self.model,
model_id=self.model,
messages=[shield_input_message],
stream=True,
):
@ -195,8 +244,7 @@ class LlamaGuardShield(ShieldBase):
content += event.delta
content = content.strip()
shield_response = self.get_shield_response(content)
return shield_response
return self.get_shield_response(content)
def build_text_shield_input(self, messages: List[Message]) -> UserMessage:
return UserMessage(content=self.build_prompt(messages))
@ -250,19 +298,23 @@ class LlamaGuardShield(ShieldBase):
conversations=conversations_str,
)
def get_shield_response(self, response: str) -> ShieldResponse:
def get_shield_response(self, response: str) -> RunShieldResponse:
response = response.strip()
if response == SAFE_RESPONSE:
return ShieldResponse(is_violation=False)
return RunShieldResponse(violation=None)
unsafe_code = self.check_unsafe_response(response)
if unsafe_code:
unsafe_code_list = unsafe_code.split(",")
if set(unsafe_code_list).issubset(set(self.excluded_categories)):
return ShieldResponse(is_violation=False)
return ShieldResponse(
is_violation=True,
violation_type=unsafe_code,
violation_return_message=CANNED_RESPONSE_TEXT,
return RunShieldResponse(violation=None)
return RunShieldResponse(
violation=SafetyViolation(
violation_level=ViolationLevel.ERROR,
user_message=CANNED_RESPONSE_TEXT,
metadata={"violation_type": unsafe_code},
),
)
raise ValueError(f"Unexpected response: {response}")

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import PromptGuardConfig # noqa: F401
async def get_provider_impl(config: PromptGuardConfig, deps):
from .prompt_guard import PromptGuardSafetyImpl
impl = PromptGuardSafetyImpl(config, deps)
await impl.initialize()
return impl

View file

@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from pydantic import BaseModel, field_validator
class PromptGuardType(Enum):
injection = "injection"
jailbreak = "jailbreak"
class PromptGuardConfig(BaseModel):
guard_type: str = PromptGuardType.injection.value
@classmethod
@field_validator("guard_type")
def validate_guard_type(cls, v):
if v not in [t.value for t in PromptGuardType]:
raise ValueError(f"Unknown prompt guard type: {v}")
return v

View file

@ -0,0 +1,122 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List
import torch
from termcolor import cprint
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from .config import PromptGuardConfig, PromptGuardType
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
def __init__(self, config: PromptGuardConfig, _deps) -> None:
self.config = config
async def initialize(self) -> None:
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
self.shield = PromptGuardShield(model_dir, self.config)
async def shutdown(self) -> None:
pass
async def register_shield(self, shield: Shield) -> None:
if shield.provider_resource_id != PROMPT_GUARD_MODEL:
raise ValueError(
f"Only {PROMPT_GUARD_MODEL} is supported for Prompt Guard. "
)
async def run_shield(
self,
shield_id: str,
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
shield = await self.shield_store.get_shield(shield_id)
if not shield:
raise ValueError(f"Unknown shield {shield_id}")
return await self.shield.run(messages)
class PromptGuardShield:
def __init__(
self,
model_dir: str,
config: PromptGuardConfig,
threshold: float = 0.9,
temperature: float = 1.0,
):
assert (
model_dir is not None
), "Must provide a model directory for prompt injection shield"
if temperature <= 0:
raise ValueError("Temperature must be greater than 0")
self.config = config
self.temperature = temperature
self.threshold = threshold
self.device = "cuda"
# load model and tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForSequenceClassification.from_pretrained(
model_dir, device_map=self.device
)
async def run(self, messages: List[Message]) -> RunShieldResponse:
message = messages[-1]
text = interleaved_text_media_as_str(message.content)
# run model on messages and return response
inputs = self.tokenizer(text, return_tensors="pt")
inputs = {name: tensor.to(self.model.device) for name, tensor in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs[0]
probabilities = torch.softmax(logits / self.temperature, dim=-1)
score_embedded = probabilities[0, 1].item()
score_malicious = probabilities[0, 2].item()
cprint(
f"Ran PromptGuardShield and got Scores: Embedded: {score_embedded}, Malicious: {score_malicious}",
color="magenta",
)
violation = None
if self.config.guard_type == PromptGuardType.injection.value and (
score_embedded + score_malicious > self.threshold
):
violation = SafetyViolation(
violation_level=ViolationLevel.ERROR,
user_message="Sorry, I cannot do this.",
metadata={
"violation_type": f"prompt_injection:embedded={score_embedded},malicious={score_malicious}",
},
)
elif (
self.config.guard_type == PromptGuardType.jailbreak.value
and score_malicious > self.threshold
):
violation = SafetyViolation(
violation_level=ViolationLevel.ERROR,
violation_type=f"prompt_injection:malicious={score_malicious}",
violation_return_message="Sorry, I cannot do this.",
)
return RunShieldResponse(violation=violation)

View file

@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from .config import BasicScoringConfig
async def get_provider_impl(
config: BasicScoringConfig,
deps: Dict[Api, ProviderSpec],
):
from .scoring import BasicScoringImpl
impl = BasicScoringImpl(
config,
deps[Api.datasetio],
deps[Api.datasets],
)
await impl.initialize()
return impl

View file

@ -3,7 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.eval import * # noqa: F401, F403
from pydantic import BaseModel
class MetaReferenceEvalConfig(BaseModel): ...
class BasicScoringConfig(BaseModel): ...

View file

@ -11,55 +11,37 @@ from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.inference.inference import Inference
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.equality_scoring_fn import (
EqualityScoringFn,
)
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.llm_as_judge_scoring_fn import (
LlmAsJudgeScoringFn,
)
from .config import BasicScoringConfig
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.subset_of_scoring_fn import (
SubsetOfScoringFn,
)
from .config import MetaReferenceScoringConfig
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn]
LLM_JUDGE_FNS = [LlmAsJudgeScoringFn]
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn]
class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
class BasicScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
def __init__(
self,
config: MetaReferenceScoringConfig,
config: BasicScoringConfig,
datasetio_api: DatasetIO,
datasets_api: Datasets,
inference_api: Inference,
) -> None:
self.config = config
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
self.inference_api = inference_api
self.scoring_fn_id_impls = {}
async def initialize(self) -> None:
for x in FIXED_FNS:
impl = x()
for fn in FIXED_FNS:
impl = fn()
for fn_defs in impl.get_supported_scoring_fn_defs():
self.scoring_fn_id_impls[fn_defs.identifier] = impl
for x in LLM_JUDGE_FNS:
impl = x(inference_api=self.inference_api)
for fn_defs in impl.get_supported_scoring_fn_defs():
self.scoring_fn_id_impls[fn_defs.identifier] = impl
self.llm_as_judge_fn = impl
async def shutdown(self) -> None: ...
async def list_scoring_functions(self) -> List[ScoringFnDef]:
async def list_scoring_functions(self) -> List[ScoringFn]:
scoring_fn_defs_list = [
fn_def
for impl in self.scoring_fn_id_impls.values()
@ -68,17 +50,16 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
for f in scoring_fn_defs_list:
assert f.identifier.startswith(
"meta-reference"
), "All meta-reference scoring fn must have identifier prefixed with 'meta-reference'! "
"basic"
), "All basic scoring fn must have identifier prefixed with 'basic'! "
return scoring_fn_defs_list
async def register_scoring_function(self, function_def: ScoringFnDef) -> None:
self.llm_as_judge_fn.register_scoring_fn_def(function_def)
self.scoring_fn_id_impls[function_def.identifier] = self.llm_as_judge_fn
async def register_scoring_function(self, function_def: ScoringFn) -> None:
raise NotImplementedError("Register scoring function not implemented yet")
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
@ -97,7 +78,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def score_batch(
self,
dataset_id: str,
scoring_functions: List[str],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
@ -106,7 +87,8 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
rows_in_page=-1,
)
res = await self.score(
input_rows=all_rows.rows, scoring_functions=scoring_functions
input_rows=all_rows.rows,
scoring_functions=scoring_functions,
)
if save_results_dataset:
# TODO: persist and register dataset on to server for reading
@ -118,14 +100,19 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
)
async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
self,
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
) -> ScoreResponse:
res = {}
for scoring_fn_id in scoring_functions:
for scoring_fn_id in scoring_functions.keys():
if scoring_fn_id not in self.scoring_fn_id_impls:
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
score_results = await scoring_fn.score(input_rows, scoring_fn_id)
scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
score_results = await scoring_fn.score(
input_rows, scoring_fn_id, scoring_fn_params
)
agg_results = await scoring_fn.aggregate(score_results)
res[scoring_fn_id] = ScoringResult(
score_rows=score_results,

View file

@ -4,20 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import (
BaseScoringFn,
)
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
aggregate_accuracy,
)
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_accuracy
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.equality import (
equality,
)
from .fn_defs.equality import equality
class EqualityScoringFn(BaseScoringFn):
@ -35,6 +29,7 @@ class EqualityScoringFn(BaseScoringFn):
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = "equality",
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
assert "expected_answer" in input_row, "Expected answer not found in input row."
assert (

View file

@ -5,12 +5,14 @@
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import ScoringFnDef
from llama_stack.apis.scoring_functions import ScoringFn
equality = ScoringFnDef(
identifier="meta-reference::equality",
equality = ScoringFn(
identifier="basic::equality",
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
parameters=[],
params=None,
provider_id="basic",
provider_resource_id="equality",
return_type=NumberType(),
)

View file

@ -0,0 +1,71 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import NumberType
MULTILINGUAL_ANSWER_REGEXES = [
r"Answer\s*:",
r"Answer\s*:", # Korean invisible character
r"উত্তর\s*:",
r"उत्तर\s*:",
r"উত্তরঃ",
r"উত্তর\s*:",
r"Antwort\s*:",
r"답변\s*:",
r"정답\s*:",
r"\s*:",
r"答案\s*",
r"答案\s*:",
r"\s*",
r"\s*:",
r"答复\s*",
r"答曰\s*",
r"الإجابة:",
r"الجواب:",
r"إجابة:",
r"الإجابة النهائية:",
r"الإجابة الصحيحة:",
r"الإجابة الصحيحة هي:",
r"الإجابة هي:",
r"Respuesta\s*:",
r"Risposta\s*:",
r"答え\s*:",
r"答え\s*",
r"回答\s*:",
r"回答\s*",
r"解答\s*:",
r"Jawaban\s*:",
r"Réponse\s*:",
r"Resposta\s*:",
r"Jibu\s*:",
r"Idahun\s*:",
r"Ìdáhùn\s*:",
r"Idáhùn\s*:",
r"Àmọ̀nà\s*:",
r"Àdáhùn\s*:",
r"Ànúgọ\s*:",
r"Àṣàyàn\s*:",
]
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = (
r"(?i){}\s*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[]|[]|[]|[])"
)
regex_parser_multiple_choice_answer = ScoringFn(
identifier="basic::regex_parser_multiple_choice_answer",
description="Extract answer from response matching Answer: [the_answer_letter], and compare with expected result",
return_type=NumberType(),
provider_id="basic",
provider_resource_id="regex-parser-multiple-choice-answer",
params=RegexParserScoringFnParams(
parsing_regexes=[
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x)
for x in MULTILINGUAL_ANSWER_REGEXES
],
),
)

View file

@ -5,12 +5,13 @@
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import ScoringFnDef
from llama_stack.apis.scoring_functions import ScoringFn
subset_of = ScoringFnDef(
identifier="meta-reference::subset_of",
subset_of = ScoringFn(
identifier="basic::subset_of",
description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.",
parameters=[],
return_type=NumberType(),
provider_id="basic",
provider_resource_id="subset-of",
)

View file

@ -0,0 +1,67 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import re
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_accuracy
from .fn_defs.regex_parser_multiple_choice_answer import (
regex_parser_multiple_choice_answer,
)
class RegexParserScoringFn(BaseScoringFn):
"""
A scoring_fn that parses answer from generated response according to context and check match with expected_answer.
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = {
regex_parser_multiple_choice_answer.identifier: regex_parser_multiple_choice_answer,
}
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
assert (
scoring_fn_identifier is not None
), "Scoring function identifier not found."
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
if scoring_params is not None:
fn_def.params = scoring_params
assert (
fn_def.params is not None
and fn_def.params.type == ScoringFnParamsType.regex_parser.value
), f"RegexParserScoringFnParams not found for {fn_def}."
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]
# parse answer according to regex
parsed_answer = None
for regex in fn_def.params.parsing_regexes:
match = re.search(regex, generated_answer)
if match:
parsed_answer = match.group(1)
break
score = 1.0 if parsed_answer and parsed_answer == expected_answer else 0.0
return {
"score": score,
}
async def aggregate(
self, scoring_results: List[ScoringResultRow]
) -> Dict[str, Any]:
return aggregate_accuracy(scoring_results)

View file

@ -4,19 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import (
BaseScoringFn,
)
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
aggregate_accuracy,
)
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_accuracy
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.subset_of import (
subset_of,
)
from .fn_defs.subset_of import subset_of
class SubsetOfScoringFn(BaseScoringFn):
@ -34,6 +28,7 @@ class SubsetOfScoringFn(BaseScoringFn):
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = "subset_of",
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]

View file

@ -16,9 +16,8 @@ from llama_stack.apis.datasets import * # noqa: F403
from autoevals.llm import Factuality
from autoevals.ragas import AnswerCorrectness
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
aggregate_average,
)
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_average
from .config import BraintrustScoringConfig
from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def
@ -49,7 +48,7 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def shutdown(self) -> None: ...
async def list_scoring_functions(self) -> List[ScoringFnDef]:
async def list_scoring_functions(self) -> List[ScoringFn]:
scoring_fn_defs_list = [x for x in self.supported_fn_defs_registry.values()]
for f in scoring_fn_defs_list:
assert f.identifier.startswith(
@ -58,13 +57,13 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
return scoring_fn_defs_list
async def register_scoring_function(self, function_def: ScoringFnDef) -> None:
async def register_scoring_function(self, scoring_fn: ScoringFn) -> None:
raise NotImplementedError(
"Registering scoring function not allowed for braintrust provider"
)
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."

Some files were not shown because too many files have changed in this diff Show more