formatting

This commit is contained in:
Dalton Flanagan 2024-08-14 17:03:43 -04:00
parent 069d877210
commit b311dcd143
22 changed files with 82 additions and 128 deletions

View file

@ -63,9 +63,9 @@ class ShieldCallStep(StepCommon):
@json_schema_type @json_schema_type
class MemoryRetrievalStep(StepCommon): class MemoryRetrievalStep(StepCommon):
step_type: Literal[ step_type: Literal[StepType.memory_retrieval.value] = (
StepType.memory_retrieval.value StepType.memory_retrieval.value
] = StepType.memory_retrieval.value )
memory_bank_ids: List[str] memory_bank_ids: List[str]
documents: List[MemoryBankDocument] documents: List[MemoryBankDocument]
scores: List[float] scores: List[float]
@ -140,9 +140,9 @@ class AgenticSystemTurnResponseEventType(Enum):
@json_schema_type @json_schema_type
class AgenticSystemTurnResponseStepStartPayload(BaseModel): class AgenticSystemTurnResponseStepStartPayload(BaseModel):
event_type: Literal[ event_type: Literal[AgenticSystemTurnResponseEventType.step_start.value] = (
AgenticSystemTurnResponseEventType.step_start.value AgenticSystemTurnResponseEventType.step_start.value
] = AgenticSystemTurnResponseEventType.step_start.value )
step_type: StepType step_type: StepType
step_id: str step_id: str
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict) metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
@ -150,9 +150,9 @@ class AgenticSystemTurnResponseStepStartPayload(BaseModel):
@json_schema_type @json_schema_type
class AgenticSystemTurnResponseStepCompletePayload(BaseModel): class AgenticSystemTurnResponseStepCompletePayload(BaseModel):
event_type: Literal[ event_type: Literal[AgenticSystemTurnResponseEventType.step_complete.value] = (
AgenticSystemTurnResponseEventType.step_complete.value AgenticSystemTurnResponseEventType.step_complete.value
] = AgenticSystemTurnResponseEventType.step_complete.value )
step_type: StepType step_type: StepType
step_details: Step step_details: Step
@ -161,9 +161,9 @@ class AgenticSystemTurnResponseStepCompletePayload(BaseModel):
class AgenticSystemTurnResponseStepProgressPayload(BaseModel): class AgenticSystemTurnResponseStepProgressPayload(BaseModel):
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
event_type: Literal[ event_type: Literal[AgenticSystemTurnResponseEventType.step_progress.value] = (
AgenticSystemTurnResponseEventType.step_progress.value AgenticSystemTurnResponseEventType.step_progress.value
] = AgenticSystemTurnResponseEventType.step_progress.value )
step_type: StepType step_type: StepType
step_id: str step_id: str
@ -174,17 +174,17 @@ class AgenticSystemTurnResponseStepProgressPayload(BaseModel):
@json_schema_type @json_schema_type
class AgenticSystemTurnResponseTurnStartPayload(BaseModel): class AgenticSystemTurnResponseTurnStartPayload(BaseModel):
event_type: Literal[ event_type: Literal[AgenticSystemTurnResponseEventType.turn_start.value] = (
AgenticSystemTurnResponseEventType.turn_start.value AgenticSystemTurnResponseEventType.turn_start.value
] = AgenticSystemTurnResponseEventType.turn_start.value )
turn_id: str turn_id: str
@json_schema_type @json_schema_type
class AgenticSystemTurnResponseTurnCompletePayload(BaseModel): class AgenticSystemTurnResponseTurnCompletePayload(BaseModel):
event_type: Literal[ event_type: Literal[AgenticSystemTurnResponseEventType.turn_complete.value] = (
AgenticSystemTurnResponseEventType.turn_complete.value AgenticSystemTurnResponseEventType.turn_complete.value
] = AgenticSystemTurnResponseEventType.turn_complete.value )
turn: Turn turn: Turn

View file

@ -67,36 +67,31 @@ class AgenticSystem(Protocol):
async def create_agentic_system( async def create_agentic_system(
self, self,
request: AgenticSystemCreateRequest, request: AgenticSystemCreateRequest,
) -> AgenticSystemCreateResponse: ) -> AgenticSystemCreateResponse: ...
...
@webmethod(route="/agentic_system/turn/create") @webmethod(route="/agentic_system/turn/create")
async def create_agentic_system_turn( async def create_agentic_system_turn(
self, self,
request: AgenticSystemTurnCreateRequest, request: AgenticSystemTurnCreateRequest,
) -> AgenticSystemTurnResponseStreamChunk: ) -> AgenticSystemTurnResponseStreamChunk: ...
...
@webmethod(route="/agentic_system/turn/get") @webmethod(route="/agentic_system/turn/get")
async def get_agentic_system_turn( async def get_agentic_system_turn(
self, self,
agent_id: str, agent_id: str,
turn_id: str, turn_id: str,
) -> Turn: ) -> Turn: ...
...
@webmethod(route="/agentic_system/step/get") @webmethod(route="/agentic_system/step/get")
async def get_agentic_system_step( async def get_agentic_system_step(
self, agent_id: str, turn_id: str, step_id: str self, agent_id: str, turn_id: str, step_id: str
) -> AgenticSystemStepResponse: ) -> AgenticSystemStepResponse: ...
...
@webmethod(route="/agentic_system/session/create") @webmethod(route="/agentic_system/session/create")
async def create_agentic_system_session( async def create_agentic_system_session(
self, self,
request: AgenticSystemSessionCreateRequest, request: AgenticSystemSessionCreateRequest,
) -> AgenticSystemSessionCreateResponse: ) -> AgenticSystemSessionCreateResponse: ...
...
@webmethod(route="/agentic_system/memory_bank/attach") @webmethod(route="/agentic_system/memory_bank/attach")
async def attach_memory_bank_to_agentic_system( async def attach_memory_bank_to_agentic_system(
@ -104,8 +99,7 @@ class AgenticSystem(Protocol):
agent_id: str, agent_id: str,
session_id: str, session_id: str,
memory_bank_ids: List[str], memory_bank_ids: List[str],
) -> None: ) -> None: ...
...
@webmethod(route="/agentic_system/memory_bank/detach") @webmethod(route="/agentic_system/memory_bank/detach")
async def detach_memory_bank_from_agentic_system( async def detach_memory_bank_from_agentic_system(
@ -113,8 +107,7 @@ class AgenticSystem(Protocol):
agent_id: str, agent_id: str,
session_id: str, session_id: str,
memory_bank_ids: List[str], memory_bank_ids: List[str],
) -> None: ) -> None: ...
...
@webmethod(route="/agentic_system/session/get") @webmethod(route="/agentic_system/session/get")
async def get_agentic_system_session( async def get_agentic_system_session(
@ -122,18 +115,15 @@ class AgenticSystem(Protocol):
agent_id: str, agent_id: str,
session_id: str, session_id: str,
turn_ids: Optional[List[str]] = None, turn_ids: Optional[List[str]] = None,
) -> Session: ) -> Session: ...
...
@webmethod(route="/agentic_system/session/delete") @webmethod(route="/agentic_system/session/delete")
async def delete_agentic_system_session( async def delete_agentic_system_session(
self, agent_id: str, session_id: str self, agent_id: str, session_id: str
) -> None: ) -> None: ...
...
@webmethod(route="/agentic_system/delete") @webmethod(route="/agentic_system/delete")
async def delete_agentic_system( async def delete_agentic_system(
self, self,
agent_id: str, agent_id: str,
) -> None: ) -> None: ...
...

View file

@ -10,11 +10,11 @@ import os
import pkg_resources import pkg_resources
import yaml import yaml
from termcolor import cprint
from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR
from termcolor import cprint
class DistributionInstall(Subcommand): class DistributionInstall(Subcommand):
"""Llama cli for configuring llama toolchain configs""" """Llama cli for configuring llama toolchain configs"""

View file

@ -14,10 +14,10 @@ from pathlib import Path
import httpx import httpx
from termcolor import cprint
from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.cli.subcommand import Subcommand
from termcolor import cprint
class Download(Subcommand): class Download(Subcommand):
"""Llama cli for downloading llama toolchain assets""" """Llama cli for downloading llama toolchain assets"""

View file

@ -9,12 +9,12 @@ import json
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from termcolor import colored
from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.cli.table import print_table from llama_toolchain.cli.table import print_table
from llama_toolchain.common.serialize import EnumEncoder from llama_toolchain.common.serialize import EnumEncoder
from termcolor import colored
class ModelDescribe(Subcommand): class ModelDescribe(Subcommand):
"""Show details about a model""" """Show details about a model"""

View file

@ -7,10 +7,10 @@
import argparse import argparse
import textwrap import textwrap
from termcolor import colored
from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.cli.subcommand import Subcommand
from termcolor import colored
class ModelTemplate(Subcommand): class ModelTemplate(Subcommand):
"""Llama model cli for describe a model template (message formats)""" """Llama model cli for describe a model template (message formats)"""

View file

@ -1,4 +1,5 @@
import os import os
from llama_models.datatypes import Model from llama_models.datatypes import Model
from .config_dirs import DEFAULT_CHECKPOINT_DIR from .config_dirs import DEFAULT_CHECKPOINT_DIR

View file

@ -26,19 +26,16 @@ class Datasets(Protocol):
def create_dataset( def create_dataset(
self, self,
request: CreateDatasetRequest, request: CreateDatasetRequest,
) -> None: ) -> None: ...
...
@webmethod(route="/datasets/get") @webmethod(route="/datasets/get")
def get_dataset( def get_dataset(
self, self,
dataset_uuid: str, dataset_uuid: str,
) -> TrainEvalDataset: ) -> TrainEvalDataset: ...
...
@webmethod(route="/datasets/delete") @webmethod(route="/datasets/delete")
def delete_dataset( def delete_dataset(
self, self,
dataset_uuid: str, dataset_uuid: str,
) -> None: ) -> None: ...
...

View file

@ -63,42 +63,36 @@ class Evaluations(Protocol):
def post_evaluate_text_generation( def post_evaluate_text_generation(
self, self,
request: EvaluateTextGenerationRequest, request: EvaluateTextGenerationRequest,
) -> EvaluationJob: ) -> EvaluationJob: ...
...
@webmethod(route="/evaluate/question_answering/") @webmethod(route="/evaluate/question_answering/")
def post_evaluate_question_answering( def post_evaluate_question_answering(
self, self,
request: EvaluateQuestionAnsweringRequest, request: EvaluateQuestionAnsweringRequest,
) -> EvaluationJob: ) -> EvaluationJob: ...
...
@webmethod(route="/evaluate/summarization/") @webmethod(route="/evaluate/summarization/")
def post_evaluate_summarization( def post_evaluate_summarization(
self, self,
request: EvaluateSummarizationRequest, request: EvaluateSummarizationRequest,
) -> EvaluationJob: ) -> EvaluationJob: ...
...
@webmethod(route="/evaluate/jobs") @webmethod(route="/evaluate/jobs")
def get_evaluation_jobs(self) -> List[EvaluationJob]: def get_evaluation_jobs(self) -> List[EvaluationJob]: ...
...
@webmethod(route="/evaluate/job/status") @webmethod(route="/evaluate/job/status")
def get_evaluation_job_status(self, job_uuid: str) -> EvaluationJobStatusResponse: def get_evaluation_job_status(
... self, job_uuid: str
) -> EvaluationJobStatusResponse: ...
# sends SSE stream of logs # sends SSE stream of logs
@webmethod(route="/evaluate/job/logs") @webmethod(route="/evaluate/job/logs")
def get_evaluation_job_logstream(self, job_uuid: str) -> EvaluationJobLogStream: def get_evaluation_job_logstream(self, job_uuid: str) -> EvaluationJobLogStream: ...
...
@webmethod(route="/evaluate/job/cancel") @webmethod(route="/evaluate/job/cancel")
def cancel_evaluation_job(self, job_uuid: str) -> None: def cancel_evaluation_job(self, job_uuid: str) -> None: ...
...
@webmethod(route="/evaluate/job/artifacts") @webmethod(route="/evaluate/job/artifacts")
def get_evaluation_job_artifacts( def get_evaluation_job_artifacts(
self, job_uuid: str self, job_uuid: str
) -> EvaluationJobArtifactsResponse: ) -> EvaluationJobArtifactsResponse: ...
...

View file

@ -101,26 +101,22 @@ class Inference(Protocol):
async def completion( async def completion(
self, self,
request: CompletionRequest, request: CompletionRequest,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
...
@webmethod(route="/inference/chat_completion") @webmethod(route="/inference/chat_completion")
async def chat_completion( async def chat_completion(
self, self,
request: ChatCompletionRequest, request: ChatCompletionRequest,
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ...
...
@webmethod(route="/inference/batch_completion") @webmethod(route="/inference/batch_completion")
async def batch_completion( async def batch_completion(
self, self,
request: BatchCompletionRequest, request: BatchCompletionRequest,
) -> BatchCompletionResponse: ) -> BatchCompletionResponse: ...
...
@webmethod(route="/inference/batch_chat_completion") @webmethod(route="/inference/batch_chat_completion")
async def batch_chat_completion( async def batch_chat_completion(
self, self,
request: BatchChatCompletionRequest, request: BatchChatCompletionRequest,
) -> BatchChatCompletionResponse: ) -> BatchChatCompletionResponse: ...
...

View file

@ -4,12 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from termcolor import cprint
from llama_toolchain.inference.api import ( from llama_toolchain.inference.api import (
ChatCompletionResponseEventType, ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk, ChatCompletionResponseStreamChunk,
) )
from termcolor import cprint
class LogEvent: class LogEvent:

View file

@ -11,10 +11,10 @@ from llama_models.datatypes import ModelFamily
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
from llama_models.sku_list import all_registered_models from llama_models.sku_list import all_registered_models
from pydantic import BaseModel, Field, field_validator
from llama_toolchain.inference.api import QuantizationConfig from llama_toolchain.inference.api import QuantizationConfig
from pydantic import BaseModel, Field, field_validator
@json_schema_type @json_schema_type
class MetaReferenceImplConfig(BaseModel): class MetaReferenceImplConfig(BaseModel):

View file

@ -28,10 +28,10 @@ from llama_models.llama3_1.api.datatypes import Message
from llama_models.llama3_1.api.tokenizer import Tokenizer from llama_models.llama3_1.api.tokenizer import Tokenizer
from llama_models.llama3_1.reference_impl.model import Transformer from llama_models.llama3_1.reference_impl.model import Transformer
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from termcolor import cprint
from llama_toolchain.common.model_utils import model_local_dir from llama_toolchain.common.model_utils import model_local_dir
from llama_toolchain.inference.api import QuantizationType from llama_toolchain.inference.api import QuantizationType
from termcolor import cprint
from .config import MetaReferenceImplConfig from .config import MetaReferenceImplConfig

View file

@ -18,52 +18,44 @@ class MemoryBanks(Protocol):
bank_id: str, bank_id: str,
bank_name: str, bank_name: str,
documents: List[MemoryBankDocument], documents: List[MemoryBankDocument],
) -> None: ) -> None: ...
...
@webmethod(route="/memory_banks/list") @webmethod(route="/memory_banks/list")
def get_memory_banks(self) -> List[MemoryBank]: def get_memory_banks(self) -> List[MemoryBank]: ...
...
@webmethod(route="/memory_banks/get") @webmethod(route="/memory_banks/get")
def get_memory_bank(self, bank_id: str) -> List[MemoryBank]: def get_memory_bank(self, bank_id: str) -> List[MemoryBank]: ...
...
@webmethod(route="/memory_banks/drop") @webmethod(route="/memory_banks/drop")
def delete_memory_bank( def delete_memory_bank(
self, self,
bank_id: str, bank_id: str,
) -> str: ) -> str: ...
...
@webmethod(route="/memory_bank/insert") @webmethod(route="/memory_bank/insert")
def post_insert_memory_documents( def post_insert_memory_documents(
self, self,
bank_id: str, bank_id: str,
documents: List[MemoryBankDocument], documents: List[MemoryBankDocument],
) -> None: ) -> None: ...
...
@webmethod(route="/memory_bank/update") @webmethod(route="/memory_bank/update")
def post_update_memory_documents( def post_update_memory_documents(
self, self,
bank_id: str, bank_id: str,
documents: List[MemoryBankDocument], documents: List[MemoryBankDocument],
) -> None: ) -> None: ...
...
@webmethod(route="/memory_bank/get") @webmethod(route="/memory_bank/get")
def get_memory_documents( def get_memory_documents(
self, self,
bank_id: str, bank_id: str,
document_uuids: List[str], document_uuids: List[str],
) -> List[MemoryBankDocument]: ) -> List[MemoryBankDocument]: ...
...
@webmethod(route="/memory_bank/delete") @webmethod(route="/memory_bank/delete")
def delete_memory_documents( def delete_memory_documents(
self, self,
bank_id: str, bank_id: str,
document_uuids: List[str], document_uuids: List[str],
) -> List[str]: ) -> List[str]: ...
...

View file

@ -11,5 +11,4 @@ from llama_models.schema_utils import webmethod # noqa: F401
from pydantic import BaseModel # noqa: F401 from pydantic import BaseModel # noqa: F401
class Models(Protocol): class Models(Protocol): ...
...

View file

@ -98,35 +98,30 @@ class PostTraining(Protocol):
def post_supervised_fine_tune( def post_supervised_fine_tune(
self, self,
request: PostTrainingSFTRequest, request: PostTrainingSFTRequest,
) -> PostTrainingJob: ) -> PostTrainingJob: ...
...
@webmethod(route="/post_training/preference_optimize") @webmethod(route="/post_training/preference_optimize")
def post_preference_optimize( def post_preference_optimize(
self, self,
request: PostTrainingRLHFRequest, request: PostTrainingRLHFRequest,
) -> PostTrainingJob: ) -> PostTrainingJob: ...
...
@webmethod(route="/post_training/jobs") @webmethod(route="/post_training/jobs")
def get_training_jobs(self) -> List[PostTrainingJob]: def get_training_jobs(self) -> List[PostTrainingJob]: ...
...
# sends SSE stream of logs # sends SSE stream of logs
@webmethod(route="/post_training/job/logs") @webmethod(route="/post_training/job/logs")
def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream: def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream: ...
...
@webmethod(route="/post_training/job/status") @webmethod(route="/post_training/job/status")
def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse: def get_training_job_status(
... self, job_uuid: str
) -> PostTrainingJobStatusResponse: ...
@webmethod(route="/post_training/job/cancel") @webmethod(route="/post_training/job/cancel")
def cancel_training_job(self, job_uuid: str) -> None: def cancel_training_job(self, job_uuid: str) -> None: ...
...
@webmethod(route="/post_training/job/artifacts") @webmethod(route="/post_training/job/artifacts")
def get_training_job_artifacts( def get_training_job_artifacts(
self, job_uuid: str self, job_uuid: str
) -> PostTrainingJobArtifactsResponse: ) -> PostTrainingJobArtifactsResponse: ...
...

View file

@ -30,5 +30,4 @@ class RewardScoring(Protocol):
def post_score( def post_score(
self, self,
request: RewardScoringRequest, request: RewardScoringRequest,
) -> Union[RewardScoringResponse]: ) -> Union[RewardScoringResponse]: ...
...

View file

@ -11,10 +11,10 @@ from llama_models.llama3_1.api.datatypes import ToolParamDefinition
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
from llama_toolchain.common.deployment_types import RestAPIExecutionConfig from llama_toolchain.common.deployment_types import RestAPIExecutionConfig
from pydantic import BaseModel
@json_schema_type @json_schema_type
class BuiltinShield(Enum): class BuiltinShield(Enum):

View file

@ -29,5 +29,4 @@ class Safety(Protocol):
async def run_shields( async def run_shields(
self, self,
request: RunShieldRequest, request: RunShieldRequest,
) -> RunShieldResponse: ) -> RunShieldResponse: ...
...

View file

@ -37,5 +37,4 @@ class SyntheticDataGeneration(Protocol):
def post_generate( def post_generate(
self, self,
request: SyntheticDataGenerationRequest, request: SyntheticDataGenerationRequest,
) -> Union[SyntheticDataGenerationResponse]: ) -> Union[SyntheticDataGenerationResponse]: ...
...

View file

@ -10,20 +10,16 @@ from datetime import datetime
from llama_models.llama3_1.api.datatypes import ( from llama_models.llama3_1.api.datatypes import (
BuiltinTool, BuiltinTool,
UserMessage,
StopReason, StopReason,
SystemMessage, SystemMessage,
ToolResponseMessage, ToolResponseMessage,
UserMessage,
) )
from llama_toolchain.inference.api.datatypes import ( from llama_toolchain.inference.api.datatypes import ChatCompletionResponseEventType
ChatCompletionResponseEventType,
)
from llama_toolchain.inference.meta_reference.inference import get_provider_impl
from llama_toolchain.inference.meta_reference.config import (
MetaReferenceImplConfig,
)
from llama_toolchain.inference.api.endpoints import ChatCompletionRequest from llama_toolchain.inference.api.endpoints import ChatCompletionRequest
from llama_toolchain.inference.meta_reference.config import MetaReferenceImplConfig
from llama_toolchain.inference.meta_reference.inference import get_provider_impl
MODEL = "Meta-Llama3.1-8B-Instruct" MODEL = "Meta-Llama3.1-8B-Instruct"

View file

@ -4,16 +4,14 @@ from datetime import datetime
from llama_models.llama3_1.api.datatypes import ( from llama_models.llama3_1.api.datatypes import (
BuiltinTool, BuiltinTool,
UserMessage,
StopReason,
SamplingParams, SamplingParams,
SamplingStrategy, SamplingStrategy,
StopReason,
SystemMessage, SystemMessage,
ToolResponseMessage, ToolResponseMessage,
UserMessage,
) )
from llama_toolchain.inference.api.datatypes import ( from llama_toolchain.inference.api.datatypes import ChatCompletionResponseEventType
ChatCompletionResponseEventType,
)
from llama_toolchain.inference.api.endpoints import ChatCompletionRequest from llama_toolchain.inference.api.endpoints import ChatCompletionRequest
from llama_toolchain.inference.ollama.config import OllamaImplConfig from llama_toolchain.inference.ollama.config import OllamaImplConfig
from llama_toolchain.inference.ollama.ollama import get_provider_impl from llama_toolchain.inference.ollama.ollama import get_provider_impl