mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
formatting
This commit is contained in:
parent
069d877210
commit
b311dcd143
22 changed files with 82 additions and 128 deletions
|
@ -63,9 +63,9 @@ class ShieldCallStep(StepCommon):
|
|||
|
||||
@json_schema_type
|
||||
class MemoryRetrievalStep(StepCommon):
|
||||
step_type: Literal[
|
||||
step_type: Literal[StepType.memory_retrieval.value] = (
|
||||
StepType.memory_retrieval.value
|
||||
] = StepType.memory_retrieval.value
|
||||
)
|
||||
memory_bank_ids: List[str]
|
||||
documents: List[MemoryBankDocument]
|
||||
scores: List[float]
|
||||
|
@ -140,9 +140,9 @@ class AgenticSystemTurnResponseEventType(Enum):
|
|||
|
||||
@json_schema_type
|
||||
class AgenticSystemTurnResponseStepStartPayload(BaseModel):
|
||||
event_type: Literal[
|
||||
event_type: Literal[AgenticSystemTurnResponseEventType.step_start.value] = (
|
||||
AgenticSystemTurnResponseEventType.step_start.value
|
||||
] = AgenticSystemTurnResponseEventType.step_start.value
|
||||
)
|
||||
step_type: StepType
|
||||
step_id: str
|
||||
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
@ -150,9 +150,9 @@ class AgenticSystemTurnResponseStepStartPayload(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class AgenticSystemTurnResponseStepCompletePayload(BaseModel):
|
||||
event_type: Literal[
|
||||
event_type: Literal[AgenticSystemTurnResponseEventType.step_complete.value] = (
|
||||
AgenticSystemTurnResponseEventType.step_complete.value
|
||||
] = AgenticSystemTurnResponseEventType.step_complete.value
|
||||
)
|
||||
step_type: StepType
|
||||
step_details: Step
|
||||
|
||||
|
@ -161,9 +161,9 @@ class AgenticSystemTurnResponseStepCompletePayload(BaseModel):
|
|||
class AgenticSystemTurnResponseStepProgressPayload(BaseModel):
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
event_type: Literal[
|
||||
event_type: Literal[AgenticSystemTurnResponseEventType.step_progress.value] = (
|
||||
AgenticSystemTurnResponseEventType.step_progress.value
|
||||
] = AgenticSystemTurnResponseEventType.step_progress.value
|
||||
)
|
||||
step_type: StepType
|
||||
step_id: str
|
||||
|
||||
|
@ -174,17 +174,17 @@ class AgenticSystemTurnResponseStepProgressPayload(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class AgenticSystemTurnResponseTurnStartPayload(BaseModel):
|
||||
event_type: Literal[
|
||||
event_type: Literal[AgenticSystemTurnResponseEventType.turn_start.value] = (
|
||||
AgenticSystemTurnResponseEventType.turn_start.value
|
||||
] = AgenticSystemTurnResponseEventType.turn_start.value
|
||||
)
|
||||
turn_id: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgenticSystemTurnResponseTurnCompletePayload(BaseModel):
|
||||
event_type: Literal[
|
||||
event_type: Literal[AgenticSystemTurnResponseEventType.turn_complete.value] = (
|
||||
AgenticSystemTurnResponseEventType.turn_complete.value
|
||||
] = AgenticSystemTurnResponseEventType.turn_complete.value
|
||||
)
|
||||
turn: Turn
|
||||
|
||||
|
||||
|
|
|
@ -67,36 +67,31 @@ class AgenticSystem(Protocol):
|
|||
async def create_agentic_system(
|
||||
self,
|
||||
request: AgenticSystemCreateRequest,
|
||||
) -> AgenticSystemCreateResponse:
|
||||
...
|
||||
) -> AgenticSystemCreateResponse: ...
|
||||
|
||||
@webmethod(route="/agentic_system/turn/create")
|
||||
async def create_agentic_system_turn(
|
||||
self,
|
||||
request: AgenticSystemTurnCreateRequest,
|
||||
) -> AgenticSystemTurnResponseStreamChunk:
|
||||
...
|
||||
) -> AgenticSystemTurnResponseStreamChunk: ...
|
||||
|
||||
@webmethod(route="/agentic_system/turn/get")
|
||||
async def get_agentic_system_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
turn_id: str,
|
||||
) -> Turn:
|
||||
...
|
||||
) -> Turn: ...
|
||||
|
||||
@webmethod(route="/agentic_system/step/get")
|
||||
async def get_agentic_system_step(
|
||||
self, agent_id: str, turn_id: str, step_id: str
|
||||
) -> AgenticSystemStepResponse:
|
||||
...
|
||||
) -> AgenticSystemStepResponse: ...
|
||||
|
||||
@webmethod(route="/agentic_system/session/create")
|
||||
async def create_agentic_system_session(
|
||||
self,
|
||||
request: AgenticSystemSessionCreateRequest,
|
||||
) -> AgenticSystemSessionCreateResponse:
|
||||
...
|
||||
) -> AgenticSystemSessionCreateResponse: ...
|
||||
|
||||
@webmethod(route="/agentic_system/memory_bank/attach")
|
||||
async def attach_memory_bank_to_agentic_system(
|
||||
|
@ -104,8 +99,7 @@ class AgenticSystem(Protocol):
|
|||
agent_id: str,
|
||||
session_id: str,
|
||||
memory_bank_ids: List[str],
|
||||
) -> None:
|
||||
...
|
||||
) -> None: ...
|
||||
|
||||
@webmethod(route="/agentic_system/memory_bank/detach")
|
||||
async def detach_memory_bank_from_agentic_system(
|
||||
|
@ -113,8 +107,7 @@ class AgenticSystem(Protocol):
|
|||
agent_id: str,
|
||||
session_id: str,
|
||||
memory_bank_ids: List[str],
|
||||
) -> None:
|
||||
...
|
||||
) -> None: ...
|
||||
|
||||
@webmethod(route="/agentic_system/session/get")
|
||||
async def get_agentic_system_session(
|
||||
|
@ -122,18 +115,15 @@ class AgenticSystem(Protocol):
|
|||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_ids: Optional[List[str]] = None,
|
||||
) -> Session:
|
||||
...
|
||||
) -> Session: ...
|
||||
|
||||
@webmethod(route="/agentic_system/session/delete")
|
||||
async def delete_agentic_system_session(
|
||||
self, agent_id: str, session_id: str
|
||||
) -> None:
|
||||
...
|
||||
) -> None: ...
|
||||
|
||||
@webmethod(route="/agentic_system/delete")
|
||||
async def delete_agentic_system(
|
||||
self,
|
||||
agent_id: str,
|
||||
) -> None:
|
||||
...
|
||||
) -> None: ...
|
||||
|
|
|
@ -10,11 +10,11 @@ import os
|
|||
import pkg_resources
|
||||
import yaml
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_toolchain.cli.subcommand import Subcommand
|
||||
from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
|
||||
class DistributionInstall(Subcommand):
|
||||
"""Llama cli for configuring llama toolchain configs"""
|
||||
|
|
|
@ -14,10 +14,10 @@ from pathlib import Path
|
|||
|
||||
import httpx
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_toolchain.cli.subcommand import Subcommand
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
|
||||
class Download(Subcommand):
|
||||
"""Llama cli for downloading llama toolchain assets"""
|
||||
|
|
|
@ -9,12 +9,12 @@ import json
|
|||
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from termcolor import colored
|
||||
|
||||
from llama_toolchain.cli.subcommand import Subcommand
|
||||
from llama_toolchain.cli.table import print_table
|
||||
from llama_toolchain.common.serialize import EnumEncoder
|
||||
|
||||
from termcolor import colored
|
||||
|
||||
|
||||
class ModelDescribe(Subcommand):
|
||||
"""Show details about a model"""
|
||||
|
|
|
@ -7,10 +7,10 @@
|
|||
import argparse
|
||||
import textwrap
|
||||
|
||||
from termcolor import colored
|
||||
|
||||
from llama_toolchain.cli.subcommand import Subcommand
|
||||
|
||||
from termcolor import colored
|
||||
|
||||
|
||||
class ModelTemplate(Subcommand):
|
||||
"""Llama model cli for describe a model template (message formats)"""
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
|
||||
from llama_models.datatypes import Model
|
||||
|
||||
from .config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||
|
|
|
@ -26,19 +26,16 @@ class Datasets(Protocol):
|
|||
def create_dataset(
|
||||
self,
|
||||
request: CreateDatasetRequest,
|
||||
) -> None:
|
||||
...
|
||||
) -> None: ...
|
||||
|
||||
@webmethod(route="/datasets/get")
|
||||
def get_dataset(
|
||||
self,
|
||||
dataset_uuid: str,
|
||||
) -> TrainEvalDataset:
|
||||
...
|
||||
) -> TrainEvalDataset: ...
|
||||
|
||||
@webmethod(route="/datasets/delete")
|
||||
def delete_dataset(
|
||||
self,
|
||||
dataset_uuid: str,
|
||||
) -> None:
|
||||
...
|
||||
) -> None: ...
|
||||
|
|
|
@ -63,42 +63,36 @@ class Evaluations(Protocol):
|
|||
def post_evaluate_text_generation(
|
||||
self,
|
||||
request: EvaluateTextGenerationRequest,
|
||||
) -> EvaluationJob:
|
||||
...
|
||||
) -> EvaluationJob: ...
|
||||
|
||||
@webmethod(route="/evaluate/question_answering/")
|
||||
def post_evaluate_question_answering(
|
||||
self,
|
||||
request: EvaluateQuestionAnsweringRequest,
|
||||
) -> EvaluationJob:
|
||||
...
|
||||
) -> EvaluationJob: ...
|
||||
|
||||
@webmethod(route="/evaluate/summarization/")
|
||||
def post_evaluate_summarization(
|
||||
self,
|
||||
request: EvaluateSummarizationRequest,
|
||||
) -> EvaluationJob:
|
||||
...
|
||||
) -> EvaluationJob: ...
|
||||
|
||||
@webmethod(route="/evaluate/jobs")
|
||||
def get_evaluation_jobs(self) -> List[EvaluationJob]:
|
||||
...
|
||||
def get_evaluation_jobs(self) -> List[EvaluationJob]: ...
|
||||
|
||||
@webmethod(route="/evaluate/job/status")
|
||||
def get_evaluation_job_status(self, job_uuid: str) -> EvaluationJobStatusResponse:
|
||||
...
|
||||
def get_evaluation_job_status(
|
||||
self, job_uuid: str
|
||||
) -> EvaluationJobStatusResponse: ...
|
||||
|
||||
# sends SSE stream of logs
|
||||
@webmethod(route="/evaluate/job/logs")
|
||||
def get_evaluation_job_logstream(self, job_uuid: str) -> EvaluationJobLogStream:
|
||||
...
|
||||
def get_evaluation_job_logstream(self, job_uuid: str) -> EvaluationJobLogStream: ...
|
||||
|
||||
@webmethod(route="/evaluate/job/cancel")
|
||||
def cancel_evaluation_job(self, job_uuid: str) -> None:
|
||||
...
|
||||
def cancel_evaluation_job(self, job_uuid: str) -> None: ...
|
||||
|
||||
@webmethod(route="/evaluate/job/artifacts")
|
||||
def get_evaluation_job_artifacts(
|
||||
self, job_uuid: str
|
||||
) -> EvaluationJobArtifactsResponse:
|
||||
...
|
||||
) -> EvaluationJobArtifactsResponse: ...
|
||||
|
|
|
@ -101,26 +101,22 @@ class Inference(Protocol):
|
|||
async def completion(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||
...
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
|
||||
|
||||
@webmethod(route="/inference/chat_completion")
|
||||
async def chat_completion(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]:
|
||||
...
|
||||
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ...
|
||||
|
||||
@webmethod(route="/inference/batch_completion")
|
||||
async def batch_completion(
|
||||
self,
|
||||
request: BatchCompletionRequest,
|
||||
) -> BatchCompletionResponse:
|
||||
...
|
||||
) -> BatchCompletionResponse: ...
|
||||
|
||||
@webmethod(route="/inference/batch_chat_completion")
|
||||
async def batch_chat_completion(
|
||||
self,
|
||||
request: BatchChatCompletionRequest,
|
||||
) -> BatchChatCompletionResponse:
|
||||
...
|
||||
) -> BatchChatCompletionResponse: ...
|
||||
|
|
|
@ -4,12 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_toolchain.inference.api import (
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
)
|
||||
from termcolor import cprint
|
||||
|
||||
|
||||
class LogEvent:
|
||||
|
|
|
@ -11,10 +11,10 @@ from llama_models.datatypes import ModelFamily
|
|||
from llama_models.schema_utils import json_schema_type
|
||||
from llama_models.sku_list import all_registered_models
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from llama_toolchain.inference.api import QuantizationConfig
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MetaReferenceImplConfig(BaseModel):
|
||||
|
|
|
@ -28,10 +28,10 @@ from llama_models.llama3_1.api.datatypes import Message
|
|||
from llama_models.llama3_1.api.tokenizer import Tokenizer
|
||||
from llama_models.llama3_1.reference_impl.model import Transformer
|
||||
from llama_models.sku_list import resolve_model
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_toolchain.common.model_utils import model_local_dir
|
||||
from llama_toolchain.inference.api import QuantizationType
|
||||
from termcolor import cprint
|
||||
|
||||
from .config import MetaReferenceImplConfig
|
||||
|
||||
|
|
|
@ -18,52 +18,44 @@ class MemoryBanks(Protocol):
|
|||
bank_id: str,
|
||||
bank_name: str,
|
||||
documents: List[MemoryBankDocument],
|
||||
) -> None:
|
||||
...
|
||||
) -> None: ...
|
||||
|
||||
@webmethod(route="/memory_banks/list")
|
||||
def get_memory_banks(self) -> List[MemoryBank]:
|
||||
...
|
||||
def get_memory_banks(self) -> List[MemoryBank]: ...
|
||||
|
||||
@webmethod(route="/memory_banks/get")
|
||||
def get_memory_bank(self, bank_id: str) -> List[MemoryBank]:
|
||||
...
|
||||
def get_memory_bank(self, bank_id: str) -> List[MemoryBank]: ...
|
||||
|
||||
@webmethod(route="/memory_banks/drop")
|
||||
def delete_memory_bank(
|
||||
self,
|
||||
bank_id: str,
|
||||
) -> str:
|
||||
...
|
||||
) -> str: ...
|
||||
|
||||
@webmethod(route="/memory_bank/insert")
|
||||
def post_insert_memory_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
documents: List[MemoryBankDocument],
|
||||
) -> None:
|
||||
...
|
||||
) -> None: ...
|
||||
|
||||
@webmethod(route="/memory_bank/update")
|
||||
def post_update_memory_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
documents: List[MemoryBankDocument],
|
||||
) -> None:
|
||||
...
|
||||
) -> None: ...
|
||||
|
||||
@webmethod(route="/memory_bank/get")
|
||||
def get_memory_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
document_uuids: List[str],
|
||||
) -> List[MemoryBankDocument]:
|
||||
...
|
||||
) -> List[MemoryBankDocument]: ...
|
||||
|
||||
@webmethod(route="/memory_bank/delete")
|
||||
def delete_memory_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
document_uuids: List[str],
|
||||
) -> List[str]:
|
||||
...
|
||||
) -> List[str]: ...
|
||||
|
|
|
@ -11,5 +11,4 @@ from llama_models.schema_utils import webmethod # noqa: F401
|
|||
from pydantic import BaseModel # noqa: F401
|
||||
|
||||
|
||||
class Models(Protocol):
|
||||
...
|
||||
class Models(Protocol): ...
|
||||
|
|
|
@ -98,35 +98,30 @@ class PostTraining(Protocol):
|
|||
def post_supervised_fine_tune(
|
||||
self,
|
||||
request: PostTrainingSFTRequest,
|
||||
) -> PostTrainingJob:
|
||||
...
|
||||
) -> PostTrainingJob: ...
|
||||
|
||||
@webmethod(route="/post_training/preference_optimize")
|
||||
def post_preference_optimize(
|
||||
self,
|
||||
request: PostTrainingRLHFRequest,
|
||||
) -> PostTrainingJob:
|
||||
...
|
||||
) -> PostTrainingJob: ...
|
||||
|
||||
@webmethod(route="/post_training/jobs")
|
||||
def get_training_jobs(self) -> List[PostTrainingJob]:
|
||||
...
|
||||
def get_training_jobs(self) -> List[PostTrainingJob]: ...
|
||||
|
||||
# sends SSE stream of logs
|
||||
@webmethod(route="/post_training/job/logs")
|
||||
def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream:
|
||||
...
|
||||
def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream: ...
|
||||
|
||||
@webmethod(route="/post_training/job/status")
|
||||
def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse:
|
||||
...
|
||||
def get_training_job_status(
|
||||
self, job_uuid: str
|
||||
) -> PostTrainingJobStatusResponse: ...
|
||||
|
||||
@webmethod(route="/post_training/job/cancel")
|
||||
def cancel_training_job(self, job_uuid: str) -> None:
|
||||
...
|
||||
def cancel_training_job(self, job_uuid: str) -> None: ...
|
||||
|
||||
@webmethod(route="/post_training/job/artifacts")
|
||||
def get_training_job_artifacts(
|
||||
self, job_uuid: str
|
||||
) -> PostTrainingJobArtifactsResponse:
|
||||
...
|
||||
) -> PostTrainingJobArtifactsResponse: ...
|
||||
|
|
|
@ -30,5 +30,4 @@ class RewardScoring(Protocol):
|
|||
def post_score(
|
||||
self,
|
||||
request: RewardScoringRequest,
|
||||
) -> Union[RewardScoringResponse]:
|
||||
...
|
||||
) -> Union[RewardScoringResponse]: ...
|
||||
|
|
|
@ -11,10 +11,10 @@ from llama_models.llama3_1.api.datatypes import ToolParamDefinition
|
|||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_toolchain.common.deployment_types import RestAPIExecutionConfig
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BuiltinShield(Enum):
|
||||
|
|
|
@ -29,5 +29,4 @@ class Safety(Protocol):
|
|||
async def run_shields(
|
||||
self,
|
||||
request: RunShieldRequest,
|
||||
) -> RunShieldResponse:
|
||||
...
|
||||
) -> RunShieldResponse: ...
|
||||
|
|
|
@ -37,5 +37,4 @@ class SyntheticDataGeneration(Protocol):
|
|||
def post_generate(
|
||||
self,
|
||||
request: SyntheticDataGenerationRequest,
|
||||
) -> Union[SyntheticDataGenerationResponse]:
|
||||
...
|
||||
) -> Union[SyntheticDataGenerationResponse]: ...
|
||||
|
|
|
@ -10,20 +10,16 @@ from datetime import datetime
|
|||
|
||||
from llama_models.llama3_1.api.datatypes import (
|
||||
BuiltinTool,
|
||||
UserMessage,
|
||||
StopReason,
|
||||
SystemMessage,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_toolchain.inference.api.datatypes import (
|
||||
ChatCompletionResponseEventType,
|
||||
)
|
||||
from llama_toolchain.inference.meta_reference.inference import get_provider_impl
|
||||
from llama_toolchain.inference.meta_reference.config import (
|
||||
MetaReferenceImplConfig,
|
||||
)
|
||||
from llama_toolchain.inference.api.datatypes import ChatCompletionResponseEventType
|
||||
|
||||
from llama_toolchain.inference.api.endpoints import ChatCompletionRequest
|
||||
from llama_toolchain.inference.meta_reference.config import MetaReferenceImplConfig
|
||||
from llama_toolchain.inference.meta_reference.inference import get_provider_impl
|
||||
|
||||
|
||||
MODEL = "Meta-Llama3.1-8B-Instruct"
|
||||
|
|
|
@ -4,16 +4,14 @@ from datetime import datetime
|
|||
|
||||
from llama_models.llama3_1.api.datatypes import (
|
||||
BuiltinTool,
|
||||
UserMessage,
|
||||
StopReason,
|
||||
SamplingParams,
|
||||
SamplingStrategy,
|
||||
StopReason,
|
||||
SystemMessage,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_toolchain.inference.api.datatypes import (
|
||||
ChatCompletionResponseEventType,
|
||||
)
|
||||
from llama_toolchain.inference.api.datatypes import ChatCompletionResponseEventType
|
||||
from llama_toolchain.inference.api.endpoints import ChatCompletionRequest
|
||||
from llama_toolchain.inference.ollama.config import OllamaImplConfig
|
||||
from llama_toolchain.inference.ollama.ollama import get_provider_impl
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue