From b6ccaf1778845d363974e8de970912ef26936076 Mon Sep 17 00:00:00 2001 From: Dalton Flanagan <6599399+dltn@users.noreply.github.com> Date: Wed, 14 Aug 2024 14:22:25 -0400 Subject: [PATCH] formatting --- .../agentic_system/api/datatypes.py | 24 +++++++------- .../agentic_system/api/endpoints.py | 31 ++++++++++++------- .../meta_reference/agent_instance.py | 1 - .../agentic_system/meta_reference/safety.py | 1 - .../meta_reference/tools/base.py | 1 - .../meta_reference/tools/builtin.py | 4 --- llama_toolchain/agentic_system/utils.py | 1 - llama_toolchain/cli/distribution/create.py | 1 - .../cli/distribution/distribution.py | 1 - llama_toolchain/cli/distribution/list.py | 1 - llama_toolchain/cli/distribution/start.py | 1 - llama_toolchain/cli/model/list.py | 6 +++- llama_toolchain/dataset/api/endpoints.py | 9 ++++-- llama_toolchain/distribution/server.py | 1 - llama_toolchain/evaluations/api/datatypes.py | 2 -- llama_toolchain/evaluations/api/endpoints.py | 27 +++++++++------- llama_toolchain/inference/api/endpoints.py | 13 +++++--- llama_toolchain/inference/event_logger.py | 4 +-- .../inference/meta_reference/inference.py | 1 - llama_toolchain/inference/ollama/ollama.py | 5 +-- llama_toolchain/memory/api/endpoints.py | 24 +++++++++----- llama_toolchain/models/api/endpoints.py | 3 +- .../post_training/api/endpoints.py | 24 ++++++++------ .../reward_scoring/api/endpoints.py | 3 +- llama_toolchain/safety/api/endpoints.py | 4 +-- .../safety/meta_reference/safety.py | 1 - .../safety/meta_reference/shields/base.py | 2 -- .../meta_reference/shields/code_scanner.py | 1 - .../meta_reference/shields/llama_guard.py | 4 --- .../meta_reference/shields/prompt_guard.py | 1 - .../api/endpoints.py | 3 +- tests/test_inference.py | 1 - tests/test_ollama_inference.py | 1 - 33 files changed, 110 insertions(+), 97 deletions(-) diff --git a/llama_toolchain/agentic_system/api/datatypes.py b/llama_toolchain/agentic_system/api/datatypes.py index 1dda64834..0da30f228 100644 --- a/llama_toolchain/agentic_system/api/datatypes.py +++ b/llama_toolchain/agentic_system/api/datatypes.py @@ -63,9 +63,9 @@ class ShieldCallStep(StepCommon): @json_schema_type class MemoryRetrievalStep(StepCommon): - step_type: Literal[StepType.memory_retrieval.value] = ( + step_type: Literal[ StepType.memory_retrieval.value - ) + ] = StepType.memory_retrieval.value memory_bank_ids: List[str] documents: List[MemoryBankDocument] scores: List[float] @@ -140,9 +140,9 @@ class AgenticSystemTurnResponseEventType(Enum): @json_schema_type class AgenticSystemTurnResponseStepStartPayload(BaseModel): - event_type: Literal[AgenticSystemTurnResponseEventType.step_start.value] = ( + event_type: Literal[ AgenticSystemTurnResponseEventType.step_start.value - ) + ] = AgenticSystemTurnResponseEventType.step_start.value step_type: StepType step_id: str metadata: Optional[Dict[str, Any]] = Field(default_factory=dict) @@ -150,9 +150,9 @@ class AgenticSystemTurnResponseStepStartPayload(BaseModel): @json_schema_type class AgenticSystemTurnResponseStepCompletePayload(BaseModel): - event_type: Literal[AgenticSystemTurnResponseEventType.step_complete.value] = ( + event_type: Literal[ AgenticSystemTurnResponseEventType.step_complete.value - ) + ] = AgenticSystemTurnResponseEventType.step_complete.value step_type: StepType step_details: Step @@ -161,9 +161,9 @@ class AgenticSystemTurnResponseStepCompletePayload(BaseModel): class AgenticSystemTurnResponseStepProgressPayload(BaseModel): model_config = ConfigDict(protected_namespaces=()) - event_type: Literal[AgenticSystemTurnResponseEventType.step_progress.value] = ( + event_type: Literal[ AgenticSystemTurnResponseEventType.step_progress.value - ) + ] = AgenticSystemTurnResponseEventType.step_progress.value step_type: StepType step_id: str @@ -174,17 +174,17 @@ class AgenticSystemTurnResponseStepProgressPayload(BaseModel): @json_schema_type class AgenticSystemTurnResponseTurnStartPayload(BaseModel): - event_type: Literal[AgenticSystemTurnResponseEventType.turn_start.value] = ( + event_type: Literal[ AgenticSystemTurnResponseEventType.turn_start.value - ) + ] = AgenticSystemTurnResponseEventType.turn_start.value turn_id: str @json_schema_type class AgenticSystemTurnResponseTurnCompletePayload(BaseModel): - event_type: Literal[AgenticSystemTurnResponseEventType.turn_complete.value] = ( + event_type: Literal[ AgenticSystemTurnResponseEventType.turn_complete.value - ) + ] = AgenticSystemTurnResponseEventType.turn_complete.value turn: Turn diff --git a/llama_toolchain/agentic_system/api/endpoints.py b/llama_toolchain/agentic_system/api/endpoints.py index be25b7d10..ee52f9b5a 100644 --- a/llama_toolchain/agentic_system/api/endpoints.py +++ b/llama_toolchain/agentic_system/api/endpoints.py @@ -63,36 +63,40 @@ class AgenticSystemStepResponse(BaseModel): class AgenticSystem(Protocol): - @webmethod(route="/agentic_system/create") async def create_agentic_system( self, request: AgenticSystemCreateRequest, - ) -> AgenticSystemCreateResponse: ... + ) -> AgenticSystemCreateResponse: + ... @webmethod(route="/agentic_system/turn/create") async def create_agentic_system_turn( self, request: AgenticSystemTurnCreateRequest, - ) -> AgenticSystemTurnResponseStreamChunk: ... + ) -> AgenticSystemTurnResponseStreamChunk: + ... @webmethod(route="/agentic_system/turn/get") async def get_agentic_system_turn( self, agent_id: str, turn_id: str, - ) -> Turn: ... + ) -> Turn: + ... @webmethod(route="/agentic_system/step/get") async def get_agentic_system_step( self, agent_id: str, turn_id: str, step_id: str - ) -> AgenticSystemStepResponse: ... + ) -> AgenticSystemStepResponse: + ... @webmethod(route="/agentic_system/session/create") async def create_agentic_system_session( self, request: AgenticSystemSessionCreateRequest, - ) -> AgenticSystemSessionCreateResponse: ... + ) -> AgenticSystemSessionCreateResponse: + ... @webmethod(route="/agentic_system/memory_bank/attach") async def attach_memory_bank_to_agentic_system( @@ -100,7 +104,8 @@ class AgenticSystem(Protocol): agent_id: str, session_id: str, memory_bank_ids: List[str], - ) -> None: ... + ) -> None: + ... @webmethod(route="/agentic_system/memory_bank/detach") async def detach_memory_bank_from_agentic_system( @@ -108,7 +113,8 @@ class AgenticSystem(Protocol): agent_id: str, session_id: str, memory_bank_ids: List[str], - ) -> None: ... + ) -> None: + ... @webmethod(route="/agentic_system/session/get") async def get_agentic_system_session( @@ -116,15 +122,18 @@ class AgenticSystem(Protocol): agent_id: str, session_id: str, turn_ids: Optional[List[str]] = None, - ) -> Session: ... + ) -> Session: + ... @webmethod(route="/agentic_system/session/delete") async def delete_agentic_system_session( self, agent_id: str, session_id: str - ) -> None: ... + ) -> None: + ... @webmethod(route="/agentic_system/delete") async def delete_agentic_system( self, agent_id: str, - ) -> None: ... + ) -> None: + ... diff --git a/llama_toolchain/agentic_system/meta_reference/agent_instance.py b/llama_toolchain/agentic_system/meta_reference/agent_instance.py index 06a5bb3db..8e4555cb4 100644 --- a/llama_toolchain/agentic_system/meta_reference/agent_instance.py +++ b/llama_toolchain/agentic_system/meta_reference/agent_instance.py @@ -246,7 +246,6 @@ class AgentInstance(ShieldRunnerMixin): await self.run_shields(messages, shields) except SafetyException as e: - yield AgenticSystemTurnResponseStreamChunk( event=AgenticSystemTurnResponseEvent( payload=AgenticSystemTurnResponseStepCompletePayload( diff --git a/llama_toolchain/agentic_system/meta_reference/safety.py b/llama_toolchain/agentic_system/meta_reference/safety.py index f066baf59..c78cb3028 100644 --- a/llama_toolchain/agentic_system/meta_reference/safety.py +++ b/llama_toolchain/agentic_system/meta_reference/safety.py @@ -23,7 +23,6 @@ class SafetyException(Exception): # noqa: N818 class ShieldRunnerMixin: - def __init__( self, safety_api: Safety, diff --git a/llama_toolchain/agentic_system/meta_reference/tools/base.py b/llama_toolchain/agentic_system/meta_reference/tools/base.py index 3c2722305..324cce0e2 100644 --- a/llama_toolchain/agentic_system/meta_reference/tools/base.py +++ b/llama_toolchain/agentic_system/meta_reference/tools/base.py @@ -11,7 +11,6 @@ from llama_toolchain.inference.api import Message class BaseTool(ABC): - @abstractmethod def get_name(self) -> str: raise NotImplementedError diff --git a/llama_toolchain/agentic_system/meta_reference/tools/builtin.py b/llama_toolchain/agentic_system/meta_reference/tools/builtin.py index 4487a2692..c13af125f 100644 --- a/llama_toolchain/agentic_system/meta_reference/tools/builtin.py +++ b/llama_toolchain/agentic_system/meta_reference/tools/builtin.py @@ -66,7 +66,6 @@ class SingleMessageBuiltinTool(BaseTool): class PhotogenTool(SingleMessageBuiltinTool): - def __init__(self, dump_dir: str) -> None: self.dump_dir = dump_dir @@ -87,7 +86,6 @@ class PhotogenTool(SingleMessageBuiltinTool): class BraveSearchTool(SingleMessageBuiltinTool): - def __init__(self, api_key: str) -> None: self.api_key = api_key @@ -204,7 +202,6 @@ class BraveSearchTool(SingleMessageBuiltinTool): class WolframAlphaTool(SingleMessageBuiltinTool): - def __init__(self, api_key: str) -> None: self.api_key = api_key self.url = "https://api.wolframalpha.com/v2/query" @@ -287,7 +284,6 @@ class WolframAlphaTool(SingleMessageBuiltinTool): class CodeInterpreterTool(BaseTool): - def __init__(self) -> None: ctx = CodeExecutionContext( matplotlib_dump_dir=f"/tmp/{os.environ['USER']}_matplotlib_dump", diff --git a/llama_toolchain/agentic_system/utils.py b/llama_toolchain/agentic_system/utils.py index 299c5f93b..bc1639b3d 100644 --- a/llama_toolchain/agentic_system/utils.py +++ b/llama_toolchain/agentic_system/utils.py @@ -27,7 +27,6 @@ from llama_toolchain.safety.api.datatypes import BuiltinShield, ShieldDefinition class AgenticSystemClientWrapper: - def __init__(self, api, system_id, custom_tools): self.api = api self.system_id = system_id diff --git a/llama_toolchain/cli/distribution/create.py b/llama_toolchain/cli/distribution/create.py index 140f1027d..f4b6d3f20 100644 --- a/llama_toolchain/cli/distribution/create.py +++ b/llama_toolchain/cli/distribution/create.py @@ -10,7 +10,6 @@ from llama_toolchain.cli.subcommand import Subcommand class DistributionCreate(Subcommand): - def __init__(self, subparsers: argparse._SubParsersAction): super().__init__() self.parser = subparsers.add_parser( diff --git a/llama_toolchain/cli/distribution/distribution.py b/llama_toolchain/cli/distribution/distribution.py index afc5f9341..641f360e9 100644 --- a/llama_toolchain/cli/distribution/distribution.py +++ b/llama_toolchain/cli/distribution/distribution.py @@ -16,7 +16,6 @@ from .start import DistributionStart class DistributionParser(Subcommand): - def __init__(self, subparsers: argparse._SubParsersAction): super().__init__() self.parser = subparsers.add_parser( diff --git a/llama_toolchain/cli/distribution/list.py b/llama_toolchain/cli/distribution/list.py index b285f2006..e214490ef 100644 --- a/llama_toolchain/cli/distribution/list.py +++ b/llama_toolchain/cli/distribution/list.py @@ -11,7 +11,6 @@ from llama_toolchain.cli.subcommand import Subcommand class DistributionList(Subcommand): - def __init__(self, subparsers: argparse._SubParsersAction): super().__init__() self.parser = subparsers.add_parser( diff --git a/llama_toolchain/cli/distribution/start.py b/llama_toolchain/cli/distribution/start.py index 8620550db..b854c79dc 100644 --- a/llama_toolchain/cli/distribution/start.py +++ b/llama_toolchain/cli/distribution/start.py @@ -14,7 +14,6 @@ from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR class DistributionStart(Subcommand): - def __init__(self, subparsers: argparse._SubParsersAction): super().__init__() self.parser = subparsers.add_parser( diff --git a/llama_toolchain/cli/model/list.py b/llama_toolchain/cli/model/list.py index e7e1e4054..cbbed7e54 100644 --- a/llama_toolchain/cli/model/list.py +++ b/llama_toolchain/cli/model/list.py @@ -27,7 +27,11 @@ class ModelList(Subcommand): self.parser.set_defaults(func=self._run_model_list_cmd) def _add_arguments(self): - self.parser.add_argument('--show-all', action='store_true', help='Show all models (not just defaults)') + self.parser.add_argument( + "--show-all", + action="store_true", + help="Show all models (not just defaults)", + ) def _run_model_list_cmd(self, args: argparse.Namespace) -> None: headers = [ diff --git a/llama_toolchain/dataset/api/endpoints.py b/llama_toolchain/dataset/api/endpoints.py index 6a88f4b7a..92d0f40c9 100644 --- a/llama_toolchain/dataset/api/endpoints.py +++ b/llama_toolchain/dataset/api/endpoints.py @@ -26,16 +26,19 @@ class Datasets(Protocol): def create_dataset( self, request: CreateDatasetRequest, - ) -> None: ... + ) -> None: + ... @webmethod(route="/datasets/get") def get_dataset( self, dataset_uuid: str, - ) -> TrainEvalDataset: ... + ) -> TrainEvalDataset: + ... @webmethod(route="/datasets/delete") def delete_dataset( self, dataset_uuid: str, - ) -> None: ... + ) -> None: + ... diff --git a/llama_toolchain/distribution/server.py b/llama_toolchain/distribution/server.py index a087b3a64..d1140e6a0 100644 --- a/llama_toolchain/distribution/server.py +++ b/llama_toolchain/distribution/server.py @@ -217,7 +217,6 @@ def create_dynamic_typed_route(func: Any): def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]: - by_id = {x.api: x for x in providers} def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]): diff --git a/llama_toolchain/evaluations/api/datatypes.py b/llama_toolchain/evaluations/api/datatypes.py index 9413fcd7c..0ba284e9d 100644 --- a/llama_toolchain/evaluations/api/datatypes.py +++ b/llama_toolchain/evaluations/api/datatypes.py @@ -26,10 +26,8 @@ class SummarizationMetric(Enum): class EvaluationJob(BaseModel): - job_uuid: str class EvaluationJobLogStream(BaseModel): - job_uuid: str diff --git a/llama_toolchain/evaluations/api/endpoints.py b/llama_toolchain/evaluations/api/endpoints.py index af5724f2a..d8e932c88 100644 --- a/llama_toolchain/evaluations/api/endpoints.py +++ b/llama_toolchain/evaluations/api/endpoints.py @@ -48,7 +48,6 @@ class EvaluateSummarizationRequest(EvaluateTaskRequestCommon): class EvaluationJobStatusResponse(BaseModel): - job_uuid: str @@ -64,36 +63,42 @@ class Evaluations(Protocol): def post_evaluate_text_generation( self, request: EvaluateTextGenerationRequest, - ) -> EvaluationJob: ... + ) -> EvaluationJob: + ... @webmethod(route="/evaluate/question_answering/") def post_evaluate_question_answering( self, request: EvaluateQuestionAnsweringRequest, - ) -> EvaluationJob: ... + ) -> EvaluationJob: + ... @webmethod(route="/evaluate/summarization/") def post_evaluate_summarization( self, request: EvaluateSummarizationRequest, - ) -> EvaluationJob: ... + ) -> EvaluationJob: + ... @webmethod(route="/evaluate/jobs") - def get_evaluation_jobs(self) -> List[EvaluationJob]: ... + def get_evaluation_jobs(self) -> List[EvaluationJob]: + ... @webmethod(route="/evaluate/job/status") - def get_evaluation_job_status( - self, job_uuid: str - ) -> EvaluationJobStatusResponse: ... + def get_evaluation_job_status(self, job_uuid: str) -> EvaluationJobStatusResponse: + ... # sends SSE stream of logs @webmethod(route="/evaluate/job/logs") - def get_evaluation_job_logstream(self, job_uuid: str) -> EvaluationJobLogStream: ... + def get_evaluation_job_logstream(self, job_uuid: str) -> EvaluationJobLogStream: + ... @webmethod(route="/evaluate/job/cancel") - def cancel_evaluation_job(self, job_uuid: str) -> None: ... + def cancel_evaluation_job(self, job_uuid: str) -> None: + ... @webmethod(route="/evaluate/job/artifacts") def get_evaluation_job_artifacts( self, job_uuid: str - ) -> EvaluationJobArtifactsResponse: ... + ) -> EvaluationJobArtifactsResponse: + ... diff --git a/llama_toolchain/inference/api/endpoints.py b/llama_toolchain/inference/api/endpoints.py index f225f5b5c..fe81695a0 100644 --- a/llama_toolchain/inference/api/endpoints.py +++ b/llama_toolchain/inference/api/endpoints.py @@ -97,27 +97,30 @@ class BatchChatCompletionResponse(BaseModel): class Inference(Protocol): - @webmethod(route="/inference/completion") async def completion( self, request: CompletionRequest, - ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ... + ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: + ... @webmethod(route="/inference/chat_completion") async def chat_completion( self, request: ChatCompletionRequest, - ) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ... + ) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: + ... @webmethod(route="/inference/batch_completion") async def batch_completion( self, request: BatchCompletionRequest, - ) -> BatchCompletionResponse: ... + ) -> BatchCompletionResponse: + ... @webmethod(route="/inference/batch_chat_completion") async def batch_chat_completion( self, request: BatchChatCompletionRequest, - ) -> BatchChatCompletionResponse: ... + ) -> BatchChatCompletionResponse: + ... diff --git a/llama_toolchain/inference/event_logger.py b/llama_toolchain/inference/event_logger.py index 9d9434b6a..ebe241395 100644 --- a/llama_toolchain/inference/event_logger.py +++ b/llama_toolchain/inference/event_logger.py @@ -7,8 +7,8 @@ from termcolor import cprint from llama_toolchain.inference.api import ( - ChatCompletionResponseEventType, - ChatCompletionResponseStreamChunk + ChatCompletionResponseEventType, + ChatCompletionResponseStreamChunk, ) diff --git a/llama_toolchain/inference/meta_reference/inference.py b/llama_toolchain/inference/meta_reference/inference.py index b41cb4acb..4bd7a80bc 100644 --- a/llama_toolchain/inference/meta_reference/inference.py +++ b/llama_toolchain/inference/meta_reference/inference.py @@ -45,7 +45,6 @@ SEMAPHORE = asyncio.Semaphore(1) class MetaReferenceInferenceImpl(Inference): - def __init__(self, config: MetaReferenceImplConfig) -> None: self.config = config model = resolve_model(config.model) diff --git a/llama_toolchain/inference/ollama/ollama.py b/llama_toolchain/inference/ollama/ollama.py index 4f4a161ad..64f24bee4 100644 --- a/llama_toolchain/inference/ollama/ollama.py +++ b/llama_toolchain/inference/ollama/ollama.py @@ -54,7 +54,6 @@ async def get_provider_impl( class OllamaInference(Inference): - def __init__(self, config: OllamaImplConfig) -> None: self.config = config @@ -66,7 +65,9 @@ class OllamaInference(Inference): try: await self.client.ps() except httpx.ConnectError: - raise RuntimeError("Ollama Server is not running, start it using `ollama serve` in a separate terminal") + raise RuntimeError( + "Ollama Server is not running, start it using `ollama serve` in a separate terminal" + ) async def shutdown(self) -> None: pass diff --git a/llama_toolchain/memory/api/endpoints.py b/llama_toolchain/memory/api/endpoints.py index 4261afa89..66c6719a4 100644 --- a/llama_toolchain/memory/api/endpoints.py +++ b/llama_toolchain/memory/api/endpoints.py @@ -18,44 +18,52 @@ class MemoryBanks(Protocol): bank_id: str, bank_name: str, documents: List[MemoryBankDocument], - ) -> None: ... + ) -> None: + ... @webmethod(route="/memory_banks/list") - def get_memory_banks(self) -> List[MemoryBank]: ... + def get_memory_banks(self) -> List[MemoryBank]: + ... @webmethod(route="/memory_banks/get") - def get_memory_bank(self, bank_id: str) -> List[MemoryBank]: ... + def get_memory_bank(self, bank_id: str) -> List[MemoryBank]: + ... @webmethod(route="/memory_banks/drop") def delete_memory_bank( self, bank_id: str, - ) -> str: ... + ) -> str: + ... @webmethod(route="/memory_bank/insert") def post_insert_memory_documents( self, bank_id: str, documents: List[MemoryBankDocument], - ) -> None: ... + ) -> None: + ... @webmethod(route="/memory_bank/update") def post_update_memory_documents( self, bank_id: str, documents: List[MemoryBankDocument], - ) -> None: ... + ) -> None: + ... @webmethod(route="/memory_bank/get") def get_memory_documents( self, bank_id: str, document_uuids: List[str], - ) -> List[MemoryBankDocument]: ... + ) -> List[MemoryBankDocument]: + ... @webmethod(route="/memory_bank/delete") def delete_memory_documents( self, bank_id: str, document_uuids: List[str], - ) -> List[str]: ... + ) -> List[str]: + ... diff --git a/llama_toolchain/models/api/endpoints.py b/llama_toolchain/models/api/endpoints.py index ee1d5f0ba..250df62e7 100644 --- a/llama_toolchain/models/api/endpoints.py +++ b/llama_toolchain/models/api/endpoints.py @@ -11,4 +11,5 @@ from llama_models.schema_utils import webmethod # noqa: F401 from pydantic import BaseModel # noqa: F401 -class Models(Protocol): ... +class Models(Protocol): + ... diff --git a/llama_toolchain/post_training/api/endpoints.py b/llama_toolchain/post_training/api/endpoints.py index 4d9c4c02b..795e1d6f8 100644 --- a/llama_toolchain/post_training/api/endpoints.py +++ b/llama_toolchain/post_training/api/endpoints.py @@ -64,7 +64,6 @@ class PostTrainingRLHFRequest(BaseModel): class PostTrainingJob(BaseModel): - job_uuid: str @@ -99,30 +98,35 @@ class PostTraining(Protocol): def post_supervised_fine_tune( self, request: PostTrainingSFTRequest, - ) -> PostTrainingJob: ... + ) -> PostTrainingJob: + ... @webmethod(route="/post_training/preference_optimize") def post_preference_optimize( self, request: PostTrainingRLHFRequest, - ) -> PostTrainingJob: ... + ) -> PostTrainingJob: + ... @webmethod(route="/post_training/jobs") - def get_training_jobs(self) -> List[PostTrainingJob]: ... + def get_training_jobs(self) -> List[PostTrainingJob]: + ... # sends SSE stream of logs @webmethod(route="/post_training/job/logs") - def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream: ... + def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream: + ... @webmethod(route="/post_training/job/status") - def get_training_job_status( - self, job_uuid: str - ) -> PostTrainingJobStatusResponse: ... + def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse: + ... @webmethod(route="/post_training/job/cancel") - def cancel_training_job(self, job_uuid: str) -> None: ... + def cancel_training_job(self, job_uuid: str) -> None: + ... @webmethod(route="/post_training/job/artifacts") def get_training_job_artifacts( self, job_uuid: str - ) -> PostTrainingJobArtifactsResponse: ... + ) -> PostTrainingJobArtifactsResponse: + ... diff --git a/llama_toolchain/reward_scoring/api/endpoints.py b/llama_toolchain/reward_scoring/api/endpoints.py index 0a7327a9b..7ff059f13 100644 --- a/llama_toolchain/reward_scoring/api/endpoints.py +++ b/llama_toolchain/reward_scoring/api/endpoints.py @@ -30,4 +30,5 @@ class RewardScoring(Protocol): def post_score( self, request: RewardScoringRequest, - ) -> Union[RewardScoringResponse]: ... + ) -> Union[RewardScoringResponse]: + ... diff --git a/llama_toolchain/safety/api/endpoints.py b/llama_toolchain/safety/api/endpoints.py index 984a58a89..c49577674 100644 --- a/llama_toolchain/safety/api/endpoints.py +++ b/llama_toolchain/safety/api/endpoints.py @@ -25,9 +25,9 @@ class RunShieldResponse(BaseModel): class Safety(Protocol): - @webmethod(route="/safety/run_shields") async def run_shields( self, request: RunShieldRequest, - ) -> RunShieldResponse: ... + ) -> RunShieldResponse: + ... diff --git a/llama_toolchain/safety/meta_reference/safety.py b/llama_toolchain/safety/meta_reference/safety.py index 60d16dbf1..426376c2d 100644 --- a/llama_toolchain/safety/meta_reference/safety.py +++ b/llama_toolchain/safety/meta_reference/safety.py @@ -41,7 +41,6 @@ def resolve_and_get_path(model_name: str) -> str: class MetaReferenceSafetyImpl(Safety): - def __init__(self, config: SafetyConfig) -> None: self.config = config diff --git a/llama_toolchain/safety/meta_reference/shields/base.py b/llama_toolchain/safety/meta_reference/shields/base.py index c4e2aa830..ce19a3676 100644 --- a/llama_toolchain/safety/meta_reference/shields/base.py +++ b/llama_toolchain/safety/meta_reference/shields/base.py @@ -14,7 +14,6 @@ CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?" class ShieldBase(ABC): - def __init__( self, on_violation_action: OnViolationAction = OnViolationAction.RAISE, @@ -60,7 +59,6 @@ class TextShield(ShieldBase): class DummyShield(TextShield): - def get_shield_type(self) -> ShieldType: return "dummy" diff --git a/llama_toolchain/safety/meta_reference/shields/code_scanner.py b/llama_toolchain/safety/meta_reference/shields/code_scanner.py index 6fcf3e6a9..f78260ff1 100644 --- a/llama_toolchain/safety/meta_reference/shields/code_scanner.py +++ b/llama_toolchain/safety/meta_reference/shields/code_scanner.py @@ -12,7 +12,6 @@ from llama_toolchain.safety.api.datatypes import * # noqa: F403 class CodeScannerShield(TextShield): - def get_shield_type(self) -> ShieldType: return BuiltinShield.code_scanner_guard diff --git a/llama_toolchain/safety/meta_reference/shields/llama_guard.py b/llama_toolchain/safety/meta_reference/shields/llama_guard.py index 50044f9b6..56126abde 100644 --- a/llama_toolchain/safety/meta_reference/shields/llama_guard.py +++ b/llama_toolchain/safety/meta_reference/shields/llama_guard.py @@ -100,7 +100,6 @@ PROMPT_TEMPLATE = Template( class LlamaGuardShield(ShieldBase): - @staticmethod def instance( on_violation_action=OnViolationAction.RAISE, @@ -166,7 +165,6 @@ class LlamaGuardShield(ShieldBase): return None def get_safety_categories(self) -> List[str]: - excluded_categories = self.excluded_categories if set(excluded_categories) == set(SAFETY_CATEGORIES_TO_CODE_MAP.values()): excluded_categories = [] @@ -181,7 +179,6 @@ class LlamaGuardShield(ShieldBase): return categories def build_prompt(self, messages: List[Message]) -> str: - categories = self.get_safety_categories() categories_str = "\n".join(categories) conversations_str = "\n\n".join( @@ -225,7 +222,6 @@ class LlamaGuardShield(ShieldBase): is_violation=False, ) else: - prompt = self.build_prompt(messages) llama_guard_input = { "role": "user", diff --git a/llama_toolchain/safety/meta_reference/shields/prompt_guard.py b/llama_toolchain/safety/meta_reference/shields/prompt_guard.py index 74b1757bd..0acc1e488 100644 --- a/llama_toolchain/safety/meta_reference/shields/prompt_guard.py +++ b/llama_toolchain/safety/meta_reference/shields/prompt_guard.py @@ -18,7 +18,6 @@ from llama_toolchain.safety.api.datatypes import * # noqa: F403 class PromptGuardShield(TextShield): - class Mode(Enum): INJECTION = auto() JAILBREAK = auto() diff --git a/llama_toolchain/synthetic_data_generation/api/endpoints.py b/llama_toolchain/synthetic_data_generation/api/endpoints.py index 8eada05cf..00dec62b6 100644 --- a/llama_toolchain/synthetic_data_generation/api/endpoints.py +++ b/llama_toolchain/synthetic_data_generation/api/endpoints.py @@ -37,4 +37,5 @@ class SyntheticDataGeneration(Protocol): def post_generate( self, request: SyntheticDataGenerationRequest, - ) -> Union[SyntheticDataGenerationResponse]: ... + ) -> Union[SyntheticDataGenerationResponse]: + ... diff --git a/tests/test_inference.py b/tests/test_inference.py index 4c28a4190..b6c56f769 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -36,7 +36,6 @@ llama download --source huggingface --model-id llama3_1_8b_instruct --hf-token < class InferenceTests(unittest.IsolatedAsyncioTestCase): - @classmethod def setUpClass(cls): # This runs the async setup function diff --git a/tests/test_ollama_inference.py b/tests/test_ollama_inference.py index bb05eaec7..b82e5b192 100644 --- a/tests/test_ollama_inference.py +++ b/tests/test_ollama_inference.py @@ -20,7 +20,6 @@ from llama_toolchain.inference.ollama.ollama import get_provider_impl class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): - async def asyncSetUp(self): ollama_config = OllamaImplConfig(url="http://localhost:11434")