Make Llama Guard 1B the default

This commit is contained in:
Ashwin Bharambe 2024-10-02 09:48:26 -07:00
parent cc5029a716
commit 4a75d922a9
5 changed files with 14 additions and 10 deletions

View file

@ -59,7 +59,7 @@ async def run_main(host: str, port: int, stream: bool):
response = await client.get_model("Meta-Llama3.1-8B-Instruct")
cprint(f"get_model response={response}", "blue")
response = await client.get_model("Llama-Guard-3-8B")
response = await client.get_model("Llama-Guard-3-1B")
cprint(f"get_model response={response}", "red")

View file

@ -20,7 +20,7 @@ class MetaReferenceShieldType(Enum):
class LlamaGuardShieldConfig(BaseModel):
model: str = "Llama-Guard-3-8B"
model: str = "Llama-Guard-3-1B"
excluded_categories: List[str] = []
disable_input_check: bool = False
disable_output_check: bool = False
@ -33,7 +33,11 @@ class LlamaGuardShieldConfig(BaseModel):
for m in safety_models()
if (
m.core_model_id
in {CoreModelId.llama_guard_3_8b, CoreModelId.llama_guard_3_11b_vision}
in {
CoreModelId.llama_guard_3_8b,
CoreModelId.llama_guard_3_1b,
CoreModelId.llama_guard_3_11b_vision,
}
)
]
if model not in permitted_models: