mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
formatting
This commit is contained in:
parent
94dfa293a6
commit
b6ccaf1778
33 changed files with 110 additions and 97 deletions
|
@ -63,9 +63,9 @@ class ShieldCallStep(StepCommon):
|
|||
|
||||
@json_schema_type
|
||||
class MemoryRetrievalStep(StepCommon):
|
||||
step_type: Literal[StepType.memory_retrieval.value] = (
|
||||
step_type: Literal[
|
||||
StepType.memory_retrieval.value
|
||||
)
|
||||
] = StepType.memory_retrieval.value
|
||||
memory_bank_ids: List[str]
|
||||
documents: List[MemoryBankDocument]
|
||||
scores: List[float]
|
||||
|
@ -140,9 +140,9 @@ class AgenticSystemTurnResponseEventType(Enum):
|
|||
|
||||
@json_schema_type
|
||||
class AgenticSystemTurnResponseStepStartPayload(BaseModel):
|
||||
event_type: Literal[AgenticSystemTurnResponseEventType.step_start.value] = (
|
||||
event_type: Literal[
|
||||
AgenticSystemTurnResponseEventType.step_start.value
|
||||
)
|
||||
] = AgenticSystemTurnResponseEventType.step_start.value
|
||||
step_type: StepType
|
||||
step_id: str
|
||||
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
@ -150,9 +150,9 @@ class AgenticSystemTurnResponseStepStartPayload(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class AgenticSystemTurnResponseStepCompletePayload(BaseModel):
|
||||
event_type: Literal[AgenticSystemTurnResponseEventType.step_complete.value] = (
|
||||
event_type: Literal[
|
||||
AgenticSystemTurnResponseEventType.step_complete.value
|
||||
)
|
||||
] = AgenticSystemTurnResponseEventType.step_complete.value
|
||||
step_type: StepType
|
||||
step_details: Step
|
||||
|
||||
|
@ -161,9 +161,9 @@ class AgenticSystemTurnResponseStepCompletePayload(BaseModel):
|
|||
class AgenticSystemTurnResponseStepProgressPayload(BaseModel):
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
event_type: Literal[AgenticSystemTurnResponseEventType.step_progress.value] = (
|
||||
event_type: Literal[
|
||||
AgenticSystemTurnResponseEventType.step_progress.value
|
||||
)
|
||||
] = AgenticSystemTurnResponseEventType.step_progress.value
|
||||
step_type: StepType
|
||||
step_id: str
|
||||
|
||||
|
@ -174,17 +174,17 @@ class AgenticSystemTurnResponseStepProgressPayload(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class AgenticSystemTurnResponseTurnStartPayload(BaseModel):
|
||||
event_type: Literal[AgenticSystemTurnResponseEventType.turn_start.value] = (
|
||||
event_type: Literal[
|
||||
AgenticSystemTurnResponseEventType.turn_start.value
|
||||
)
|
||||
] = AgenticSystemTurnResponseEventType.turn_start.value
|
||||
turn_id: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgenticSystemTurnResponseTurnCompletePayload(BaseModel):
|
||||
event_type: Literal[AgenticSystemTurnResponseEventType.turn_complete.value] = (
|
||||
event_type: Literal[
|
||||
AgenticSystemTurnResponseEventType.turn_complete.value
|
||||
)
|
||||
] = AgenticSystemTurnResponseEventType.turn_complete.value
|
||||
turn: Turn
|
||||
|
||||
|
||||
|
|
|
@ -63,36 +63,40 @@ class AgenticSystemStepResponse(BaseModel):
|
|||
|
||||
|
||||
class AgenticSystem(Protocol):
|
||||
|
||||
@webmethod(route="/agentic_system/create")
|
||||
async def create_agentic_system(
|
||||
self,
|
||||
request: AgenticSystemCreateRequest,
|
||||
) -> AgenticSystemCreateResponse: ...
|
||||
) -> AgenticSystemCreateResponse:
|
||||
...
|
||||
|
||||
@webmethod(route="/agentic_system/turn/create")
|
||||
async def create_agentic_system_turn(
|
||||
self,
|
||||
request: AgenticSystemTurnCreateRequest,
|
||||
) -> AgenticSystemTurnResponseStreamChunk: ...
|
||||
) -> AgenticSystemTurnResponseStreamChunk:
|
||||
...
|
||||
|
||||
@webmethod(route="/agentic_system/turn/get")
|
||||
async def get_agentic_system_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
turn_id: str,
|
||||
) -> Turn: ...
|
||||
) -> Turn:
|
||||
...
|
||||
|
||||
@webmethod(route="/agentic_system/step/get")
|
||||
async def get_agentic_system_step(
|
||||
self, agent_id: str, turn_id: str, step_id: str
|
||||
) -> AgenticSystemStepResponse: ...
|
||||
) -> AgenticSystemStepResponse:
|
||||
...
|
||||
|
||||
@webmethod(route="/agentic_system/session/create")
|
||||
async def create_agentic_system_session(
|
||||
self,
|
||||
request: AgenticSystemSessionCreateRequest,
|
||||
) -> AgenticSystemSessionCreateResponse: ...
|
||||
) -> AgenticSystemSessionCreateResponse:
|
||||
...
|
||||
|
||||
@webmethod(route="/agentic_system/memory_bank/attach")
|
||||
async def attach_memory_bank_to_agentic_system(
|
||||
|
@ -100,7 +104,8 @@ class AgenticSystem(Protocol):
|
|||
agent_id: str,
|
||||
session_id: str,
|
||||
memory_bank_ids: List[str],
|
||||
) -> None: ...
|
||||
) -> None:
|
||||
...
|
||||
|
||||
@webmethod(route="/agentic_system/memory_bank/detach")
|
||||
async def detach_memory_bank_from_agentic_system(
|
||||
|
@ -108,7 +113,8 @@ class AgenticSystem(Protocol):
|
|||
agent_id: str,
|
||||
session_id: str,
|
||||
memory_bank_ids: List[str],
|
||||
) -> None: ...
|
||||
) -> None:
|
||||
...
|
||||
|
||||
@webmethod(route="/agentic_system/session/get")
|
||||
async def get_agentic_system_session(
|
||||
|
@ -116,15 +122,18 @@ class AgenticSystem(Protocol):
|
|||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_ids: Optional[List[str]] = None,
|
||||
) -> Session: ...
|
||||
) -> Session:
|
||||
...
|
||||
|
||||
@webmethod(route="/agentic_system/session/delete")
|
||||
async def delete_agentic_system_session(
|
||||
self, agent_id: str, session_id: str
|
||||
) -> None: ...
|
||||
) -> None:
|
||||
...
|
||||
|
||||
@webmethod(route="/agentic_system/delete")
|
||||
async def delete_agentic_system(
|
||||
self,
|
||||
agent_id: str,
|
||||
) -> None: ...
|
||||
) -> None:
|
||||
...
|
||||
|
|
|
@ -246,7 +246,6 @@ class AgentInstance(ShieldRunnerMixin):
|
|||
await self.run_shields(messages, shields)
|
||||
|
||||
except SafetyException as e:
|
||||
|
||||
yield AgenticSystemTurnResponseStreamChunk(
|
||||
event=AgenticSystemTurnResponseEvent(
|
||||
payload=AgenticSystemTurnResponseStepCompletePayload(
|
||||
|
|
|
@ -23,7 +23,6 @@ class SafetyException(Exception): # noqa: N818
|
|||
|
||||
|
||||
class ShieldRunnerMixin:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
safety_api: Safety,
|
||||
|
|
|
@ -11,7 +11,6 @@ from llama_toolchain.inference.api import Message
|
|||
|
||||
|
||||
class BaseTool(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def get_name(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -66,7 +66,6 @@ class SingleMessageBuiltinTool(BaseTool):
|
|||
|
||||
|
||||
class PhotogenTool(SingleMessageBuiltinTool):
|
||||
|
||||
def __init__(self, dump_dir: str) -> None:
|
||||
self.dump_dir = dump_dir
|
||||
|
||||
|
@ -87,7 +86,6 @@ class PhotogenTool(SingleMessageBuiltinTool):
|
|||
|
||||
|
||||
class BraveSearchTool(SingleMessageBuiltinTool):
|
||||
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self.api_key = api_key
|
||||
|
||||
|
@ -204,7 +202,6 @@ class BraveSearchTool(SingleMessageBuiltinTool):
|
|||
|
||||
|
||||
class WolframAlphaTool(SingleMessageBuiltinTool):
|
||||
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self.api_key = api_key
|
||||
self.url = "https://api.wolframalpha.com/v2/query"
|
||||
|
@ -287,7 +284,6 @@ class WolframAlphaTool(SingleMessageBuiltinTool):
|
|||
|
||||
|
||||
class CodeInterpreterTool(BaseTool):
|
||||
|
||||
def __init__(self) -> None:
|
||||
ctx = CodeExecutionContext(
|
||||
matplotlib_dump_dir=f"/tmp/{os.environ['USER']}_matplotlib_dump",
|
||||
|
|
|
@ -27,7 +27,6 @@ from llama_toolchain.safety.api.datatypes import BuiltinShield, ShieldDefinition
|
|||
|
||||
|
||||
class AgenticSystemClientWrapper:
|
||||
|
||||
def __init__(self, api, system_id, custom_tools):
|
||||
self.api = api
|
||||
self.system_id = system_id
|
||||
|
|
|
@ -10,7 +10,6 @@ from llama_toolchain.cli.subcommand import Subcommand
|
|||
|
||||
|
||||
class DistributionCreate(Subcommand):
|
||||
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
super().__init__()
|
||||
self.parser = subparsers.add_parser(
|
||||
|
|
|
@ -16,7 +16,6 @@ from .start import DistributionStart
|
|||
|
||||
|
||||
class DistributionParser(Subcommand):
|
||||
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
super().__init__()
|
||||
self.parser = subparsers.add_parser(
|
||||
|
|
|
@ -11,7 +11,6 @@ from llama_toolchain.cli.subcommand import Subcommand
|
|||
|
||||
|
||||
class DistributionList(Subcommand):
|
||||
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
super().__init__()
|
||||
self.parser = subparsers.add_parser(
|
||||
|
|
|
@ -14,7 +14,6 @@ from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR
|
|||
|
||||
|
||||
class DistributionStart(Subcommand):
|
||||
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
super().__init__()
|
||||
self.parser = subparsers.add_parser(
|
||||
|
|
|
@ -27,7 +27,11 @@ class ModelList(Subcommand):
|
|||
self.parser.set_defaults(func=self._run_model_list_cmd)
|
||||
|
||||
def _add_arguments(self):
|
||||
self.parser.add_argument('--show-all', action='store_true', help='Show all models (not just defaults)')
|
||||
self.parser.add_argument(
|
||||
"--show-all",
|
||||
action="store_true",
|
||||
help="Show all models (not just defaults)",
|
||||
)
|
||||
|
||||
def _run_model_list_cmd(self, args: argparse.Namespace) -> None:
|
||||
headers = [
|
||||
|
|
|
@ -26,16 +26,19 @@ class Datasets(Protocol):
|
|||
def create_dataset(
|
||||
self,
|
||||
request: CreateDatasetRequest,
|
||||
) -> None: ...
|
||||
) -> None:
|
||||
...
|
||||
|
||||
@webmethod(route="/datasets/get")
|
||||
def get_dataset(
|
||||
self,
|
||||
dataset_uuid: str,
|
||||
) -> TrainEvalDataset: ...
|
||||
) -> TrainEvalDataset:
|
||||
...
|
||||
|
||||
@webmethod(route="/datasets/delete")
|
||||
def delete_dataset(
|
||||
self,
|
||||
dataset_uuid: str,
|
||||
) -> None: ...
|
||||
) -> None:
|
||||
...
|
||||
|
|
|
@ -217,7 +217,6 @@ def create_dynamic_typed_route(func: Any):
|
|||
|
||||
|
||||
def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
|
||||
|
||||
by_id = {x.api: x for x in providers}
|
||||
|
||||
def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]):
|
||||
|
|
|
@ -26,10 +26,8 @@ class SummarizationMetric(Enum):
|
|||
|
||||
|
||||
class EvaluationJob(BaseModel):
|
||||
|
||||
job_uuid: str
|
||||
|
||||
|
||||
class EvaluationJobLogStream(BaseModel):
|
||||
|
||||
job_uuid: str
|
||||
|
|
|
@ -48,7 +48,6 @@ class EvaluateSummarizationRequest(EvaluateTaskRequestCommon):
|
|||
|
||||
|
||||
class EvaluationJobStatusResponse(BaseModel):
|
||||
|
||||
job_uuid: str
|
||||
|
||||
|
||||
|
@ -64,36 +63,42 @@ class Evaluations(Protocol):
|
|||
def post_evaluate_text_generation(
|
||||
self,
|
||||
request: EvaluateTextGenerationRequest,
|
||||
) -> EvaluationJob: ...
|
||||
) -> EvaluationJob:
|
||||
...
|
||||
|
||||
@webmethod(route="/evaluate/question_answering/")
|
||||
def post_evaluate_question_answering(
|
||||
self,
|
||||
request: EvaluateQuestionAnsweringRequest,
|
||||
) -> EvaluationJob: ...
|
||||
) -> EvaluationJob:
|
||||
...
|
||||
|
||||
@webmethod(route="/evaluate/summarization/")
|
||||
def post_evaluate_summarization(
|
||||
self,
|
||||
request: EvaluateSummarizationRequest,
|
||||
) -> EvaluationJob: ...
|
||||
) -> EvaluationJob:
|
||||
...
|
||||
|
||||
@webmethod(route="/evaluate/jobs")
|
||||
def get_evaluation_jobs(self) -> List[EvaluationJob]: ...
|
||||
def get_evaluation_jobs(self) -> List[EvaluationJob]:
|
||||
...
|
||||
|
||||
@webmethod(route="/evaluate/job/status")
|
||||
def get_evaluation_job_status(
|
||||
self, job_uuid: str
|
||||
) -> EvaluationJobStatusResponse: ...
|
||||
def get_evaluation_job_status(self, job_uuid: str) -> EvaluationJobStatusResponse:
|
||||
...
|
||||
|
||||
# sends SSE stream of logs
|
||||
@webmethod(route="/evaluate/job/logs")
|
||||
def get_evaluation_job_logstream(self, job_uuid: str) -> EvaluationJobLogStream: ...
|
||||
def get_evaluation_job_logstream(self, job_uuid: str) -> EvaluationJobLogStream:
|
||||
...
|
||||
|
||||
@webmethod(route="/evaluate/job/cancel")
|
||||
def cancel_evaluation_job(self, job_uuid: str) -> None: ...
|
||||
def cancel_evaluation_job(self, job_uuid: str) -> None:
|
||||
...
|
||||
|
||||
@webmethod(route="/evaluate/job/artifacts")
|
||||
def get_evaluation_job_artifacts(
|
||||
self, job_uuid: str
|
||||
) -> EvaluationJobArtifactsResponse: ...
|
||||
) -> EvaluationJobArtifactsResponse:
|
||||
...
|
||||
|
|
|
@ -97,27 +97,30 @@ class BatchChatCompletionResponse(BaseModel):
|
|||
|
||||
|
||||
class Inference(Protocol):
|
||||
|
||||
@webmethod(route="/inference/completion")
|
||||
async def completion(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||
...
|
||||
|
||||
@webmethod(route="/inference/chat_completion")
|
||||
async def chat_completion(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ...
|
||||
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]:
|
||||
...
|
||||
|
||||
@webmethod(route="/inference/batch_completion")
|
||||
async def batch_completion(
|
||||
self,
|
||||
request: BatchCompletionRequest,
|
||||
) -> BatchCompletionResponse: ...
|
||||
) -> BatchCompletionResponse:
|
||||
...
|
||||
|
||||
@webmethod(route="/inference/batch_chat_completion")
|
||||
async def batch_chat_completion(
|
||||
self,
|
||||
request: BatchChatCompletionRequest,
|
||||
) -> BatchChatCompletionResponse: ...
|
||||
) -> BatchChatCompletionResponse:
|
||||
...
|
||||
|
|
|
@ -7,8 +7,8 @@
|
|||
from termcolor import cprint
|
||||
|
||||
from llama_toolchain.inference.api import (
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -45,7 +45,6 @@ SEMAPHORE = asyncio.Semaphore(1)
|
|||
|
||||
|
||||
class MetaReferenceInferenceImpl(Inference):
|
||||
|
||||
def __init__(self, config: MetaReferenceImplConfig) -> None:
|
||||
self.config = config
|
||||
model = resolve_model(config.model)
|
||||
|
|
|
@ -54,7 +54,6 @@ async def get_provider_impl(
|
|||
|
||||
|
||||
class OllamaInference(Inference):
|
||||
|
||||
def __init__(self, config: OllamaImplConfig) -> None:
|
||||
self.config = config
|
||||
|
||||
|
@ -66,7 +65,9 @@ class OllamaInference(Inference):
|
|||
try:
|
||||
await self.client.ps()
|
||||
except httpx.ConnectError:
|
||||
raise RuntimeError("Ollama Server is not running, start it using `ollama serve` in a separate terminal")
|
||||
raise RuntimeError(
|
||||
"Ollama Server is not running, start it using `ollama serve` in a separate terminal"
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
|
|
@ -18,44 +18,52 @@ class MemoryBanks(Protocol):
|
|||
bank_id: str,
|
||||
bank_name: str,
|
||||
documents: List[MemoryBankDocument],
|
||||
) -> None: ...
|
||||
) -> None:
|
||||
...
|
||||
|
||||
@webmethod(route="/memory_banks/list")
|
||||
def get_memory_banks(self) -> List[MemoryBank]: ...
|
||||
def get_memory_banks(self) -> List[MemoryBank]:
|
||||
...
|
||||
|
||||
@webmethod(route="/memory_banks/get")
|
||||
def get_memory_bank(self, bank_id: str) -> List[MemoryBank]: ...
|
||||
def get_memory_bank(self, bank_id: str) -> List[MemoryBank]:
|
||||
...
|
||||
|
||||
@webmethod(route="/memory_banks/drop")
|
||||
def delete_memory_bank(
|
||||
self,
|
||||
bank_id: str,
|
||||
) -> str: ...
|
||||
) -> str:
|
||||
...
|
||||
|
||||
@webmethod(route="/memory_bank/insert")
|
||||
def post_insert_memory_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
documents: List[MemoryBankDocument],
|
||||
) -> None: ...
|
||||
) -> None:
|
||||
...
|
||||
|
||||
@webmethod(route="/memory_bank/update")
|
||||
def post_update_memory_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
documents: List[MemoryBankDocument],
|
||||
) -> None: ...
|
||||
) -> None:
|
||||
...
|
||||
|
||||
@webmethod(route="/memory_bank/get")
|
||||
def get_memory_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
document_uuids: List[str],
|
||||
) -> List[MemoryBankDocument]: ...
|
||||
) -> List[MemoryBankDocument]:
|
||||
...
|
||||
|
||||
@webmethod(route="/memory_bank/delete")
|
||||
def delete_memory_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
document_uuids: List[str],
|
||||
) -> List[str]: ...
|
||||
) -> List[str]:
|
||||
...
|
||||
|
|
|
@ -11,4 +11,5 @@ from llama_models.schema_utils import webmethod # noqa: F401
|
|||
from pydantic import BaseModel # noqa: F401
|
||||
|
||||
|
||||
class Models(Protocol): ...
|
||||
class Models(Protocol):
|
||||
...
|
||||
|
|
|
@ -64,7 +64,6 @@ class PostTrainingRLHFRequest(BaseModel):
|
|||
|
||||
|
||||
class PostTrainingJob(BaseModel):
|
||||
|
||||
job_uuid: str
|
||||
|
||||
|
||||
|
@ -99,30 +98,35 @@ class PostTraining(Protocol):
|
|||
def post_supervised_fine_tune(
|
||||
self,
|
||||
request: PostTrainingSFTRequest,
|
||||
) -> PostTrainingJob: ...
|
||||
) -> PostTrainingJob:
|
||||
...
|
||||
|
||||
@webmethod(route="/post_training/preference_optimize")
|
||||
def post_preference_optimize(
|
||||
self,
|
||||
request: PostTrainingRLHFRequest,
|
||||
) -> PostTrainingJob: ...
|
||||
) -> PostTrainingJob:
|
||||
...
|
||||
|
||||
@webmethod(route="/post_training/jobs")
|
||||
def get_training_jobs(self) -> List[PostTrainingJob]: ...
|
||||
def get_training_jobs(self) -> List[PostTrainingJob]:
|
||||
...
|
||||
|
||||
# sends SSE stream of logs
|
||||
@webmethod(route="/post_training/job/logs")
|
||||
def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream: ...
|
||||
def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream:
|
||||
...
|
||||
|
||||
@webmethod(route="/post_training/job/status")
|
||||
def get_training_job_status(
|
||||
self, job_uuid: str
|
||||
) -> PostTrainingJobStatusResponse: ...
|
||||
def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse:
|
||||
...
|
||||
|
||||
@webmethod(route="/post_training/job/cancel")
|
||||
def cancel_training_job(self, job_uuid: str) -> None: ...
|
||||
def cancel_training_job(self, job_uuid: str) -> None:
|
||||
...
|
||||
|
||||
@webmethod(route="/post_training/job/artifacts")
|
||||
def get_training_job_artifacts(
|
||||
self, job_uuid: str
|
||||
) -> PostTrainingJobArtifactsResponse: ...
|
||||
) -> PostTrainingJobArtifactsResponse:
|
||||
...
|
||||
|
|
|
@ -30,4 +30,5 @@ class RewardScoring(Protocol):
|
|||
def post_score(
|
||||
self,
|
||||
request: RewardScoringRequest,
|
||||
) -> Union[RewardScoringResponse]: ...
|
||||
) -> Union[RewardScoringResponse]:
|
||||
...
|
||||
|
|
|
@ -25,9 +25,9 @@ class RunShieldResponse(BaseModel):
|
|||
|
||||
|
||||
class Safety(Protocol):
|
||||
|
||||
@webmethod(route="/safety/run_shields")
|
||||
async def run_shields(
|
||||
self,
|
||||
request: RunShieldRequest,
|
||||
) -> RunShieldResponse: ...
|
||||
) -> RunShieldResponse:
|
||||
...
|
||||
|
|
|
@ -41,7 +41,6 @@ def resolve_and_get_path(model_name: str) -> str:
|
|||
|
||||
|
||||
class MetaReferenceSafetyImpl(Safety):
|
||||
|
||||
def __init__(self, config: SafetyConfig) -> None:
|
||||
self.config = config
|
||||
|
||||
|
|
|
@ -14,7 +14,6 @@ CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
|||
|
||||
|
||||
class ShieldBase(ABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||
|
@ -60,7 +59,6 @@ class TextShield(ShieldBase):
|
|||
|
||||
|
||||
class DummyShield(TextShield):
|
||||
|
||||
def get_shield_type(self) -> ShieldType:
|
||||
return "dummy"
|
||||
|
||||
|
|
|
@ -12,7 +12,6 @@ from llama_toolchain.safety.api.datatypes import * # noqa: F403
|
|||
|
||||
|
||||
class CodeScannerShield(TextShield):
|
||||
|
||||
def get_shield_type(self) -> ShieldType:
|
||||
return BuiltinShield.code_scanner_guard
|
||||
|
||||
|
|
|
@ -100,7 +100,6 @@ PROMPT_TEMPLATE = Template(
|
|||
|
||||
|
||||
class LlamaGuardShield(ShieldBase):
|
||||
|
||||
@staticmethod
|
||||
def instance(
|
||||
on_violation_action=OnViolationAction.RAISE,
|
||||
|
@ -166,7 +165,6 @@ class LlamaGuardShield(ShieldBase):
|
|||
return None
|
||||
|
||||
def get_safety_categories(self) -> List[str]:
|
||||
|
||||
excluded_categories = self.excluded_categories
|
||||
if set(excluded_categories) == set(SAFETY_CATEGORIES_TO_CODE_MAP.values()):
|
||||
excluded_categories = []
|
||||
|
@ -181,7 +179,6 @@ class LlamaGuardShield(ShieldBase):
|
|||
return categories
|
||||
|
||||
def build_prompt(self, messages: List[Message]) -> str:
|
||||
|
||||
categories = self.get_safety_categories()
|
||||
categories_str = "\n".join(categories)
|
||||
conversations_str = "\n\n".join(
|
||||
|
@ -225,7 +222,6 @@ class LlamaGuardShield(ShieldBase):
|
|||
is_violation=False,
|
||||
)
|
||||
else:
|
||||
|
||||
prompt = self.build_prompt(messages)
|
||||
llama_guard_input = {
|
||||
"role": "user",
|
||||
|
|
|
@ -18,7 +18,6 @@ from llama_toolchain.safety.api.datatypes import * # noqa: F403
|
|||
|
||||
|
||||
class PromptGuardShield(TextShield):
|
||||
|
||||
class Mode(Enum):
|
||||
INJECTION = auto()
|
||||
JAILBREAK = auto()
|
||||
|
|
|
@ -37,4 +37,5 @@ class SyntheticDataGeneration(Protocol):
|
|||
def post_generate(
|
||||
self,
|
||||
request: SyntheticDataGenerationRequest,
|
||||
) -> Union[SyntheticDataGenerationResponse]: ...
|
||||
) -> Union[SyntheticDataGenerationResponse]:
|
||||
...
|
||||
|
|
|
@ -36,7 +36,6 @@ llama download --source huggingface --model-id llama3_1_8b_instruct --hf-token <
|
|||
|
||||
|
||||
class InferenceTests(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# This runs the async setup function
|
||||
|
|
|
@ -20,7 +20,6 @@ from llama_toolchain.inference.ollama.ollama import get_provider_impl
|
|||
|
||||
|
||||
class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def asyncSetUp(self):
|
||||
ollama_config = OllamaImplConfig(url="http://localhost:11434")
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue