forked from phoenix-oss/llama-stack-mirror
fix prompt guard (#177)
Several other fixes to configure. Add support for 1b/3b models in ollama.
This commit is contained in:
parent
b9b1e8b08b
commit
210b71b0ba
11 changed files with 50 additions and 45 deletions
|
@ -6,8 +6,15 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from llama_models.sku_list import (
|
||||
llama3_1_family,
|
||||
llama3_2_family,
|
||||
llama3_family,
|
||||
resolve_model,
|
||||
safety_models,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from prompt_toolkit import prompt
|
||||
from prompt_toolkit.validation import Validator
|
||||
|
@ -27,6 +34,11 @@ from llama_stack.providers.impls.meta_reference.safety.config import (
|
|||
)
|
||||
|
||||
|
||||
ALLOWED_MODELS = (
|
||||
llama3_family() + llama3_1_family() + llama3_2_family() + safety_models()
|
||||
)
|
||||
|
||||
|
||||
def make_routing_entry_type(config_class: Any):
|
||||
class BaseModelWithConfig(BaseModel):
|
||||
routing_key: str
|
||||
|
@ -104,7 +116,13 @@ def configure_api_providers(
|
|||
else:
|
||||
routing_key = prompt(
|
||||
"> Please enter the supported model your provider has for inference: ",
|
||||
default="Meta-Llama3.1-8B-Instruct",
|
||||
default="Llama3.1-8B-Instruct",
|
||||
validator=Validator.from_callable(
|
||||
lambda x: resolve_model(x) is not None,
|
||||
error_message="Model must be: {}".format(
|
||||
[x.descriptor() for x in ALLOWED_MODELS]
|
||||
),
|
||||
),
|
||||
)
|
||||
routing_entries.append(
|
||||
RoutableProviderConfig(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue