Merge branch 'main' into models_api_2

This commit is contained in:
Xi Yan 2024-09-18 22:36:48 -07:00 committed by GitHub
commit df33e6fbec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
30 changed files with 9689 additions and 113 deletions

View file

@ -7,15 +7,14 @@
from typing import List
from llama_models.llama3.api.datatypes import Message, Role, UserMessage
from termcolor import cprint
from llama_stack.apis.safety import (
OnViolationAction,
RunShieldRequest,
Safety,
ShieldDefinition,
ShieldResponse,
)
from termcolor import cprint
class SafetyException(Exception): # noqa: N818
@ -45,10 +44,8 @@ class ShieldRunnerMixin:
messages[0] = UserMessage(content=messages[0].content)
res = await self.safety_api.run_shields(
RunShieldRequest(
messages=messages,
shields=shields,
)
messages=messages,
shields=shields,
)
results = res.responses

View file

@ -11,10 +11,10 @@ from llama_models.datatypes import ModelFamily
from llama_models.schema_utils import json_schema_type
from llama_models.sku_list import all_registered_models, resolve_model
from llama_stack.apis.inference import QuantizationConfig
from pydantic import BaseModel, Field, field_validator
from llama_stack.apis.inference import QuantizationConfig
@json_schema_type
class MetaReferenceImplConfig(BaseModel):
@ -24,7 +24,7 @@ class MetaReferenceImplConfig(BaseModel):
)
quantization: Optional[QuantizationConfig] = None
torch_seed: Optional[int] = None
max_seq_len: int
max_seq_len: int = 4096
max_batch_size: int = 1
@field_validator("model")