Use inference APIs for executing Llama Guard (#121)

We should use Inference APIs to execute Llama Guard instead of directly needing to use HuggingFace modeling related code. The actual inference consideration is handled by Inference.
This commit is contained in:
Ashwin Bharambe 2024-09-28 15:40:06 -07:00 committed by GitHub
parent 6236634d84
commit 0a3999a9a4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 167 additions and 204 deletions

View file

@ -7,12 +7,13 @@
from typing import Optional
from llama_models.datatypes import * # noqa: F403
from llama_models.sku_list import all_registered_models, resolve_model
from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import * # noqa: F401, F403
from pydantic import BaseModel, Field, field_validator
from llama_stack.providers.utils.inference import supported_inference_models
class MetaReferenceImplConfig(BaseModel):
model: str = Field(
@ -27,12 +28,7 @@ class MetaReferenceImplConfig(BaseModel):
@field_validator("model")
@classmethod
def validate_model(cls, model: str) -> str:
permitted_models = [
m.descriptor()
for m in all_registered_models()
if m.model_family in {ModelFamily.llama3_1, ModelFamily.llama3_2}
or m.core_model_id == CoreModelId.llama_guard_3_8b
]
permitted_models = supported_inference_models()
if model not in permitted_models:
model_list = "\n\t".join(permitted_models)
raise ValueError(