mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-06 04:34:57 +00:00
Merge branch 'main' into models_api_2
This commit is contained in:
commit
df33e6fbec
30 changed files with 9689 additions and 113 deletions
|
@ -7,15 +7,14 @@
|
|||
from typing import List
|
||||
|
||||
from llama_models.llama3.api.datatypes import Message, Role, UserMessage
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.safety import (
|
||||
OnViolationAction,
|
||||
RunShieldRequest,
|
||||
Safety,
|
||||
ShieldDefinition,
|
||||
ShieldResponse,
|
||||
)
|
||||
from termcolor import cprint
|
||||
|
||||
|
||||
class SafetyException(Exception): # noqa: N818
|
||||
|
@ -45,10 +44,8 @@ class ShieldRunnerMixin:
|
|||
messages[0] = UserMessage(content=messages[0].content)
|
||||
|
||||
res = await self.safety_api.run_shields(
|
||||
RunShieldRequest(
|
||||
messages=messages,
|
||||
shields=shields,
|
||||
)
|
||||
messages=messages,
|
||||
shields=shields,
|
||||
)
|
||||
|
||||
results = res.responses
|
||||
|
|
|
@ -11,10 +11,10 @@ from llama_models.datatypes import ModelFamily
|
|||
from llama_models.schema_utils import json_schema_type
|
||||
from llama_models.sku_list import all_registered_models, resolve_model
|
||||
|
||||
from llama_stack.apis.inference import QuantizationConfig
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from llama_stack.apis.inference import QuantizationConfig
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MetaReferenceImplConfig(BaseModel):
|
||||
|
@ -24,7 +24,7 @@ class MetaReferenceImplConfig(BaseModel):
|
|||
)
|
||||
quantization: Optional[QuantizationConfig] = None
|
||||
torch_seed: Optional[int] = None
|
||||
max_seq_len: int
|
||||
max_seq_len: int = 4096
|
||||
max_batch_size: int = 1
|
||||
|
||||
@field_validator("model")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue