diff --git a/llama_toolchain/agentic_system/api/datatypes.py b/llama_toolchain/agentic_system/api/datatypes.py index 0da30f228..1dda64834 100644 --- a/llama_toolchain/agentic_system/api/datatypes.py +++ b/llama_toolchain/agentic_system/api/datatypes.py @@ -63,9 +63,9 @@ class ShieldCallStep(StepCommon): @json_schema_type class MemoryRetrievalStep(StepCommon): - step_type: Literal[ + step_type: Literal[StepType.memory_retrieval.value] = ( StepType.memory_retrieval.value - ] = StepType.memory_retrieval.value + ) memory_bank_ids: List[str] documents: List[MemoryBankDocument] scores: List[float] @@ -140,9 +140,9 @@ class AgenticSystemTurnResponseEventType(Enum): @json_schema_type class AgenticSystemTurnResponseStepStartPayload(BaseModel): - event_type: Literal[ + event_type: Literal[AgenticSystemTurnResponseEventType.step_start.value] = ( AgenticSystemTurnResponseEventType.step_start.value - ] = AgenticSystemTurnResponseEventType.step_start.value + ) step_type: StepType step_id: str metadata: Optional[Dict[str, Any]] = Field(default_factory=dict) @@ -150,9 +150,9 @@ class AgenticSystemTurnResponseStepStartPayload(BaseModel): @json_schema_type class AgenticSystemTurnResponseStepCompletePayload(BaseModel): - event_type: Literal[ + event_type: Literal[AgenticSystemTurnResponseEventType.step_complete.value] = ( AgenticSystemTurnResponseEventType.step_complete.value - ] = AgenticSystemTurnResponseEventType.step_complete.value + ) step_type: StepType step_details: Step @@ -161,9 +161,9 @@ class AgenticSystemTurnResponseStepCompletePayload(BaseModel): class AgenticSystemTurnResponseStepProgressPayload(BaseModel): model_config = ConfigDict(protected_namespaces=()) - event_type: Literal[ + event_type: Literal[AgenticSystemTurnResponseEventType.step_progress.value] = ( AgenticSystemTurnResponseEventType.step_progress.value - ] = AgenticSystemTurnResponseEventType.step_progress.value + ) step_type: StepType step_id: str @@ -174,17 +174,17 @@ class AgenticSystemTurnResponseStepProgressPayload(BaseModel): @json_schema_type class AgenticSystemTurnResponseTurnStartPayload(BaseModel): - event_type: Literal[ + event_type: Literal[AgenticSystemTurnResponseEventType.turn_start.value] = ( AgenticSystemTurnResponseEventType.turn_start.value - ] = AgenticSystemTurnResponseEventType.turn_start.value + ) turn_id: str @json_schema_type class AgenticSystemTurnResponseTurnCompletePayload(BaseModel): - event_type: Literal[ + event_type: Literal[AgenticSystemTurnResponseEventType.turn_complete.value] = ( AgenticSystemTurnResponseEventType.turn_complete.value - ] = AgenticSystemTurnResponseEventType.turn_complete.value + ) turn: Turn diff --git a/llama_toolchain/agentic_system/api/endpoints.py b/llama_toolchain/agentic_system/api/endpoints.py index ee52f9b5a..1f6bdcc9d 100644 --- a/llama_toolchain/agentic_system/api/endpoints.py +++ b/llama_toolchain/agentic_system/api/endpoints.py @@ -67,36 +67,31 @@ class AgenticSystem(Protocol): async def create_agentic_system( self, request: AgenticSystemCreateRequest, - ) -> AgenticSystemCreateResponse: - ... + ) -> AgenticSystemCreateResponse: ... @webmethod(route="/agentic_system/turn/create") async def create_agentic_system_turn( self, request: AgenticSystemTurnCreateRequest, - ) -> AgenticSystemTurnResponseStreamChunk: - ... + ) -> AgenticSystemTurnResponseStreamChunk: ... @webmethod(route="/agentic_system/turn/get") async def get_agentic_system_turn( self, agent_id: str, turn_id: str, - ) -> Turn: - ... + ) -> Turn: ... @webmethod(route="/agentic_system/step/get") async def get_agentic_system_step( self, agent_id: str, turn_id: str, step_id: str - ) -> AgenticSystemStepResponse: - ... + ) -> AgenticSystemStepResponse: ... @webmethod(route="/agentic_system/session/create") async def create_agentic_system_session( self, request: AgenticSystemSessionCreateRequest, - ) -> AgenticSystemSessionCreateResponse: - ... + ) -> AgenticSystemSessionCreateResponse: ... @webmethod(route="/agentic_system/memory_bank/attach") async def attach_memory_bank_to_agentic_system( @@ -104,8 +99,7 @@ class AgenticSystem(Protocol): agent_id: str, session_id: str, memory_bank_ids: List[str], - ) -> None: - ... + ) -> None: ... @webmethod(route="/agentic_system/memory_bank/detach") async def detach_memory_bank_from_agentic_system( @@ -113,8 +107,7 @@ class AgenticSystem(Protocol): agent_id: str, session_id: str, memory_bank_ids: List[str], - ) -> None: - ... + ) -> None: ... @webmethod(route="/agentic_system/session/get") async def get_agentic_system_session( @@ -122,18 +115,15 @@ class AgenticSystem(Protocol): agent_id: str, session_id: str, turn_ids: Optional[List[str]] = None, - ) -> Session: - ... + ) -> Session: ... @webmethod(route="/agentic_system/session/delete") async def delete_agentic_system_session( self, agent_id: str, session_id: str - ) -> None: - ... + ) -> None: ... @webmethod(route="/agentic_system/delete") async def delete_agentic_system( self, agent_id: str, - ) -> None: - ... + ) -> None: ... diff --git a/llama_toolchain/cli/distribution/install.py b/llama_toolchain/cli/distribution/install.py index a056dba36..cd75effc3 100644 --- a/llama_toolchain/cli/distribution/install.py +++ b/llama_toolchain/cli/distribution/install.py @@ -10,11 +10,11 @@ import os import pkg_resources import yaml -from termcolor import cprint - from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR +from termcolor import cprint + class DistributionInstall(Subcommand): """Llama cli for configuring llama toolchain configs""" diff --git a/llama_toolchain/cli/download.py b/llama_toolchain/cli/download.py index e7f0c9f66..2a1c79220 100644 --- a/llama_toolchain/cli/download.py +++ b/llama_toolchain/cli/download.py @@ -14,10 +14,10 @@ from pathlib import Path import httpx -from termcolor import cprint - from llama_toolchain.cli.subcommand import Subcommand +from termcolor import cprint + class Download(Subcommand): """Llama cli for downloading llama toolchain assets""" diff --git a/llama_toolchain/cli/model/describe.py b/llama_toolchain/cli/model/describe.py index e0fb44a96..683995f7b 100644 --- a/llama_toolchain/cli/model/describe.py +++ b/llama_toolchain/cli/model/describe.py @@ -9,12 +9,12 @@ import json from llama_models.sku_list import resolve_model -from termcolor import colored - from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.cli.table import print_table from llama_toolchain.common.serialize import EnumEncoder +from termcolor import colored + class ModelDescribe(Subcommand): """Show details about a model""" diff --git a/llama_toolchain/cli/model/template.py b/llama_toolchain/cli/model/template.py index c0ba60882..58b245035 100644 --- a/llama_toolchain/cli/model/template.py +++ b/llama_toolchain/cli/model/template.py @@ -7,10 +7,10 @@ import argparse import textwrap -from termcolor import colored - from llama_toolchain.cli.subcommand import Subcommand +from termcolor import colored + class ModelTemplate(Subcommand): """Llama model cli for describe a model template (message formats)""" diff --git a/llama_toolchain/common/model_utils.py b/llama_toolchain/common/model_utils.py index af3929cb7..282e02ea8 100644 --- a/llama_toolchain/common/model_utils.py +++ b/llama_toolchain/common/model_utils.py @@ -1,4 +1,5 @@ import os + from llama_models.datatypes import Model from .config_dirs import DEFAULT_CHECKPOINT_DIR diff --git a/llama_toolchain/dataset/api/endpoints.py b/llama_toolchain/dataset/api/endpoints.py index 92d0f40c9..6a88f4b7a 100644 --- a/llama_toolchain/dataset/api/endpoints.py +++ b/llama_toolchain/dataset/api/endpoints.py @@ -26,19 +26,16 @@ class Datasets(Protocol): def create_dataset( self, request: CreateDatasetRequest, - ) -> None: - ... + ) -> None: ... @webmethod(route="/datasets/get") def get_dataset( self, dataset_uuid: str, - ) -> TrainEvalDataset: - ... + ) -> TrainEvalDataset: ... @webmethod(route="/datasets/delete") def delete_dataset( self, dataset_uuid: str, - ) -> None: - ... + ) -> None: ... diff --git a/llama_toolchain/evaluations/api/endpoints.py b/llama_toolchain/evaluations/api/endpoints.py index d8e932c88..39b9a28e0 100644 --- a/llama_toolchain/evaluations/api/endpoints.py +++ b/llama_toolchain/evaluations/api/endpoints.py @@ -63,42 +63,36 @@ class Evaluations(Protocol): def post_evaluate_text_generation( self, request: EvaluateTextGenerationRequest, - ) -> EvaluationJob: - ... + ) -> EvaluationJob: ... @webmethod(route="/evaluate/question_answering/") def post_evaluate_question_answering( self, request: EvaluateQuestionAnsweringRequest, - ) -> EvaluationJob: - ... + ) -> EvaluationJob: ... @webmethod(route="/evaluate/summarization/") def post_evaluate_summarization( self, request: EvaluateSummarizationRequest, - ) -> EvaluationJob: - ... + ) -> EvaluationJob: ... @webmethod(route="/evaluate/jobs") - def get_evaluation_jobs(self) -> List[EvaluationJob]: - ... + def get_evaluation_jobs(self) -> List[EvaluationJob]: ... @webmethod(route="/evaluate/job/status") - def get_evaluation_job_status(self, job_uuid: str) -> EvaluationJobStatusResponse: - ... + def get_evaluation_job_status( + self, job_uuid: str + ) -> EvaluationJobStatusResponse: ... # sends SSE stream of logs @webmethod(route="/evaluate/job/logs") - def get_evaluation_job_logstream(self, job_uuid: str) -> EvaluationJobLogStream: - ... + def get_evaluation_job_logstream(self, job_uuid: str) -> EvaluationJobLogStream: ... @webmethod(route="/evaluate/job/cancel") - def cancel_evaluation_job(self, job_uuid: str) -> None: - ... + def cancel_evaluation_job(self, job_uuid: str) -> None: ... @webmethod(route="/evaluate/job/artifacts") def get_evaluation_job_artifacts( self, job_uuid: str - ) -> EvaluationJobArtifactsResponse: - ... + ) -> EvaluationJobArtifactsResponse: ... diff --git a/llama_toolchain/inference/api/endpoints.py b/llama_toolchain/inference/api/endpoints.py index fe81695a0..a3ec18c95 100644 --- a/llama_toolchain/inference/api/endpoints.py +++ b/llama_toolchain/inference/api/endpoints.py @@ -101,26 +101,22 @@ class Inference(Protocol): async def completion( self, request: CompletionRequest, - ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: - ... + ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ... @webmethod(route="/inference/chat_completion") async def chat_completion( self, request: ChatCompletionRequest, - ) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: - ... + ) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ... @webmethod(route="/inference/batch_completion") async def batch_completion( self, request: BatchCompletionRequest, - ) -> BatchCompletionResponse: - ... + ) -> BatchCompletionResponse: ... @webmethod(route="/inference/batch_chat_completion") async def batch_chat_completion( self, request: BatchChatCompletionRequest, - ) -> BatchChatCompletionResponse: - ... + ) -> BatchChatCompletionResponse: ... diff --git a/llama_toolchain/inference/event_logger.py b/llama_toolchain/inference/event_logger.py index ebe241395..248ceae27 100644 --- a/llama_toolchain/inference/event_logger.py +++ b/llama_toolchain/inference/event_logger.py @@ -4,12 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from termcolor import cprint - from llama_toolchain.inference.api import ( ChatCompletionResponseEventType, ChatCompletionResponseStreamChunk, ) +from termcolor import cprint class LogEvent: diff --git a/llama_toolchain/inference/meta_reference/config.py b/llama_toolchain/inference/meta_reference/config.py index d2e601680..f85934118 100644 --- a/llama_toolchain/inference/meta_reference/config.py +++ b/llama_toolchain/inference/meta_reference/config.py @@ -11,10 +11,10 @@ from llama_models.datatypes import ModelFamily from llama_models.schema_utils import json_schema_type from llama_models.sku_list import all_registered_models -from pydantic import BaseModel, Field, field_validator - from llama_toolchain.inference.api import QuantizationConfig +from pydantic import BaseModel, Field, field_validator + @json_schema_type class MetaReferenceImplConfig(BaseModel): diff --git a/llama_toolchain/inference/meta_reference/generation.py b/llama_toolchain/inference/meta_reference/generation.py index 9594311ef..23df2e287 100644 --- a/llama_toolchain/inference/meta_reference/generation.py +++ b/llama_toolchain/inference/meta_reference/generation.py @@ -28,10 +28,10 @@ from llama_models.llama3_1.api.datatypes import Message from llama_models.llama3_1.api.tokenizer import Tokenizer from llama_models.llama3_1.reference_impl.model import Transformer from llama_models.sku_list import resolve_model -from termcolor import cprint from llama_toolchain.common.model_utils import model_local_dir from llama_toolchain.inference.api import QuantizationType +from termcolor import cprint from .config import MetaReferenceImplConfig diff --git a/llama_toolchain/memory/api/endpoints.py b/llama_toolchain/memory/api/endpoints.py index 66c6719a4..4261afa89 100644 --- a/llama_toolchain/memory/api/endpoints.py +++ b/llama_toolchain/memory/api/endpoints.py @@ -18,52 +18,44 @@ class MemoryBanks(Protocol): bank_id: str, bank_name: str, documents: List[MemoryBankDocument], - ) -> None: - ... + ) -> None: ... @webmethod(route="/memory_banks/list") - def get_memory_banks(self) -> List[MemoryBank]: - ... + def get_memory_banks(self) -> List[MemoryBank]: ... @webmethod(route="/memory_banks/get") - def get_memory_bank(self, bank_id: str) -> List[MemoryBank]: - ... + def get_memory_bank(self, bank_id: str) -> List[MemoryBank]: ... @webmethod(route="/memory_banks/drop") def delete_memory_bank( self, bank_id: str, - ) -> str: - ... + ) -> str: ... @webmethod(route="/memory_bank/insert") def post_insert_memory_documents( self, bank_id: str, documents: List[MemoryBankDocument], - ) -> None: - ... + ) -> None: ... @webmethod(route="/memory_bank/update") def post_update_memory_documents( self, bank_id: str, documents: List[MemoryBankDocument], - ) -> None: - ... + ) -> None: ... @webmethod(route="/memory_bank/get") def get_memory_documents( self, bank_id: str, document_uuids: List[str], - ) -> List[MemoryBankDocument]: - ... + ) -> List[MemoryBankDocument]: ... @webmethod(route="/memory_bank/delete") def delete_memory_documents( self, bank_id: str, document_uuids: List[str], - ) -> List[str]: - ... + ) -> List[str]: ... diff --git a/llama_toolchain/models/api/endpoints.py b/llama_toolchain/models/api/endpoints.py index 250df62e7..ee1d5f0ba 100644 --- a/llama_toolchain/models/api/endpoints.py +++ b/llama_toolchain/models/api/endpoints.py @@ -11,5 +11,4 @@ from llama_models.schema_utils import webmethod # noqa: F401 from pydantic import BaseModel # noqa: F401 -class Models(Protocol): - ... +class Models(Protocol): ... diff --git a/llama_toolchain/post_training/api/endpoints.py b/llama_toolchain/post_training/api/endpoints.py index 795e1d6f8..0512003d3 100644 --- a/llama_toolchain/post_training/api/endpoints.py +++ b/llama_toolchain/post_training/api/endpoints.py @@ -98,35 +98,30 @@ class PostTraining(Protocol): def post_supervised_fine_tune( self, request: PostTrainingSFTRequest, - ) -> PostTrainingJob: - ... + ) -> PostTrainingJob: ... @webmethod(route="/post_training/preference_optimize") def post_preference_optimize( self, request: PostTrainingRLHFRequest, - ) -> PostTrainingJob: - ... + ) -> PostTrainingJob: ... @webmethod(route="/post_training/jobs") - def get_training_jobs(self) -> List[PostTrainingJob]: - ... + def get_training_jobs(self) -> List[PostTrainingJob]: ... # sends SSE stream of logs @webmethod(route="/post_training/job/logs") - def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream: - ... + def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream: ... @webmethod(route="/post_training/job/status") - def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse: - ... + def get_training_job_status( + self, job_uuid: str + ) -> PostTrainingJobStatusResponse: ... @webmethod(route="/post_training/job/cancel") - def cancel_training_job(self, job_uuid: str) -> None: - ... + def cancel_training_job(self, job_uuid: str) -> None: ... @webmethod(route="/post_training/job/artifacts") def get_training_job_artifacts( self, job_uuid: str - ) -> PostTrainingJobArtifactsResponse: - ... + ) -> PostTrainingJobArtifactsResponse: ... diff --git a/llama_toolchain/reward_scoring/api/endpoints.py b/llama_toolchain/reward_scoring/api/endpoints.py index 7ff059f13..0a7327a9b 100644 --- a/llama_toolchain/reward_scoring/api/endpoints.py +++ b/llama_toolchain/reward_scoring/api/endpoints.py @@ -30,5 +30,4 @@ class RewardScoring(Protocol): def post_score( self, request: RewardScoringRequest, - ) -> Union[RewardScoringResponse]: - ... + ) -> Union[RewardScoringResponse]: ... diff --git a/llama_toolchain/safety/api/datatypes.py b/llama_toolchain/safety/api/datatypes.py index a3f67615a..c5734da99 100644 --- a/llama_toolchain/safety/api/datatypes.py +++ b/llama_toolchain/safety/api/datatypes.py @@ -11,10 +11,10 @@ from llama_models.llama3_1.api.datatypes import ToolParamDefinition from llama_models.schema_utils import json_schema_type -from pydantic import BaseModel - from llama_toolchain.common.deployment_types import RestAPIExecutionConfig +from pydantic import BaseModel + @json_schema_type class BuiltinShield(Enum): diff --git a/llama_toolchain/safety/api/endpoints.py b/llama_toolchain/safety/api/endpoints.py index c49577674..11c1282a1 100644 --- a/llama_toolchain/safety/api/endpoints.py +++ b/llama_toolchain/safety/api/endpoints.py @@ -29,5 +29,4 @@ class Safety(Protocol): async def run_shields( self, request: RunShieldRequest, - ) -> RunShieldResponse: - ... + ) -> RunShieldResponse: ... diff --git a/llama_toolchain/synthetic_data_generation/api/endpoints.py b/llama_toolchain/synthetic_data_generation/api/endpoints.py index 00dec62b6..8eada05cf 100644 --- a/llama_toolchain/synthetic_data_generation/api/endpoints.py +++ b/llama_toolchain/synthetic_data_generation/api/endpoints.py @@ -37,5 +37,4 @@ class SyntheticDataGeneration(Protocol): def post_generate( self, request: SyntheticDataGenerationRequest, - ) -> Union[SyntheticDataGenerationResponse]: - ... + ) -> Union[SyntheticDataGenerationResponse]: ... diff --git a/tests/test_inference.py b/tests/test_inference.py index b6c56f769..14ec5cdc2 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -10,20 +10,16 @@ from datetime import datetime from llama_models.llama3_1.api.datatypes import ( BuiltinTool, - UserMessage, StopReason, SystemMessage, ToolResponseMessage, + UserMessage, ) -from llama_toolchain.inference.api.datatypes import ( - ChatCompletionResponseEventType, -) -from llama_toolchain.inference.meta_reference.inference import get_provider_impl -from llama_toolchain.inference.meta_reference.config import ( - MetaReferenceImplConfig, -) +from llama_toolchain.inference.api.datatypes import ChatCompletionResponseEventType from llama_toolchain.inference.api.endpoints import ChatCompletionRequest +from llama_toolchain.inference.meta_reference.config import MetaReferenceImplConfig +from llama_toolchain.inference.meta_reference.inference import get_provider_impl MODEL = "Meta-Llama3.1-8B-Instruct" diff --git a/tests/test_ollama_inference.py b/tests/test_ollama_inference.py index b82e5b192..0459cd6dc 100644 --- a/tests/test_ollama_inference.py +++ b/tests/test_ollama_inference.py @@ -4,16 +4,14 @@ from datetime import datetime from llama_models.llama3_1.api.datatypes import ( BuiltinTool, - UserMessage, - StopReason, SamplingParams, SamplingStrategy, + StopReason, SystemMessage, ToolResponseMessage, + UserMessage, ) -from llama_toolchain.inference.api.datatypes import ( - ChatCompletionResponseEventType, -) +from llama_toolchain.inference.api.datatypes import ChatCompletionResponseEventType from llama_toolchain.inference.api.endpoints import ChatCompletionRequest from llama_toolchain.inference.ollama.config import OllamaImplConfig from llama_toolchain.inference.ollama.ollama import get_provider_impl