formatting

This commit is contained in:
Dalton Flanagan 2024-08-14 14:22:25 -04:00
parent 94dfa293a6
commit b6ccaf1778
33 changed files with 110 additions and 97 deletions

View file

@ -63,9 +63,9 @@ class ShieldCallStep(StepCommon):
@json_schema_type @json_schema_type
class MemoryRetrievalStep(StepCommon): class MemoryRetrievalStep(StepCommon):
step_type: Literal[StepType.memory_retrieval.value] = ( step_type: Literal[
StepType.memory_retrieval.value StepType.memory_retrieval.value
) ] = StepType.memory_retrieval.value
memory_bank_ids: List[str] memory_bank_ids: List[str]
documents: List[MemoryBankDocument] documents: List[MemoryBankDocument]
scores: List[float] scores: List[float]
@ -140,9 +140,9 @@ class AgenticSystemTurnResponseEventType(Enum):
@json_schema_type @json_schema_type
class AgenticSystemTurnResponseStepStartPayload(BaseModel): class AgenticSystemTurnResponseStepStartPayload(BaseModel):
event_type: Literal[AgenticSystemTurnResponseEventType.step_start.value] = ( event_type: Literal[
AgenticSystemTurnResponseEventType.step_start.value AgenticSystemTurnResponseEventType.step_start.value
) ] = AgenticSystemTurnResponseEventType.step_start.value
step_type: StepType step_type: StepType
step_id: str step_id: str
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict) metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
@ -150,9 +150,9 @@ class AgenticSystemTurnResponseStepStartPayload(BaseModel):
@json_schema_type @json_schema_type
class AgenticSystemTurnResponseStepCompletePayload(BaseModel): class AgenticSystemTurnResponseStepCompletePayload(BaseModel):
event_type: Literal[AgenticSystemTurnResponseEventType.step_complete.value] = ( event_type: Literal[
AgenticSystemTurnResponseEventType.step_complete.value AgenticSystemTurnResponseEventType.step_complete.value
) ] = AgenticSystemTurnResponseEventType.step_complete.value
step_type: StepType step_type: StepType
step_details: Step step_details: Step
@ -161,9 +161,9 @@ class AgenticSystemTurnResponseStepCompletePayload(BaseModel):
class AgenticSystemTurnResponseStepProgressPayload(BaseModel): class AgenticSystemTurnResponseStepProgressPayload(BaseModel):
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
event_type: Literal[AgenticSystemTurnResponseEventType.step_progress.value] = ( event_type: Literal[
AgenticSystemTurnResponseEventType.step_progress.value AgenticSystemTurnResponseEventType.step_progress.value
) ] = AgenticSystemTurnResponseEventType.step_progress.value
step_type: StepType step_type: StepType
step_id: str step_id: str
@ -174,17 +174,17 @@ class AgenticSystemTurnResponseStepProgressPayload(BaseModel):
@json_schema_type @json_schema_type
class AgenticSystemTurnResponseTurnStartPayload(BaseModel): class AgenticSystemTurnResponseTurnStartPayload(BaseModel):
event_type: Literal[AgenticSystemTurnResponseEventType.turn_start.value] = ( event_type: Literal[
AgenticSystemTurnResponseEventType.turn_start.value AgenticSystemTurnResponseEventType.turn_start.value
) ] = AgenticSystemTurnResponseEventType.turn_start.value
turn_id: str turn_id: str
@json_schema_type @json_schema_type
class AgenticSystemTurnResponseTurnCompletePayload(BaseModel): class AgenticSystemTurnResponseTurnCompletePayload(BaseModel):
event_type: Literal[AgenticSystemTurnResponseEventType.turn_complete.value] = ( event_type: Literal[
AgenticSystemTurnResponseEventType.turn_complete.value AgenticSystemTurnResponseEventType.turn_complete.value
) ] = AgenticSystemTurnResponseEventType.turn_complete.value
turn: Turn turn: Turn

View file

@ -63,36 +63,40 @@ class AgenticSystemStepResponse(BaseModel):
class AgenticSystem(Protocol): class AgenticSystem(Protocol):
@webmethod(route="/agentic_system/create") @webmethod(route="/agentic_system/create")
async def create_agentic_system( async def create_agentic_system(
self, self,
request: AgenticSystemCreateRequest, request: AgenticSystemCreateRequest,
) -> AgenticSystemCreateResponse: ... ) -> AgenticSystemCreateResponse:
...
@webmethod(route="/agentic_system/turn/create") @webmethod(route="/agentic_system/turn/create")
async def create_agentic_system_turn( async def create_agentic_system_turn(
self, self,
request: AgenticSystemTurnCreateRequest, request: AgenticSystemTurnCreateRequest,
) -> AgenticSystemTurnResponseStreamChunk: ... ) -> AgenticSystemTurnResponseStreamChunk:
...
@webmethod(route="/agentic_system/turn/get") @webmethod(route="/agentic_system/turn/get")
async def get_agentic_system_turn( async def get_agentic_system_turn(
self, self,
agent_id: str, agent_id: str,
turn_id: str, turn_id: str,
) -> Turn: ... ) -> Turn:
...
@webmethod(route="/agentic_system/step/get") @webmethod(route="/agentic_system/step/get")
async def get_agentic_system_step( async def get_agentic_system_step(
self, agent_id: str, turn_id: str, step_id: str self, agent_id: str, turn_id: str, step_id: str
) -> AgenticSystemStepResponse: ... ) -> AgenticSystemStepResponse:
...
@webmethod(route="/agentic_system/session/create") @webmethod(route="/agentic_system/session/create")
async def create_agentic_system_session( async def create_agentic_system_session(
self, self,
request: AgenticSystemSessionCreateRequest, request: AgenticSystemSessionCreateRequest,
) -> AgenticSystemSessionCreateResponse: ... ) -> AgenticSystemSessionCreateResponse:
...
@webmethod(route="/agentic_system/memory_bank/attach") @webmethod(route="/agentic_system/memory_bank/attach")
async def attach_memory_bank_to_agentic_system( async def attach_memory_bank_to_agentic_system(
@ -100,7 +104,8 @@ class AgenticSystem(Protocol):
agent_id: str, agent_id: str,
session_id: str, session_id: str,
memory_bank_ids: List[str], memory_bank_ids: List[str],
) -> None: ... ) -> None:
...
@webmethod(route="/agentic_system/memory_bank/detach") @webmethod(route="/agentic_system/memory_bank/detach")
async def detach_memory_bank_from_agentic_system( async def detach_memory_bank_from_agentic_system(
@ -108,7 +113,8 @@ class AgenticSystem(Protocol):
agent_id: str, agent_id: str,
session_id: str, session_id: str,
memory_bank_ids: List[str], memory_bank_ids: List[str],
) -> None: ... ) -> None:
...
@webmethod(route="/agentic_system/session/get") @webmethod(route="/agentic_system/session/get")
async def get_agentic_system_session( async def get_agentic_system_session(
@ -116,15 +122,18 @@ class AgenticSystem(Protocol):
agent_id: str, agent_id: str,
session_id: str, session_id: str,
turn_ids: Optional[List[str]] = None, turn_ids: Optional[List[str]] = None,
) -> Session: ... ) -> Session:
...
@webmethod(route="/agentic_system/session/delete") @webmethod(route="/agentic_system/session/delete")
async def delete_agentic_system_session( async def delete_agentic_system_session(
self, agent_id: str, session_id: str self, agent_id: str, session_id: str
) -> None: ... ) -> None:
...
@webmethod(route="/agentic_system/delete") @webmethod(route="/agentic_system/delete")
async def delete_agentic_system( async def delete_agentic_system(
self, self,
agent_id: str, agent_id: str,
) -> None: ... ) -> None:
...

View file

@ -246,7 +246,6 @@ class AgentInstance(ShieldRunnerMixin):
await self.run_shields(messages, shields) await self.run_shields(messages, shields)
except SafetyException as e: except SafetyException as e:
yield AgenticSystemTurnResponseStreamChunk( yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent( event=AgenticSystemTurnResponseEvent(
payload=AgenticSystemTurnResponseStepCompletePayload( payload=AgenticSystemTurnResponseStepCompletePayload(

View file

@ -23,7 +23,6 @@ class SafetyException(Exception): # noqa: N818
class ShieldRunnerMixin: class ShieldRunnerMixin:
def __init__( def __init__(
self, self,
safety_api: Safety, safety_api: Safety,

View file

@ -11,7 +11,6 @@ from llama_toolchain.inference.api import Message
class BaseTool(ABC): class BaseTool(ABC):
@abstractmethod @abstractmethod
def get_name(self) -> str: def get_name(self) -> str:
raise NotImplementedError raise NotImplementedError

View file

@ -66,7 +66,6 @@ class SingleMessageBuiltinTool(BaseTool):
class PhotogenTool(SingleMessageBuiltinTool): class PhotogenTool(SingleMessageBuiltinTool):
def __init__(self, dump_dir: str) -> None: def __init__(self, dump_dir: str) -> None:
self.dump_dir = dump_dir self.dump_dir = dump_dir
@ -87,7 +86,6 @@ class PhotogenTool(SingleMessageBuiltinTool):
class BraveSearchTool(SingleMessageBuiltinTool): class BraveSearchTool(SingleMessageBuiltinTool):
def __init__(self, api_key: str) -> None: def __init__(self, api_key: str) -> None:
self.api_key = api_key self.api_key = api_key
@ -204,7 +202,6 @@ class BraveSearchTool(SingleMessageBuiltinTool):
class WolframAlphaTool(SingleMessageBuiltinTool): class WolframAlphaTool(SingleMessageBuiltinTool):
def __init__(self, api_key: str) -> None: def __init__(self, api_key: str) -> None:
self.api_key = api_key self.api_key = api_key
self.url = "https://api.wolframalpha.com/v2/query" self.url = "https://api.wolframalpha.com/v2/query"
@ -287,7 +284,6 @@ class WolframAlphaTool(SingleMessageBuiltinTool):
class CodeInterpreterTool(BaseTool): class CodeInterpreterTool(BaseTool):
def __init__(self) -> None: def __init__(self) -> None:
ctx = CodeExecutionContext( ctx = CodeExecutionContext(
matplotlib_dump_dir=f"/tmp/{os.environ['USER']}_matplotlib_dump", matplotlib_dump_dir=f"/tmp/{os.environ['USER']}_matplotlib_dump",

View file

@ -27,7 +27,6 @@ from llama_toolchain.safety.api.datatypes import BuiltinShield, ShieldDefinition
class AgenticSystemClientWrapper: class AgenticSystemClientWrapper:
def __init__(self, api, system_id, custom_tools): def __init__(self, api, system_id, custom_tools):
self.api = api self.api = api
self.system_id = system_id self.system_id = system_id

View file

@ -10,7 +10,6 @@ from llama_toolchain.cli.subcommand import Subcommand
class DistributionCreate(Subcommand): class DistributionCreate(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction): def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__() super().__init__()
self.parser = subparsers.add_parser( self.parser = subparsers.add_parser(

View file

@ -16,7 +16,6 @@ from .start import DistributionStart
class DistributionParser(Subcommand): class DistributionParser(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction): def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__() super().__init__()
self.parser = subparsers.add_parser( self.parser = subparsers.add_parser(

View file

@ -11,7 +11,6 @@ from llama_toolchain.cli.subcommand import Subcommand
class DistributionList(Subcommand): class DistributionList(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction): def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__() super().__init__()
self.parser = subparsers.add_parser( self.parser = subparsers.add_parser(

View file

@ -14,7 +14,6 @@ from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR
class DistributionStart(Subcommand): class DistributionStart(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction): def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__() super().__init__()
self.parser = subparsers.add_parser( self.parser = subparsers.add_parser(

View file

@ -27,7 +27,11 @@ class ModelList(Subcommand):
self.parser.set_defaults(func=self._run_model_list_cmd) self.parser.set_defaults(func=self._run_model_list_cmd)
def _add_arguments(self): def _add_arguments(self):
self.parser.add_argument('--show-all', action='store_true', help='Show all models (not just defaults)') self.parser.add_argument(
"--show-all",
action="store_true",
help="Show all models (not just defaults)",
)
def _run_model_list_cmd(self, args: argparse.Namespace) -> None: def _run_model_list_cmd(self, args: argparse.Namespace) -> None:
headers = [ headers = [

View file

@ -26,16 +26,19 @@ class Datasets(Protocol):
def create_dataset( def create_dataset(
self, self,
request: CreateDatasetRequest, request: CreateDatasetRequest,
) -> None: ... ) -> None:
...
@webmethod(route="/datasets/get") @webmethod(route="/datasets/get")
def get_dataset( def get_dataset(
self, self,
dataset_uuid: str, dataset_uuid: str,
) -> TrainEvalDataset: ... ) -> TrainEvalDataset:
...
@webmethod(route="/datasets/delete") @webmethod(route="/datasets/delete")
def delete_dataset( def delete_dataset(
self, self,
dataset_uuid: str, dataset_uuid: str,
) -> None: ... ) -> None:
...

View file

@ -217,7 +217,6 @@ def create_dynamic_typed_route(func: Any):
def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]: def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
by_id = {x.api: x for x in providers} by_id = {x.api: x for x in providers}
def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]): def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]):

View file

@ -26,10 +26,8 @@ class SummarizationMetric(Enum):
class EvaluationJob(BaseModel): class EvaluationJob(BaseModel):
job_uuid: str job_uuid: str
class EvaluationJobLogStream(BaseModel): class EvaluationJobLogStream(BaseModel):
job_uuid: str job_uuid: str

View file

@ -48,7 +48,6 @@ class EvaluateSummarizationRequest(EvaluateTaskRequestCommon):
class EvaluationJobStatusResponse(BaseModel): class EvaluationJobStatusResponse(BaseModel):
job_uuid: str job_uuid: str
@ -64,36 +63,42 @@ class Evaluations(Protocol):
def post_evaluate_text_generation( def post_evaluate_text_generation(
self, self,
request: EvaluateTextGenerationRequest, request: EvaluateTextGenerationRequest,
) -> EvaluationJob: ... ) -> EvaluationJob:
...
@webmethod(route="/evaluate/question_answering/") @webmethod(route="/evaluate/question_answering/")
def post_evaluate_question_answering( def post_evaluate_question_answering(
self, self,
request: EvaluateQuestionAnsweringRequest, request: EvaluateQuestionAnsweringRequest,
) -> EvaluationJob: ... ) -> EvaluationJob:
...
@webmethod(route="/evaluate/summarization/") @webmethod(route="/evaluate/summarization/")
def post_evaluate_summarization( def post_evaluate_summarization(
self, self,
request: EvaluateSummarizationRequest, request: EvaluateSummarizationRequest,
) -> EvaluationJob: ... ) -> EvaluationJob:
...
@webmethod(route="/evaluate/jobs") @webmethod(route="/evaluate/jobs")
def get_evaluation_jobs(self) -> List[EvaluationJob]: ... def get_evaluation_jobs(self) -> List[EvaluationJob]:
...
@webmethod(route="/evaluate/job/status") @webmethod(route="/evaluate/job/status")
def get_evaluation_job_status( def get_evaluation_job_status(self, job_uuid: str) -> EvaluationJobStatusResponse:
self, job_uuid: str ...
) -> EvaluationJobStatusResponse: ...
# sends SSE stream of logs # sends SSE stream of logs
@webmethod(route="/evaluate/job/logs") @webmethod(route="/evaluate/job/logs")
def get_evaluation_job_logstream(self, job_uuid: str) -> EvaluationJobLogStream: ... def get_evaluation_job_logstream(self, job_uuid: str) -> EvaluationJobLogStream:
...
@webmethod(route="/evaluate/job/cancel") @webmethod(route="/evaluate/job/cancel")
def cancel_evaluation_job(self, job_uuid: str) -> None: ... def cancel_evaluation_job(self, job_uuid: str) -> None:
...
@webmethod(route="/evaluate/job/artifacts") @webmethod(route="/evaluate/job/artifacts")
def get_evaluation_job_artifacts( def get_evaluation_job_artifacts(
self, job_uuid: str self, job_uuid: str
) -> EvaluationJobArtifactsResponse: ... ) -> EvaluationJobArtifactsResponse:
...

View file

@ -97,27 +97,30 @@ class BatchChatCompletionResponse(BaseModel):
class Inference(Protocol): class Inference(Protocol):
@webmethod(route="/inference/completion") @webmethod(route="/inference/completion")
async def completion( async def completion(
self, self,
request: CompletionRequest, request: CompletionRequest,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ... ) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
...
@webmethod(route="/inference/chat_completion") @webmethod(route="/inference/chat_completion")
async def chat_completion( async def chat_completion(
self, self,
request: ChatCompletionRequest, request: ChatCompletionRequest,
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ... ) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]:
...
@webmethod(route="/inference/batch_completion") @webmethod(route="/inference/batch_completion")
async def batch_completion( async def batch_completion(
self, self,
request: BatchCompletionRequest, request: BatchCompletionRequest,
) -> BatchCompletionResponse: ... ) -> BatchCompletionResponse:
...
@webmethod(route="/inference/batch_chat_completion") @webmethod(route="/inference/batch_chat_completion")
async def batch_chat_completion( async def batch_chat_completion(
self, self,
request: BatchChatCompletionRequest, request: BatchChatCompletionRequest,
) -> BatchChatCompletionResponse: ... ) -> BatchChatCompletionResponse:
...

View file

@ -7,8 +7,8 @@
from termcolor import cprint from termcolor import cprint
from llama_toolchain.inference.api import ( from llama_toolchain.inference.api import (
ChatCompletionResponseEventType, ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk ChatCompletionResponseStreamChunk,
) )

View file

@ -45,7 +45,6 @@ SEMAPHORE = asyncio.Semaphore(1)
class MetaReferenceInferenceImpl(Inference): class MetaReferenceInferenceImpl(Inference):
def __init__(self, config: MetaReferenceImplConfig) -> None: def __init__(self, config: MetaReferenceImplConfig) -> None:
self.config = config self.config = config
model = resolve_model(config.model) model = resolve_model(config.model)

View file

@ -54,7 +54,6 @@ async def get_provider_impl(
class OllamaInference(Inference): class OllamaInference(Inference):
def __init__(self, config: OllamaImplConfig) -> None: def __init__(self, config: OllamaImplConfig) -> None:
self.config = config self.config = config
@ -66,7 +65,9 @@ class OllamaInference(Inference):
try: try:
await self.client.ps() await self.client.ps()
except httpx.ConnectError: except httpx.ConnectError:
raise RuntimeError("Ollama Server is not running, start it using `ollama serve` in a separate terminal") raise RuntimeError(
"Ollama Server is not running, start it using `ollama serve` in a separate terminal"
)
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass

View file

@ -18,44 +18,52 @@ class MemoryBanks(Protocol):
bank_id: str, bank_id: str,
bank_name: str, bank_name: str,
documents: List[MemoryBankDocument], documents: List[MemoryBankDocument],
) -> None: ... ) -> None:
...
@webmethod(route="/memory_banks/list") @webmethod(route="/memory_banks/list")
def get_memory_banks(self) -> List[MemoryBank]: ... def get_memory_banks(self) -> List[MemoryBank]:
...
@webmethod(route="/memory_banks/get") @webmethod(route="/memory_banks/get")
def get_memory_bank(self, bank_id: str) -> List[MemoryBank]: ... def get_memory_bank(self, bank_id: str) -> List[MemoryBank]:
...
@webmethod(route="/memory_banks/drop") @webmethod(route="/memory_banks/drop")
def delete_memory_bank( def delete_memory_bank(
self, self,
bank_id: str, bank_id: str,
) -> str: ... ) -> str:
...
@webmethod(route="/memory_bank/insert") @webmethod(route="/memory_bank/insert")
def post_insert_memory_documents( def post_insert_memory_documents(
self, self,
bank_id: str, bank_id: str,
documents: List[MemoryBankDocument], documents: List[MemoryBankDocument],
) -> None: ... ) -> None:
...
@webmethod(route="/memory_bank/update") @webmethod(route="/memory_bank/update")
def post_update_memory_documents( def post_update_memory_documents(
self, self,
bank_id: str, bank_id: str,
documents: List[MemoryBankDocument], documents: List[MemoryBankDocument],
) -> None: ... ) -> None:
...
@webmethod(route="/memory_bank/get") @webmethod(route="/memory_bank/get")
def get_memory_documents( def get_memory_documents(
self, self,
bank_id: str, bank_id: str,
document_uuids: List[str], document_uuids: List[str],
) -> List[MemoryBankDocument]: ... ) -> List[MemoryBankDocument]:
...
@webmethod(route="/memory_bank/delete") @webmethod(route="/memory_bank/delete")
def delete_memory_documents( def delete_memory_documents(
self, self,
bank_id: str, bank_id: str,
document_uuids: List[str], document_uuids: List[str],
) -> List[str]: ... ) -> List[str]:
...

View file

@ -11,4 +11,5 @@ from llama_models.schema_utils import webmethod # noqa: F401
from pydantic import BaseModel # noqa: F401 from pydantic import BaseModel # noqa: F401
class Models(Protocol): ... class Models(Protocol):
...

View file

@ -64,7 +64,6 @@ class PostTrainingRLHFRequest(BaseModel):
class PostTrainingJob(BaseModel): class PostTrainingJob(BaseModel):
job_uuid: str job_uuid: str
@ -99,30 +98,35 @@ class PostTraining(Protocol):
def post_supervised_fine_tune( def post_supervised_fine_tune(
self, self,
request: PostTrainingSFTRequest, request: PostTrainingSFTRequest,
) -> PostTrainingJob: ... ) -> PostTrainingJob:
...
@webmethod(route="/post_training/preference_optimize") @webmethod(route="/post_training/preference_optimize")
def post_preference_optimize( def post_preference_optimize(
self, self,
request: PostTrainingRLHFRequest, request: PostTrainingRLHFRequest,
) -> PostTrainingJob: ... ) -> PostTrainingJob:
...
@webmethod(route="/post_training/jobs") @webmethod(route="/post_training/jobs")
def get_training_jobs(self) -> List[PostTrainingJob]: ... def get_training_jobs(self) -> List[PostTrainingJob]:
...
# sends SSE stream of logs # sends SSE stream of logs
@webmethod(route="/post_training/job/logs") @webmethod(route="/post_training/job/logs")
def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream: ... def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream:
...
@webmethod(route="/post_training/job/status") @webmethod(route="/post_training/job/status")
def get_training_job_status( def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse:
self, job_uuid: str ...
) -> PostTrainingJobStatusResponse: ...
@webmethod(route="/post_training/job/cancel") @webmethod(route="/post_training/job/cancel")
def cancel_training_job(self, job_uuid: str) -> None: ... def cancel_training_job(self, job_uuid: str) -> None:
...
@webmethod(route="/post_training/job/artifacts") @webmethod(route="/post_training/job/artifacts")
def get_training_job_artifacts( def get_training_job_artifacts(
self, job_uuid: str self, job_uuid: str
) -> PostTrainingJobArtifactsResponse: ... ) -> PostTrainingJobArtifactsResponse:
...

View file

@ -30,4 +30,5 @@ class RewardScoring(Protocol):
def post_score( def post_score(
self, self,
request: RewardScoringRequest, request: RewardScoringRequest,
) -> Union[RewardScoringResponse]: ... ) -> Union[RewardScoringResponse]:
...

View file

@ -25,9 +25,9 @@ class RunShieldResponse(BaseModel):
class Safety(Protocol): class Safety(Protocol):
@webmethod(route="/safety/run_shields") @webmethod(route="/safety/run_shields")
async def run_shields( async def run_shields(
self, self,
request: RunShieldRequest, request: RunShieldRequest,
) -> RunShieldResponse: ... ) -> RunShieldResponse:
...

View file

@ -41,7 +41,6 @@ def resolve_and_get_path(model_name: str) -> str:
class MetaReferenceSafetyImpl(Safety): class MetaReferenceSafetyImpl(Safety):
def __init__(self, config: SafetyConfig) -> None: def __init__(self, config: SafetyConfig) -> None:
self.config = config self.config = config

View file

@ -14,7 +14,6 @@ CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
class ShieldBase(ABC): class ShieldBase(ABC):
def __init__( def __init__(
self, self,
on_violation_action: OnViolationAction = OnViolationAction.RAISE, on_violation_action: OnViolationAction = OnViolationAction.RAISE,
@ -60,7 +59,6 @@ class TextShield(ShieldBase):
class DummyShield(TextShield): class DummyShield(TextShield):
def get_shield_type(self) -> ShieldType: def get_shield_type(self) -> ShieldType:
return "dummy" return "dummy"

View file

@ -12,7 +12,6 @@ from llama_toolchain.safety.api.datatypes import * # noqa: F403
class CodeScannerShield(TextShield): class CodeScannerShield(TextShield):
def get_shield_type(self) -> ShieldType: def get_shield_type(self) -> ShieldType:
return BuiltinShield.code_scanner_guard return BuiltinShield.code_scanner_guard

View file

@ -100,7 +100,6 @@ PROMPT_TEMPLATE = Template(
class LlamaGuardShield(ShieldBase): class LlamaGuardShield(ShieldBase):
@staticmethod @staticmethod
def instance( def instance(
on_violation_action=OnViolationAction.RAISE, on_violation_action=OnViolationAction.RAISE,
@ -166,7 +165,6 @@ class LlamaGuardShield(ShieldBase):
return None return None
def get_safety_categories(self) -> List[str]: def get_safety_categories(self) -> List[str]:
excluded_categories = self.excluded_categories excluded_categories = self.excluded_categories
if set(excluded_categories) == set(SAFETY_CATEGORIES_TO_CODE_MAP.values()): if set(excluded_categories) == set(SAFETY_CATEGORIES_TO_CODE_MAP.values()):
excluded_categories = [] excluded_categories = []
@ -181,7 +179,6 @@ class LlamaGuardShield(ShieldBase):
return categories return categories
def build_prompt(self, messages: List[Message]) -> str: def build_prompt(self, messages: List[Message]) -> str:
categories = self.get_safety_categories() categories = self.get_safety_categories()
categories_str = "\n".join(categories) categories_str = "\n".join(categories)
conversations_str = "\n\n".join( conversations_str = "\n\n".join(
@ -225,7 +222,6 @@ class LlamaGuardShield(ShieldBase):
is_violation=False, is_violation=False,
) )
else: else:
prompt = self.build_prompt(messages) prompt = self.build_prompt(messages)
llama_guard_input = { llama_guard_input = {
"role": "user", "role": "user",

View file

@ -18,7 +18,6 @@ from llama_toolchain.safety.api.datatypes import * # noqa: F403
class PromptGuardShield(TextShield): class PromptGuardShield(TextShield):
class Mode(Enum): class Mode(Enum):
INJECTION = auto() INJECTION = auto()
JAILBREAK = auto() JAILBREAK = auto()

View file

@ -37,4 +37,5 @@ class SyntheticDataGeneration(Protocol):
def post_generate( def post_generate(
self, self,
request: SyntheticDataGenerationRequest, request: SyntheticDataGenerationRequest,
) -> Union[SyntheticDataGenerationResponse]: ... ) -> Union[SyntheticDataGenerationResponse]:
...

View file

@ -36,7 +36,6 @@ llama download --source huggingface --model-id llama3_1_8b_instruct --hf-token <
class InferenceTests(unittest.IsolatedAsyncioTestCase): class InferenceTests(unittest.IsolatedAsyncioTestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
# This runs the async setup function # This runs the async setup function

View file

@ -20,7 +20,6 @@ from llama_toolchain.inference.ollama.ollama import get_provider_impl
class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self): async def asyncSetUp(self):
ollama_config = OllamaImplConfig(url="http://localhost:11434") ollama_config = OllamaImplConfig(url="http://localhost:11434")