mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
fix prompt guard (#177)
Several other fixes to configure. Add support for 1b/3b models in ollama.
This commit is contained in:
parent
b9b1e8b08b
commit
210b71b0ba
11 changed files with 50 additions and 45 deletions
|
@ -117,9 +117,9 @@ llama download --source meta --model-id Llama-Guard-3-1B --meta-url META_URL
|
||||||
Essentially, the same commands above work, just replace `--source meta` with `--source huggingface`.
|
Essentially, the same commands above work, just replace `--source meta` with `--source huggingface`.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
llama download --source huggingface --model-id Meta-Llama3.1-8B-Instruct --hf-token <HF_TOKEN>
|
llama download --source huggingface --model-id Llama3.1-8B-Instruct --hf-token <HF_TOKEN>
|
||||||
|
|
||||||
llama download --source huggingface --model-id Meta-Llama3.1-70B-Instruct --hf-token <HF_TOKEN>
|
llama download --source huggingface --model-id Llama3.1-70B-Instruct --hf-token <HF_TOKEN>
|
||||||
|
|
||||||
llama download --source huggingface --model-id Llama-Guard-3-1B --ignore-patterns *original*
|
llama download --source huggingface --model-id Llama-Guard-3-1B --ignore-patterns *original*
|
||||||
llama download --source huggingface --model-id Prompt-Guard-86M --ignore-patterns *original*
|
llama download --source huggingface --model-id Prompt-Guard-86M --ignore-patterns *original*
|
||||||
|
@ -230,7 +230,7 @@ You will be shown a Markdown formatted description of the model interface and ho
|
||||||
- Please see our [Getting Started](getting_started.md) guide for more details on how to build and start a Llama Stack distribution.
|
- Please see our [Getting Started](getting_started.md) guide for more details on how to build and start a Llama Stack distribution.
|
||||||
|
|
||||||
### Step 3.1 Build
|
### Step 3.1 Build
|
||||||
In the following steps, imagine we'll be working with a `Meta-Llama3.1-8B-Instruct` model. We will name our build `8b-instruct` to help us remember the config. We will start build our distribution (in the form of a Conda environment, or Docker image). In this step, we will specify:
|
In the following steps, imagine we'll be working with a `Llama3.1-8B-Instruct` model. We will name our build `8b-instruct` to help us remember the config. We will start build our distribution (in the form of a Conda environment, or Docker image). In this step, we will specify:
|
||||||
- `name`: the name for our distribution (e.g. `8b-instruct`)
|
- `name`: the name for our distribution (e.g. `8b-instruct`)
|
||||||
- `image_type`: our build image type (`conda | docker`)
|
- `image_type`: our build image type (`conda | docker`)
|
||||||
- `distribution_spec`: our distribution specs for specifying API providers
|
- `distribution_spec`: our distribution specs for specifying API providers
|
||||||
|
@ -365,7 +365,7 @@ llama stack configure [ <name> | <docker-image-name> | <path/to/name.build.yaml>
|
||||||
$ llama stack configure ~/.llama/distributions/conda/8b-instruct-build.yaml
|
$ llama stack configure ~/.llama/distributions/conda/8b-instruct-build.yaml
|
||||||
|
|
||||||
Configuring API: inference (meta-reference)
|
Configuring API: inference (meta-reference)
|
||||||
Enter value for model (existing: Meta-Llama3.1-8B-Instruct) (required):
|
Enter value for model (existing: Llama3.1-8B-Instruct) (required):
|
||||||
Enter value for quantization (optional):
|
Enter value for quantization (optional):
|
||||||
Enter value for torch_seed (optional):
|
Enter value for torch_seed (optional):
|
||||||
Enter value for max_seq_len (existing: 4096) (required):
|
Enter value for max_seq_len (existing: 4096) (required):
|
||||||
|
@ -397,7 +397,7 @@ YAML configuration has been written to ~/.llama/builds/conda/8b-instruct-run.yam
|
||||||
After this step is successful, you should be able to find a run configuration spec in `~/.llama/builds/conda/8b-instruct-run.yaml` with the following contents. You may edit this file to change the settings.
|
After this step is successful, you should be able to find a run configuration spec in `~/.llama/builds/conda/8b-instruct-run.yaml` with the following contents. You may edit this file to change the settings.
|
||||||
|
|
||||||
As you can see, we did basic configuration above and configured:
|
As you can see, we did basic configuration above and configured:
|
||||||
- inference to run on model `Meta-Llama3.1-8B-Instruct` (obtained from `llama model list`)
|
- inference to run on model `Llama3.1-8B-Instruct` (obtained from `llama model list`)
|
||||||
- Llama Guard safety shield with model `Llama-Guard-3-1B`
|
- Llama Guard safety shield with model `Llama-Guard-3-1B`
|
||||||
- Prompt Guard safety shield with model `Prompt-Guard-86M`
|
- Prompt Guard safety shield with model `Prompt-Guard-86M`
|
||||||
|
|
||||||
|
|
|
@ -56,7 +56,7 @@ async def run_main(host: str, port: int, stream: bool):
|
||||||
response = await client.list_models()
|
response = await client.list_models()
|
||||||
cprint(f"list_models response={response}", "green")
|
cprint(f"list_models response={response}", "green")
|
||||||
|
|
||||||
response = await client.get_model("Meta-Llama3.1-8B-Instruct")
|
response = await client.get_model("Llama3.1-8B-Instruct")
|
||||||
cprint(f"get_model response={response}", "blue")
|
cprint(f"get_model response={response}", "blue")
|
||||||
|
|
||||||
response = await client.get_model("Llama-Guard-3-1B")
|
response = await client.get_model("Llama-Guard-3-1B")
|
||||||
|
|
|
@ -23,7 +23,7 @@ if [ "$#" -lt 3 ]; then
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
special_pip_deps="$3"
|
special_pip_deps="$4"
|
||||||
|
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
|
|
||||||
|
|
|
@ -6,8 +6,15 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from llama_models.sku_list import (
|
||||||
|
llama3_1_family,
|
||||||
|
llama3_2_family,
|
||||||
|
llama3_family,
|
||||||
|
resolve_model,
|
||||||
|
safety_models,
|
||||||
|
)
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
from prompt_toolkit import prompt
|
from prompt_toolkit import prompt
|
||||||
from prompt_toolkit.validation import Validator
|
from prompt_toolkit.validation import Validator
|
||||||
|
@ -27,6 +34,11 @@ from llama_stack.providers.impls.meta_reference.safety.config import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
ALLOWED_MODELS = (
|
||||||
|
llama3_family() + llama3_1_family() + llama3_2_family() + safety_models()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def make_routing_entry_type(config_class: Any):
|
def make_routing_entry_type(config_class: Any):
|
||||||
class BaseModelWithConfig(BaseModel):
|
class BaseModelWithConfig(BaseModel):
|
||||||
routing_key: str
|
routing_key: str
|
||||||
|
@ -104,7 +116,13 @@ def configure_api_providers(
|
||||||
else:
|
else:
|
||||||
routing_key = prompt(
|
routing_key = prompt(
|
||||||
"> Please enter the supported model your provider has for inference: ",
|
"> Please enter the supported model your provider has for inference: ",
|
||||||
default="Meta-Llama3.1-8B-Instruct",
|
default="Llama3.1-8B-Instruct",
|
||||||
|
validator=Validator.from_callable(
|
||||||
|
lambda x: resolve_model(x) is not None,
|
||||||
|
error_message="Model must be: {}".format(
|
||||||
|
[x.descriptor() for x in ALLOWED_MODELS]
|
||||||
|
),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
routing_entries.append(
|
routing_entries.append(
|
||||||
RoutableProviderConfig(
|
RoutableProviderConfig(
|
||||||
|
|
|
@ -117,10 +117,10 @@ Provider configurations for each of the APIs provided by this package.
|
||||||
description="""
|
description="""
|
||||||
|
|
||||||
E.g. The following is a ProviderRoutingEntry for models:
|
E.g. The following is a ProviderRoutingEntry for models:
|
||||||
- routing_key: Meta-Llama3.1-8B-Instruct
|
- routing_key: Llama3.1-8B-Instruct
|
||||||
provider_type: meta-reference
|
provider_type: meta-reference
|
||||||
config:
|
config:
|
||||||
model: Meta-Llama3.1-8B-Instruct
|
model: Llama3.1-8B-Instruct
|
||||||
quantization: null
|
quantization: null
|
||||||
torch_seed: null
|
torch_seed: null
|
||||||
max_seq_len: 4096
|
max_seq_len: 4096
|
||||||
|
|
|
@ -36,7 +36,7 @@ routing_table:
|
||||||
config:
|
config:
|
||||||
host: localhost
|
host: localhost
|
||||||
port: 6000
|
port: 6000
|
||||||
routing_key: Meta-Llama3.1-8B-Instruct
|
routing_key: Llama3.1-8B-Instruct
|
||||||
safety:
|
safety:
|
||||||
- provider_type: meta-reference
|
- provider_type: meta-reference
|
||||||
config:
|
config:
|
||||||
|
|
|
@ -7,6 +7,10 @@
|
||||||
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaImplConfig(RemoteProviderConfig):
|
||||||
|
port: int = 11434
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: RemoteProviderConfig, _deps):
|
async def get_adapter_impl(config: RemoteProviderConfig, _deps):
|
||||||
from .ollama import OllamaInferenceAdapter
|
from .ollama import OllamaInferenceAdapter
|
||||||
|
|
||||||
|
|
|
@ -23,9 +23,10 @@ from llama_stack.providers.utils.inference.routable import RoutableProviderForMo
|
||||||
# TODO: Eventually this will move to the llama cli model list command
|
# TODO: Eventually this will move to the llama cli model list command
|
||||||
# mapping of Model SKUs to ollama models
|
# mapping of Model SKUs to ollama models
|
||||||
OLLAMA_SUPPORTED_SKUS = {
|
OLLAMA_SUPPORTED_SKUS = {
|
||||||
# "Llama3.1-8B-Instruct": "llama3.1",
|
|
||||||
"Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
|
"Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
|
||||||
"Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
|
"Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
|
||||||
|
"Llama3.2-1B-Instruct": "llama3.2:1b-instruct-fp16",
|
||||||
|
"Llama3.2-3B-Instruct": "llama3.2:3b-instruct-fp16",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -47,10 +47,6 @@ class LlamaGuardShieldConfig(BaseModel):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
class PromptGuardShieldConfig(BaseModel):
|
|
||||||
model: str = "Prompt-Guard-86M"
|
|
||||||
|
|
||||||
|
|
||||||
class SafetyConfig(BaseModel):
|
class SafetyConfig(BaseModel):
|
||||||
llama_guard_shield: Optional[LlamaGuardShieldConfig] = None
|
llama_guard_shield: Optional[LlamaGuardShieldConfig] = None
|
||||||
prompt_guard_shield: Optional[PromptGuardShieldConfig] = None
|
enable_prompt_guard: Optional[bool] = False
|
||||||
|
|
|
@ -6,8 +6,6 @@
|
||||||
|
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from llama_models.sku_list import resolve_model
|
|
||||||
|
|
||||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.apis.safety import * # noqa: F403
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
|
@ -20,21 +18,9 @@ from llama_stack.providers.impls.meta_reference.safety.shields.base import (
|
||||||
|
|
||||||
from .config import MetaReferenceShieldType, SafetyConfig
|
from .config import MetaReferenceShieldType, SafetyConfig
|
||||||
|
|
||||||
from .shields import (
|
from .shields import CodeScannerShield, LlamaGuardShield, ShieldBase
|
||||||
CodeScannerShield,
|
|
||||||
InjectionShield,
|
|
||||||
JailbreakShield,
|
|
||||||
LlamaGuardShield,
|
|
||||||
PromptGuardShield,
|
|
||||||
ShieldBase,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
|
||||||
def resolve_and_get_path(model_name: str) -> str:
|
|
||||||
model = resolve_model(model_name)
|
|
||||||
assert model is not None, f"Could not resolve model {model_name}"
|
|
||||||
model_dir = model_local_dir(model.descriptor())
|
|
||||||
return model_dir
|
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceSafetyImpl(Safety, RoutableProvider):
|
class MetaReferenceSafetyImpl(Safety, RoutableProvider):
|
||||||
|
@ -43,9 +29,10 @@ class MetaReferenceSafetyImpl(Safety, RoutableProvider):
|
||||||
self.inference_api = deps[Api.inference]
|
self.inference_api = deps[Api.inference]
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
shield_cfg = self.config.prompt_guard_shield
|
if self.config.enable_prompt_guard:
|
||||||
if shield_cfg is not None:
|
from .shields import PromptGuardShield
|
||||||
model_dir = resolve_and_get_path(shield_cfg.model)
|
|
||||||
|
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
|
||||||
_ = PromptGuardShield.instance(model_dir)
|
_ = PromptGuardShield.instance(model_dir)
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
|
@ -108,16 +95,14 @@ class MetaReferenceSafetyImpl(Safety, RoutableProvider):
|
||||||
disable_output_check=cfg.disable_output_check,
|
disable_output_check=cfg.disable_output_check,
|
||||||
)
|
)
|
||||||
elif typ == MetaReferenceShieldType.jailbreak_shield:
|
elif typ == MetaReferenceShieldType.jailbreak_shield:
|
||||||
assert (
|
from .shields import JailbreakShield
|
||||||
cfg.prompt_guard_shield is not None
|
|
||||||
), "Cannot use Jailbreak Shield since Prompt Guard not present in config"
|
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
|
||||||
model_dir = resolve_and_get_path(cfg.prompt_guard_shield.model)
|
|
||||||
return JailbreakShield.instance(model_dir)
|
return JailbreakShield.instance(model_dir)
|
||||||
elif typ == MetaReferenceShieldType.injection_shield:
|
elif typ == MetaReferenceShieldType.injection_shield:
|
||||||
assert (
|
from .shields import InjectionShield
|
||||||
cfg.prompt_guard_shield is not None
|
|
||||||
), "Cannot use PromptGuardShield since not present in config"
|
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
|
||||||
model_dir = resolve_and_get_path(cfg.prompt_guard_shield.model)
|
|
||||||
return InjectionShield.instance(model_dir)
|
return InjectionShield.instance(model_dir)
|
||||||
elif typ == MetaReferenceShieldType.code_scanner_guard:
|
elif typ == MetaReferenceShieldType.code_scanner_guard:
|
||||||
return CodeScannerShield.instance()
|
return CodeScannerShield.instance()
|
||||||
|
|
|
@ -41,6 +41,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
adapter_type="ollama",
|
adapter_type="ollama",
|
||||||
pip_packages=["ollama"],
|
pip_packages=["ollama"],
|
||||||
|
config_class="llama_stack.providers.adapters.inference.ollama.OllamaImplConfig",
|
||||||
module="llama_stack.providers.adapters.inference.ollama",
|
module="llama_stack.providers.adapters.inference.ollama",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue