formatting

This commit is contained in:
Dalton Flanagan 2024-08-14 17:03:43 -04:00
parent 069d877210
commit b311dcd143
22 changed files with 82 additions and 128 deletions

View file

@ -63,9 +63,9 @@ class ShieldCallStep(StepCommon):
@json_schema_type
class MemoryRetrievalStep(StepCommon):
step_type: Literal[
step_type: Literal[StepType.memory_retrieval.value] = (
StepType.memory_retrieval.value
] = StepType.memory_retrieval.value
)
memory_bank_ids: List[str]
documents: List[MemoryBankDocument]
scores: List[float]
@ -140,9 +140,9 @@ class AgenticSystemTurnResponseEventType(Enum):
@json_schema_type
class AgenticSystemTurnResponseStepStartPayload(BaseModel):
event_type: Literal[
event_type: Literal[AgenticSystemTurnResponseEventType.step_start.value] = (
AgenticSystemTurnResponseEventType.step_start.value
] = AgenticSystemTurnResponseEventType.step_start.value
)
step_type: StepType
step_id: str
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
@ -150,9 +150,9 @@ class AgenticSystemTurnResponseStepStartPayload(BaseModel):
@json_schema_type
class AgenticSystemTurnResponseStepCompletePayload(BaseModel):
event_type: Literal[
event_type: Literal[AgenticSystemTurnResponseEventType.step_complete.value] = (
AgenticSystemTurnResponseEventType.step_complete.value
] = AgenticSystemTurnResponseEventType.step_complete.value
)
step_type: StepType
step_details: Step
@ -161,9 +161,9 @@ class AgenticSystemTurnResponseStepCompletePayload(BaseModel):
class AgenticSystemTurnResponseStepProgressPayload(BaseModel):
model_config = ConfigDict(protected_namespaces=())
event_type: Literal[
event_type: Literal[AgenticSystemTurnResponseEventType.step_progress.value] = (
AgenticSystemTurnResponseEventType.step_progress.value
] = AgenticSystemTurnResponseEventType.step_progress.value
)
step_type: StepType
step_id: str
@ -174,17 +174,17 @@ class AgenticSystemTurnResponseStepProgressPayload(BaseModel):
@json_schema_type
class AgenticSystemTurnResponseTurnStartPayload(BaseModel):
event_type: Literal[
event_type: Literal[AgenticSystemTurnResponseEventType.turn_start.value] = (
AgenticSystemTurnResponseEventType.turn_start.value
] = AgenticSystemTurnResponseEventType.turn_start.value
)
turn_id: str
@json_schema_type
class AgenticSystemTurnResponseTurnCompletePayload(BaseModel):
event_type: Literal[
event_type: Literal[AgenticSystemTurnResponseEventType.turn_complete.value] = (
AgenticSystemTurnResponseEventType.turn_complete.value
] = AgenticSystemTurnResponseEventType.turn_complete.value
)
turn: Turn

View file

@ -67,36 +67,31 @@ class AgenticSystem(Protocol):
async def create_agentic_system(
self,
request: AgenticSystemCreateRequest,
) -> AgenticSystemCreateResponse:
...
) -> AgenticSystemCreateResponse: ...
@webmethod(route="/agentic_system/turn/create")
async def create_agentic_system_turn(
self,
request: AgenticSystemTurnCreateRequest,
) -> AgenticSystemTurnResponseStreamChunk:
...
) -> AgenticSystemTurnResponseStreamChunk: ...
@webmethod(route="/agentic_system/turn/get")
async def get_agentic_system_turn(
self,
agent_id: str,
turn_id: str,
) -> Turn:
...
) -> Turn: ...
@webmethod(route="/agentic_system/step/get")
async def get_agentic_system_step(
self, agent_id: str, turn_id: str, step_id: str
) -> AgenticSystemStepResponse:
...
) -> AgenticSystemStepResponse: ...
@webmethod(route="/agentic_system/session/create")
async def create_agentic_system_session(
self,
request: AgenticSystemSessionCreateRequest,
) -> AgenticSystemSessionCreateResponse:
...
) -> AgenticSystemSessionCreateResponse: ...
@webmethod(route="/agentic_system/memory_bank/attach")
async def attach_memory_bank_to_agentic_system(
@ -104,8 +99,7 @@ class AgenticSystem(Protocol):
agent_id: str,
session_id: str,
memory_bank_ids: List[str],
) -> None:
...
) -> None: ...
@webmethod(route="/agentic_system/memory_bank/detach")
async def detach_memory_bank_from_agentic_system(
@ -113,8 +107,7 @@ class AgenticSystem(Protocol):
agent_id: str,
session_id: str,
memory_bank_ids: List[str],
) -> None:
...
) -> None: ...
@webmethod(route="/agentic_system/session/get")
async def get_agentic_system_session(
@ -122,18 +115,15 @@ class AgenticSystem(Protocol):
agent_id: str,
session_id: str,
turn_ids: Optional[List[str]] = None,
) -> Session:
...
) -> Session: ...
@webmethod(route="/agentic_system/session/delete")
async def delete_agentic_system_session(
self, agent_id: str, session_id: str
) -> None:
...
) -> None: ...
@webmethod(route="/agentic_system/delete")
async def delete_agentic_system(
self,
agent_id: str,
) -> None:
...
) -> None: ...

View file

@ -10,11 +10,11 @@ import os
import pkg_resources
import yaml
from termcolor import cprint
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR
from termcolor import cprint
class DistributionInstall(Subcommand):
"""Llama cli for configuring llama toolchain configs"""

View file

@ -14,10 +14,10 @@ from pathlib import Path
import httpx
from termcolor import cprint
from llama_toolchain.cli.subcommand import Subcommand
from termcolor import cprint
class Download(Subcommand):
"""Llama cli for downloading llama toolchain assets"""

View file

@ -9,12 +9,12 @@ import json
from llama_models.sku_list import resolve_model
from termcolor import colored
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.cli.table import print_table
from llama_toolchain.common.serialize import EnumEncoder
from termcolor import colored
class ModelDescribe(Subcommand):
"""Show details about a model"""

View file

@ -7,10 +7,10 @@
import argparse
import textwrap
from termcolor import colored
from llama_toolchain.cli.subcommand import Subcommand
from termcolor import colored
class ModelTemplate(Subcommand):
"""Llama model cli for describe a model template (message formats)"""

View file

@ -1,4 +1,5 @@
import os
from llama_models.datatypes import Model
from .config_dirs import DEFAULT_CHECKPOINT_DIR

View file

@ -26,19 +26,16 @@ class Datasets(Protocol):
def create_dataset(
self,
request: CreateDatasetRequest,
) -> None:
...
) -> None: ...
@webmethod(route="/datasets/get")
def get_dataset(
self,
dataset_uuid: str,
) -> TrainEvalDataset:
...
) -> TrainEvalDataset: ...
@webmethod(route="/datasets/delete")
def delete_dataset(
self,
dataset_uuid: str,
) -> None:
...
) -> None: ...

View file

@ -63,42 +63,36 @@ class Evaluations(Protocol):
def post_evaluate_text_generation(
self,
request: EvaluateTextGenerationRequest,
) -> EvaluationJob:
...
) -> EvaluationJob: ...
@webmethod(route="/evaluate/question_answering/")
def post_evaluate_question_answering(
self,
request: EvaluateQuestionAnsweringRequest,
) -> EvaluationJob:
...
) -> EvaluationJob: ...
@webmethod(route="/evaluate/summarization/")
def post_evaluate_summarization(
self,
request: EvaluateSummarizationRequest,
) -> EvaluationJob:
...
) -> EvaluationJob: ...
@webmethod(route="/evaluate/jobs")
def get_evaluation_jobs(self) -> List[EvaluationJob]:
...
def get_evaluation_jobs(self) -> List[EvaluationJob]: ...
@webmethod(route="/evaluate/job/status")
def get_evaluation_job_status(self, job_uuid: str) -> EvaluationJobStatusResponse:
...
def get_evaluation_job_status(
self, job_uuid: str
) -> EvaluationJobStatusResponse: ...
# sends SSE stream of logs
@webmethod(route="/evaluate/job/logs")
def get_evaluation_job_logstream(self, job_uuid: str) -> EvaluationJobLogStream:
...
def get_evaluation_job_logstream(self, job_uuid: str) -> EvaluationJobLogStream: ...
@webmethod(route="/evaluate/job/cancel")
def cancel_evaluation_job(self, job_uuid: str) -> None:
...
def cancel_evaluation_job(self, job_uuid: str) -> None: ...
@webmethod(route="/evaluate/job/artifacts")
def get_evaluation_job_artifacts(
self, job_uuid: str
) -> EvaluationJobArtifactsResponse:
...
) -> EvaluationJobArtifactsResponse: ...

View file

@ -101,26 +101,22 @@ class Inference(Protocol):
async def completion(
self,
request: CompletionRequest,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
...
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
@webmethod(route="/inference/chat_completion")
async def chat_completion(
self,
request: ChatCompletionRequest,
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]:
...
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ...
@webmethod(route="/inference/batch_completion")
async def batch_completion(
self,
request: BatchCompletionRequest,
) -> BatchCompletionResponse:
...
) -> BatchCompletionResponse: ...
@webmethod(route="/inference/batch_chat_completion")
async def batch_chat_completion(
self,
request: BatchChatCompletionRequest,
) -> BatchChatCompletionResponse:
...
) -> BatchChatCompletionResponse: ...

View file

@ -4,12 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from termcolor import cprint
from llama_toolchain.inference.api import (
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
)
from termcolor import cprint
class LogEvent:

View file

@ -11,10 +11,10 @@ from llama_models.datatypes import ModelFamily
from llama_models.schema_utils import json_schema_type
from llama_models.sku_list import all_registered_models
from pydantic import BaseModel, Field, field_validator
from llama_toolchain.inference.api import QuantizationConfig
from pydantic import BaseModel, Field, field_validator
@json_schema_type
class MetaReferenceImplConfig(BaseModel):

View file

@ -28,10 +28,10 @@ from llama_models.llama3_1.api.datatypes import Message
from llama_models.llama3_1.api.tokenizer import Tokenizer
from llama_models.llama3_1.reference_impl.model import Transformer
from llama_models.sku_list import resolve_model
from termcolor import cprint
from llama_toolchain.common.model_utils import model_local_dir
from llama_toolchain.inference.api import QuantizationType
from termcolor import cprint
from .config import MetaReferenceImplConfig

View file

@ -18,52 +18,44 @@ class MemoryBanks(Protocol):
bank_id: str,
bank_name: str,
documents: List[MemoryBankDocument],
) -> None:
...
) -> None: ...
@webmethod(route="/memory_banks/list")
def get_memory_banks(self) -> List[MemoryBank]:
...
def get_memory_banks(self) -> List[MemoryBank]: ...
@webmethod(route="/memory_banks/get")
def get_memory_bank(self, bank_id: str) -> List[MemoryBank]:
...
def get_memory_bank(self, bank_id: str) -> List[MemoryBank]: ...
@webmethod(route="/memory_banks/drop")
def delete_memory_bank(
self,
bank_id: str,
) -> str:
...
) -> str: ...
@webmethod(route="/memory_bank/insert")
def post_insert_memory_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
) -> None:
...
) -> None: ...
@webmethod(route="/memory_bank/update")
def post_update_memory_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
) -> None:
...
) -> None: ...
@webmethod(route="/memory_bank/get")
def get_memory_documents(
self,
bank_id: str,
document_uuids: List[str],
) -> List[MemoryBankDocument]:
...
) -> List[MemoryBankDocument]: ...
@webmethod(route="/memory_bank/delete")
def delete_memory_documents(
self,
bank_id: str,
document_uuids: List[str],
) -> List[str]:
...
) -> List[str]: ...

View file

@ -11,5 +11,4 @@ from llama_models.schema_utils import webmethod # noqa: F401
from pydantic import BaseModel # noqa: F401
class Models(Protocol):
...
class Models(Protocol): ...

View file

@ -98,35 +98,30 @@ class PostTraining(Protocol):
def post_supervised_fine_tune(
self,
request: PostTrainingSFTRequest,
) -> PostTrainingJob:
...
) -> PostTrainingJob: ...
@webmethod(route="/post_training/preference_optimize")
def post_preference_optimize(
self,
request: PostTrainingRLHFRequest,
) -> PostTrainingJob:
...
) -> PostTrainingJob: ...
@webmethod(route="/post_training/jobs")
def get_training_jobs(self) -> List[PostTrainingJob]:
...
def get_training_jobs(self) -> List[PostTrainingJob]: ...
# sends SSE stream of logs
@webmethod(route="/post_training/job/logs")
def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream:
...
def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream: ...
@webmethod(route="/post_training/job/status")
def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse:
...
def get_training_job_status(
self, job_uuid: str
) -> PostTrainingJobStatusResponse: ...
@webmethod(route="/post_training/job/cancel")
def cancel_training_job(self, job_uuid: str) -> None:
...
def cancel_training_job(self, job_uuid: str) -> None: ...
@webmethod(route="/post_training/job/artifacts")
def get_training_job_artifacts(
self, job_uuid: str
) -> PostTrainingJobArtifactsResponse:
...
) -> PostTrainingJobArtifactsResponse: ...

View file

@ -30,5 +30,4 @@ class RewardScoring(Protocol):
def post_score(
self,
request: RewardScoringRequest,
) -> Union[RewardScoringResponse]:
...
) -> Union[RewardScoringResponse]: ...

View file

@ -11,10 +11,10 @@ from llama_models.llama3_1.api.datatypes import ToolParamDefinition
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
from llama_toolchain.common.deployment_types import RestAPIExecutionConfig
from pydantic import BaseModel
@json_schema_type
class BuiltinShield(Enum):

View file

@ -29,5 +29,4 @@ class Safety(Protocol):
async def run_shields(
self,
request: RunShieldRequest,
) -> RunShieldResponse:
...
) -> RunShieldResponse: ...

View file

@ -37,5 +37,4 @@ class SyntheticDataGeneration(Protocol):
def post_generate(
self,
request: SyntheticDataGenerationRequest,
) -> Union[SyntheticDataGenerationResponse]:
...
) -> Union[SyntheticDataGenerationResponse]: ...

View file

@ -10,20 +10,16 @@ from datetime import datetime
from llama_models.llama3_1.api.datatypes import (
BuiltinTool,
UserMessage,
StopReason,
SystemMessage,
ToolResponseMessage,
UserMessage,
)
from llama_toolchain.inference.api.datatypes import (
ChatCompletionResponseEventType,
)
from llama_toolchain.inference.meta_reference.inference import get_provider_impl
from llama_toolchain.inference.meta_reference.config import (
MetaReferenceImplConfig,
)
from llama_toolchain.inference.api.datatypes import ChatCompletionResponseEventType
from llama_toolchain.inference.api.endpoints import ChatCompletionRequest
from llama_toolchain.inference.meta_reference.config import MetaReferenceImplConfig
from llama_toolchain.inference.meta_reference.inference import get_provider_impl
MODEL = "Meta-Llama3.1-8B-Instruct"

View file

@ -4,16 +4,14 @@ from datetime import datetime
from llama_models.llama3_1.api.datatypes import (
BuiltinTool,
UserMessage,
StopReason,
SamplingParams,
SamplingStrategy,
StopReason,
SystemMessage,
ToolResponseMessage,
UserMessage,
)
from llama_toolchain.inference.api.datatypes import (
ChatCompletionResponseEventType,
)
from llama_toolchain.inference.api.datatypes import ChatCompletionResponseEventType
from llama_toolchain.inference.api.endpoints import ChatCompletionRequest
from llama_toolchain.inference.ollama.config import OllamaImplConfig
from llama_toolchain.inference.ollama.ollama import get_provider_impl