fix prompt guard (#177)

Several other fixes to configure. Add support for 1b/3b models in ollama.
This commit is contained in:
Ashwin Bharambe 2024-10-03 11:07:53 -07:00 committed by GitHub
parent b9b1e8b08b
commit 210b71b0ba
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 50 additions and 45 deletions

View file

@ -23,7 +23,7 @@ if [ "$#" -lt 3 ]; then
exit 1
fi
special_pip_deps="$3"
special_pip_deps="$4"
set -euo pipefail

View file

@ -6,8 +6,15 @@
from typing import Any
from pydantic import BaseModel
from llama_models.sku_list import (
llama3_1_family,
llama3_2_family,
llama3_family,
resolve_model,
safety_models,
)
from pydantic import BaseModel
from llama_stack.distribution.datatypes import * # noqa: F403
from prompt_toolkit import prompt
from prompt_toolkit.validation import Validator
@ -27,6 +34,11 @@ from llama_stack.providers.impls.meta_reference.safety.config import (
)
ALLOWED_MODELS = (
llama3_family() + llama3_1_family() + llama3_2_family() + safety_models()
)
def make_routing_entry_type(config_class: Any):
class BaseModelWithConfig(BaseModel):
routing_key: str
@ -104,7 +116,13 @@ def configure_api_providers(
else:
routing_key = prompt(
"> Please enter the supported model your provider has for inference: ",
default="Meta-Llama3.1-8B-Instruct",
default="Llama3.1-8B-Instruct",
validator=Validator.from_callable(
lambda x: resolve_model(x) is not None,
error_message="Model must be: {}".format(
[x.descriptor() for x in ALLOWED_MODELS]
),
),
)
routing_entries.append(
RoutableProviderConfig(

View file

@ -117,10 +117,10 @@ Provider configurations for each of the APIs provided by this package.
description="""
E.g. The following is a ProviderRoutingEntry for models:
- routing_key: Meta-Llama3.1-8B-Instruct
- routing_key: Llama3.1-8B-Instruct
provider_type: meta-reference
config:
model: Meta-Llama3.1-8B-Instruct
model: Llama3.1-8B-Instruct
quantization: null
torch_seed: null
max_seq_len: 4096

View file

@ -36,7 +36,7 @@ routing_table:
config:
host: localhost
port: 6000
routing_key: Meta-Llama3.1-8B-Instruct
routing_key: Llama3.1-8B-Instruct
safety:
- provider_type: meta-reference
config: