mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
kill older junk
This commit is contained in:
parent
95781ec85d
commit
f94efcf2ee
11 changed files with 0 additions and 6898 deletions
|
@ -1,98 +0,0 @@
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Any, Dict, List, Optional, Set, Union
|
|
||||||
|
|
||||||
from model_types import (
|
|
||||||
BuiltinTool,
|
|
||||||
Content,
|
|
||||||
InstructModel,
|
|
||||||
Message,
|
|
||||||
PretrainedModel,
|
|
||||||
SamplingParams,
|
|
||||||
SafetyViolation,
|
|
||||||
StopReason,
|
|
||||||
ToolCall,
|
|
||||||
ToolDefinition,
|
|
||||||
ToolResponse,
|
|
||||||
)
|
|
||||||
|
|
||||||
from strong_typing.schema import json_schema_type
|
|
||||||
|
|
||||||
|
|
||||||
class ExecutionStepType(Enum):
|
|
||||||
"""The type of execution step."""
|
|
||||||
|
|
||||||
model_inference = "model_inference"
|
|
||||||
tool_execution = "tool_execution"
|
|
||||||
safety_filtering = "safety_filtering"
|
|
||||||
memory_retrieval = "memory_retrieval"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ExecutionStepBase:
|
|
||||||
"""An agentic system turn can consist of one or more such execution steps."""
|
|
||||||
|
|
||||||
step_type: ExecutionStepType
|
|
||||||
uuid: str
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ModelInferenceStep(ExecutionStepBase):
|
|
||||||
step_type = ExecutionStepType.model_inference
|
|
||||||
text: str
|
|
||||||
logprobs: Optional[Dict[str, Any]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ToolExecutionStep(ExecutionStepBase):
|
|
||||||
step_type = ExecutionStepType.tool_execution
|
|
||||||
|
|
||||||
# we could be calling multiple tools in a single step (in parallel)
|
|
||||||
tool_calls: List[ToolCall]
|
|
||||||
tool_responses: List[ToolResponse]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SafetyFilteringStep(ExecutionStepBase):
|
|
||||||
step_type = ExecutionStepType.safety_filtering
|
|
||||||
violation: Optional[SafetyViolation] = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
@dataclass
|
|
||||||
class MemoryBank:
|
|
||||||
uuid: str
|
|
||||||
name: str
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MemoryBankDocument:
|
|
||||||
uuid: str
|
|
||||||
content: bytes
|
|
||||||
metadata: Dict[str, Any]
|
|
||||||
mime_type: str
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MemoryRetrievalStep(ExecutionStepBase):
|
|
||||||
step_type = ExecutionStepType.memory_retrieval
|
|
||||||
documents: List[MemoryBankDocument]
|
|
||||||
scores: List[float]
|
|
||||||
|
|
||||||
|
|
||||||
ExecutionStep = Union[
|
|
||||||
ModelInferenceStep,
|
|
||||||
ToolExecutionStep,
|
|
||||||
SafetyFilteringStep,
|
|
||||||
MemoryRetrievalStep,
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
@dataclass
|
|
||||||
class AgenticSystemTurn:
|
|
||||||
"""A single turn in an interaction with an Agentic System."""
|
|
||||||
|
|
||||||
user_messages: List[Message]
|
|
||||||
steps: List[ExecutionStep]
|
|
||||||
response_message: Message
|
|
|
@ -1,563 +0,0 @@
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from datetime import datetime
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Any, Dict, List, Optional, Protocol, Set, Tuple, Union
|
|
||||||
|
|
||||||
import yaml
|
|
||||||
from agentic_system_types import (
|
|
||||||
AgenticSystemTurn,
|
|
||||||
ExecutionStepType,
|
|
||||||
MemoryBank,
|
|
||||||
MemoryBankDocument,
|
|
||||||
SafetyViolation,
|
|
||||||
)
|
|
||||||
|
|
||||||
from model_types import (
|
|
||||||
BuiltinTool,
|
|
||||||
Content,
|
|
||||||
Dialog,
|
|
||||||
InstructModel,
|
|
||||||
Message,
|
|
||||||
PretrainedModel,
|
|
||||||
RewardModel,
|
|
||||||
SamplingParams,
|
|
||||||
ShieldConfig,
|
|
||||||
StopReason,
|
|
||||||
ToolCall,
|
|
||||||
ToolDefinition,
|
|
||||||
ToolResponse,
|
|
||||||
URL,
|
|
||||||
)
|
|
||||||
|
|
||||||
from post_training_types import (
|
|
||||||
Checkpoint,
|
|
||||||
Dataset,
|
|
||||||
DoraFinetuningConfig,
|
|
||||||
DPOAlignmentConfig,
|
|
||||||
FinetuningAlgorithm,
|
|
||||||
LoraFinetuningConfig,
|
|
||||||
OptimizerConfig,
|
|
||||||
PostTrainingJobLogStream,
|
|
||||||
PostTrainingJobStatus,
|
|
||||||
QLoraFinetuningConfig,
|
|
||||||
RLHFAlgorithm,
|
|
||||||
TrainingConfig,
|
|
||||||
)
|
|
||||||
|
|
||||||
from pyopenapi import Info, Options, Server, Specification, webmethod
|
|
||||||
from strong_typing.schema import json_schema_type
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
@dataclass
|
|
||||||
class CompletionRequest:
|
|
||||||
content: Content
|
|
||||||
model: PretrainedModel
|
|
||||||
sampling_params: SamplingParams = SamplingParams()
|
|
||||||
max_tokens: int = 0
|
|
||||||
stream: bool = False
|
|
||||||
logprobs: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
@dataclass
|
|
||||||
class CompletionResponse:
|
|
||||||
"""Normal completion response."""
|
|
||||||
|
|
||||||
content: Content
|
|
||||||
stop_reason: Optional[StopReason] = None
|
|
||||||
logprobs: Optional[Dict[str, Any]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
@dataclass
|
|
||||||
class CompletionResponseStreamChunk:
|
|
||||||
"""streamed completion response."""
|
|
||||||
|
|
||||||
text_delta: str
|
|
||||||
stop_reason: Optional[StopReason] = None
|
|
||||||
logprobs: Optional[Dict[str, Any]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
@dataclass
|
|
||||||
class ChatCompletionRequest:
|
|
||||||
model: InstructModel
|
|
||||||
dialog: Dialog
|
|
||||||
sampling_params: SamplingParams = SamplingParams()
|
|
||||||
|
|
||||||
# zero-shot tool definitions as input to the model
|
|
||||||
available_tools: List[ToolDefinition] = field(default_factory=list)
|
|
||||||
|
|
||||||
max_tokens: int = 0
|
|
||||||
stream: bool = False
|
|
||||||
logprobs: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
@dataclass
|
|
||||||
class ChatCompletionResponse:
|
|
||||||
"""Normal chat completion response."""
|
|
||||||
|
|
||||||
content: Content
|
|
||||||
|
|
||||||
# note: multiple tool calls can be generated in a single response
|
|
||||||
tool_calls: List[ToolCall] = field(default_factory=list)
|
|
||||||
|
|
||||||
stop_reason: Optional[StopReason] = None
|
|
||||||
logprobs: Optional[Dict[str, Any]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
@dataclass
|
|
||||||
class ChatCompletionResponseStreamChunk:
|
|
||||||
"""Streamed chat completion response. The actual response is a series of such objects."""
|
|
||||||
|
|
||||||
text_delta: str
|
|
||||||
stop_reason: Optional[StopReason] = None
|
|
||||||
tool_call: Optional[ToolCall] = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
@dataclass
|
|
||||||
class BatchCompletionRequest:
|
|
||||||
model: PretrainedModel
|
|
||||||
content_batch: List[Content]
|
|
||||||
sampling_params: SamplingParams = SamplingParams()
|
|
||||||
max_tokens: int = 0
|
|
||||||
logprobs: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
@dataclass
|
|
||||||
class BatchChatCompletionRequest:
|
|
||||||
model: InstructModel
|
|
||||||
batch_dialogs: List[Dialog]
|
|
||||||
sampling_params: SamplingParams = SamplingParams()
|
|
||||||
|
|
||||||
# zero-shot tool definitions as input to the model
|
|
||||||
available_tools: List[ToolDefinition] = field(default_factory=list)
|
|
||||||
|
|
||||||
max_tokens: int = 0
|
|
||||||
logprobs: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class Inference(Protocol):
|
|
||||||
|
|
||||||
@webmethod(route="/inference/completion")
|
|
||||||
def post_completion(
|
|
||||||
self,
|
|
||||||
request: CompletionRequest,
|
|
||||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
|
|
||||||
|
|
||||||
@webmethod(route="/inference/chat_completion")
|
|
||||||
def post_chat_completion(
|
|
||||||
self,
|
|
||||||
request: ChatCompletionRequest,
|
|
||||||
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ...
|
|
||||||
|
|
||||||
@webmethod(route="/inference/batch_completion")
|
|
||||||
def post_batch_completion(
|
|
||||||
self,
|
|
||||||
request: BatchCompletionRequest,
|
|
||||||
) -> List[CompletionResponse]: ...
|
|
||||||
|
|
||||||
@webmethod(route="/inference/batch_chat_completion")
|
|
||||||
def post_batch_chat_completion(
|
|
||||||
self,
|
|
||||||
request: BatchChatCompletionRequest,
|
|
||||||
) -> List[ChatCompletionResponse]: ...
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AgenticSystemCreateRequest:
|
|
||||||
uuid: str
|
|
||||||
|
|
||||||
instructions: str
|
|
||||||
model: InstructModel
|
|
||||||
|
|
||||||
# zero-shot or built-in tool configurations as input to the model
|
|
||||||
available_tools: List[ToolDefinition] = field(default_factory=list)
|
|
||||||
|
|
||||||
# tools which aren't executable are emitted as tool calls which the users can
|
|
||||||
# execute themselves.
|
|
||||||
executable_tools: Set[str] = field(default_factory=set)
|
|
||||||
|
|
||||||
memory_bank_uuids: List[str] = field(default_factory=list)
|
|
||||||
|
|
||||||
input_shields: List[ShieldConfig] = field(default_factory=list)
|
|
||||||
output_shields: List[ShieldConfig] = field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
@dataclass
|
|
||||||
class AgenticSystemCreateResponse:
|
|
||||||
agent_uuid: str
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
@dataclass
|
|
||||||
class AgenticSystemExecuteRequest:
|
|
||||||
agent_uuid: str
|
|
||||||
messages: List[Message]
|
|
||||||
turn_history: List[AgenticSystemTurn] = None
|
|
||||||
stream: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
@dataclass
|
|
||||||
class AgenticSystemExecuteResponse:
|
|
||||||
"""non-stream response from the agentic system."""
|
|
||||||
|
|
||||||
turn: AgenticSystemTurn
|
|
||||||
|
|
||||||
|
|
||||||
class AgenticSystemExecuteResponseEventType(Enum):
|
|
||||||
"""The type of event."""
|
|
||||||
|
|
||||||
step_start = "step_start"
|
|
||||||
step_end = "step_end"
|
|
||||||
step_progress = "step_progress"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
@dataclass
|
|
||||||
class AgenticSystemExecuteResponseStreamChunk:
|
|
||||||
"""Streamed agent execution response."""
|
|
||||||
|
|
||||||
event_type: AgenticSystemExecuteResponseEventType
|
|
||||||
|
|
||||||
step_uuid: str
|
|
||||||
step_type: ExecutionStepType
|
|
||||||
|
|
||||||
# TODO(ashwin): maybe add more structure here and do this as a proper tagged union
|
|
||||||
violation: Optional[SafetyViolation] = None
|
|
||||||
tool_call: Optional[ToolCall] = None
|
|
||||||
tool_response_delta: Optional[ToolResponse] = None
|
|
||||||
response_text_delta: Optional[str] = None
|
|
||||||
retrieved_document: Optional[MemoryBankDocument] = None
|
|
||||||
|
|
||||||
stop_reason: Optional[StopReason] = None
|
|
||||||
|
|
||||||
|
|
||||||
class AgenticSystem(Protocol):
|
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/create")
|
|
||||||
def create_agentic_system(
|
|
||||||
self,
|
|
||||||
request: AgenticSystemCreateRequest,
|
|
||||||
) -> AgenticSystemCreateResponse: ...
|
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/execute")
|
|
||||||
def create_agentic_system_execute(
|
|
||||||
self,
|
|
||||||
request: AgenticSystemExecuteRequest,
|
|
||||||
) -> Union[
|
|
||||||
AgenticSystemExecuteResponse, AgenticSystemExecuteResponseStreamChunk
|
|
||||||
]: ...
|
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/delete")
|
|
||||||
def delete_agentic_system(
|
|
||||||
self,
|
|
||||||
agent_id: str,
|
|
||||||
) -> None: ...
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryBanks(Protocol):
|
|
||||||
@webmethod(route="/memory_banks/create")
|
|
||||||
def post_create_memory_bank(
|
|
||||||
self,
|
|
||||||
bank_uuid: str,
|
|
||||||
bank_name: str,
|
|
||||||
documents: List[MemoryBankDocument],
|
|
||||||
) -> None: ...
|
|
||||||
|
|
||||||
@webmethod(route="/memory_banks/get")
|
|
||||||
def get_memory_banks(
|
|
||||||
self
|
|
||||||
) -> List[MemoryBank]: ...
|
|
||||||
|
|
||||||
@webmethod(route="/memory_banks/drop")
|
|
||||||
def delete_memory_bank(
|
|
||||||
self,
|
|
||||||
bank_uuid: str,
|
|
||||||
) -> str: ...
|
|
||||||
|
|
||||||
@webmethod(route="/memory_bank/insert")
|
|
||||||
def post_insert_memory_documents(
|
|
||||||
self,
|
|
||||||
bank_uuid: str,
|
|
||||||
documents: List[MemoryBankDocument],
|
|
||||||
) -> None: ...
|
|
||||||
|
|
||||||
@webmethod(route="/memory_bank/update")
|
|
||||||
def post_update_memory_documents(
|
|
||||||
self,
|
|
||||||
bank_uuid: str,
|
|
||||||
documents: List[MemoryBankDocument],
|
|
||||||
) -> None: ...
|
|
||||||
|
|
||||||
@webmethod(route="/memory_bank/get")
|
|
||||||
def get_memory_documents(
|
|
||||||
self,
|
|
||||||
bank_uuid: str,
|
|
||||||
document_uuids: List[str],
|
|
||||||
) -> List[MemoryBankDocument]: ...
|
|
||||||
|
|
||||||
@webmethod(route="/memory_bank/delete")
|
|
||||||
def delete_memory_documents(
|
|
||||||
self,
|
|
||||||
bank_uuid: str,
|
|
||||||
document_uuids: List[str],
|
|
||||||
) -> List[str]: ...
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class KPromptGenerations:
|
|
||||||
dialog: Dialog
|
|
||||||
k_generations: List[Message]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
@dataclass
|
|
||||||
class ScoredMessage:
|
|
||||||
message: Message
|
|
||||||
score: float
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
@dataclass
|
|
||||||
class KScoredPromptGenerations:
|
|
||||||
prompt: Message
|
|
||||||
k_scored_generations: List[ScoredMessage]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
@dataclass
|
|
||||||
class RewardScoringRequest:
|
|
||||||
"""Request to score a reward function. A list of prompts and a list of responses per prompt."""
|
|
||||||
|
|
||||||
prompt_generations: List[KPromptGenerations]
|
|
||||||
model: RewardModel
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
@dataclass
|
|
||||||
class RewardScoringResponse:
|
|
||||||
"""Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold."""
|
|
||||||
|
|
||||||
scored_generations: List[KScoredPromptGenerations]
|
|
||||||
|
|
||||||
|
|
||||||
class RewardScoring(Protocol):
|
|
||||||
@webmethod(route="/reward_scoring/score")
|
|
||||||
def post_score(
|
|
||||||
self,
|
|
||||||
request: RewardScoringRequest,
|
|
||||||
) -> Union[RewardScoringResponse]: ...
|
|
||||||
|
|
||||||
|
|
||||||
class FilteringFunction(Enum):
|
|
||||||
"""The type of filtering function."""
|
|
||||||
|
|
||||||
none = "none"
|
|
||||||
random = "random"
|
|
||||||
top_k = "top_k"
|
|
||||||
top_p = "top_p"
|
|
||||||
top_k_top_p = "top_k_top_p"
|
|
||||||
sigmoid = "sigmoid"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
@dataclass
|
|
||||||
class SyntheticDataGenerationRequest:
|
|
||||||
"""Request to generate synthetic data. A small batch of prompts and a filtering function"""
|
|
||||||
|
|
||||||
prompts: List[Message]
|
|
||||||
filtering_function: FilteringFunction = FilteringFunction.none
|
|
||||||
reward_scoring: Optional[RewardScoring] = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
@dataclass
|
|
||||||
class SyntheticDataGenerationResponse:
|
|
||||||
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."""
|
|
||||||
|
|
||||||
synthetic_data: List[KScoredPromptGenerations]
|
|
||||||
statistics: Optional[Dict[str, Any]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class SyntheticDataGeneration(Protocol):
|
|
||||||
@webmethod(route="/synthetic_data_generation/generate")
|
|
||||||
def post_generate(
|
|
||||||
self,
|
|
||||||
request: SyntheticDataGenerationRequest,
|
|
||||||
) -> Union[SyntheticDataGenerationResponse]: ...
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
@dataclass
|
|
||||||
class CreateDatasetRequest:
|
|
||||||
"""Request to create a dataset."""
|
|
||||||
|
|
||||||
uuid: str
|
|
||||||
dataset: Dataset
|
|
||||||
|
|
||||||
|
|
||||||
class Datasets(Protocol):
|
|
||||||
@webmethod(route="/datasets/create")
|
|
||||||
def create_dataset(
|
|
||||||
self,
|
|
||||||
request: CreateDatasetRequest,
|
|
||||||
) -> None: ...
|
|
||||||
|
|
||||||
@webmethod(route="/datasets/get")
|
|
||||||
def get_dataset(
|
|
||||||
self,
|
|
||||||
dataset_id: str,
|
|
||||||
) -> Dataset: ...
|
|
||||||
|
|
||||||
@webmethod(route="/datasets/delete")
|
|
||||||
def delete_dataset(
|
|
||||||
self,
|
|
||||||
dataset_id: str,
|
|
||||||
) -> None: ...
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
@dataclass
|
|
||||||
class PostTrainingSFTRequest:
|
|
||||||
"""Request to finetune a model."""
|
|
||||||
|
|
||||||
job_uuid: str
|
|
||||||
|
|
||||||
model: PretrainedModel
|
|
||||||
dataset: Dataset
|
|
||||||
validation_dataset: Dataset
|
|
||||||
|
|
||||||
algorithm: FinetuningAlgorithm
|
|
||||||
algorithm_config: Union[
|
|
||||||
LoraFinetuningConfig, QLoraFinetuningConfig, DoraFinetuningConfig
|
|
||||||
]
|
|
||||||
|
|
||||||
optimizer_config: OptimizerConfig
|
|
||||||
training_config: TrainingConfig
|
|
||||||
|
|
||||||
# TODO: define these
|
|
||||||
hyperparam_search_config: Dict[str, Any]
|
|
||||||
logger_config: Dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
@dataclass
|
|
||||||
class PostTrainingRLHFRequest:
|
|
||||||
"""Request to finetune a model."""
|
|
||||||
|
|
||||||
job_uuid: str
|
|
||||||
|
|
||||||
finetuned_model: URL
|
|
||||||
|
|
||||||
dataset: Dataset
|
|
||||||
validation_dataset: Dataset
|
|
||||||
|
|
||||||
algorithm: RLHFAlgorithm
|
|
||||||
algorithm_config: Union[DPOAlignmentConfig]
|
|
||||||
|
|
||||||
optimizer_config: OptimizerConfig
|
|
||||||
training_config: TrainingConfig
|
|
||||||
|
|
||||||
# TODO: define these
|
|
||||||
hyperparam_search_config: Dict[str, Any]
|
|
||||||
logger_config: Dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
@dataclass
|
|
||||||
class PostTrainingJobStatusResponse:
|
|
||||||
"""Status of a finetuning job."""
|
|
||||||
|
|
||||||
job_uuid: str
|
|
||||||
status: PostTrainingJobStatus
|
|
||||||
|
|
||||||
scheduled_at: Optional[datetime] = None
|
|
||||||
started_at: Optional[datetime] = None
|
|
||||||
completed_at: Optional[datetime] = None
|
|
||||||
|
|
||||||
resources_allocated: Optional[Dict[str, Any]] = None
|
|
||||||
|
|
||||||
checkpoints: List[Checkpoint] = field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
@dataclass
|
|
||||||
class PostTrainingJobArtifactsResponse:
|
|
||||||
"""Artifacts of a finetuning job."""
|
|
||||||
|
|
||||||
job_uuid: str
|
|
||||||
checkpoints: List[Checkpoint] = field(default_factory=list)
|
|
||||||
|
|
||||||
# TODO(ashwin): metrics, evals
|
|
||||||
|
|
||||||
|
|
||||||
class PostTraining(Protocol):
|
|
||||||
@webmethod(route="/post_training/supervised_fine_tune/")
|
|
||||||
def post_supervised_fine_tune(
|
|
||||||
self,
|
|
||||||
request: PostTrainingSFTRequest,
|
|
||||||
) -> None: ...
|
|
||||||
|
|
||||||
@webmethod(route="/post_training/preference_optimize/")
|
|
||||||
def post_preference_optimize(
|
|
||||||
self,
|
|
||||||
request: PostTrainingRLHFRequest,
|
|
||||||
) -> None: ...
|
|
||||||
|
|
||||||
# sends SSE stream of logs
|
|
||||||
@webmethod(route="/post_training/job/logs")
|
|
||||||
def get_training_log_stream(self, job_uuid: str) -> PostTrainingJobLogStream: ...
|
|
||||||
|
|
||||||
@webmethod(route="/post_training/job/status")
|
|
||||||
def get_training_job_status(
|
|
||||||
self, job_uuid: str
|
|
||||||
) -> PostTrainingJobStatusResponse: ...
|
|
||||||
|
|
||||||
@webmethod(route="/post_training/job/cancel")
|
|
||||||
def cancel_training_job(self, job_uuid: str) -> None: ...
|
|
||||||
|
|
||||||
@webmethod(route="/post_training/job/artifacts")
|
|
||||||
def get_training_job_artifacts(
|
|
||||||
self, job_uuid: str
|
|
||||||
) -> PostTrainingJobArtifactsResponse: ...
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaStackEndpoints(
|
|
||||||
Inference,
|
|
||||||
AgenticSystem,
|
|
||||||
RewardScoring,
|
|
||||||
SyntheticDataGeneration,
|
|
||||||
Datasets,
|
|
||||||
PostTraining,
|
|
||||||
MemoryBanks,
|
|
||||||
): ...
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print("Converting the spec to YAML (openapi.yaml) and HTML (openapi.html)")
|
|
||||||
spec = Specification(
|
|
||||||
LlamaStackEndpoints,
|
|
||||||
Options(
|
|
||||||
server=Server(url="http://any-hosted-llama-stack.com"),
|
|
||||||
info=Info(
|
|
||||||
title="[DRAFT] Llama Stack Specification",
|
|
||||||
version="0.0.1",
|
|
||||||
description="""This is the specification of the llama stack that provides
|
|
||||||
a set of endpoints and their corresponding interfaces that are tailored to
|
|
||||||
best leverage Llama Models. The specification is still in draft and subject to change.""",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
with open("openapi.yaml", "w", encoding="utf-8") as fp:
|
|
||||||
yaml.dump(spec.get_json(), fp, allow_unicode=True)
|
|
||||||
|
|
||||||
with open("openapi.html", "w") as fp:
|
|
||||||
spec.write_html(fp, pretty_print=True)
|
|
|
@ -1,59 +0,0 @@
|
||||||
import requests
|
|
||||||
from dataclasses import dataclass, field, asdict
|
|
||||||
from typing import List, Set, Optional, Union, Protocol
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
import json
|
|
||||||
|
|
||||||
from model_types import *
|
|
||||||
from agentic_system_types import *
|
|
||||||
from api_definitions import *
|
|
||||||
|
|
||||||
class EnumEncoder(json.JSONEncoder):
|
|
||||||
def default(self, obj):
|
|
||||||
if isinstance(obj, Enum):
|
|
||||||
return obj.value
|
|
||||||
elif isinstance(obj, set):
|
|
||||||
return list(obj)
|
|
||||||
return json.JSONEncoder.default(self, obj)
|
|
||||||
|
|
||||||
|
|
||||||
class AgenticSystemClient:
|
|
||||||
def __init__(self, base_url: str):
|
|
||||||
self.base_url = base_url
|
|
||||||
|
|
||||||
def create_agentic_system(self, request: AgenticSystemCreateRequest) -> AgenticSystemCreateResponse:
|
|
||||||
response = requests.post(f"{self.base_url}/agentic_system/create", data=json.dumps(asdict(request), cls=EnumEncoder), headers={'Content-Type': 'application/json'})
|
|
||||||
response.raise_for_status()
|
|
||||||
return AgenticSystemCreateResponse(**response.json())
|
|
||||||
|
|
||||||
def execute_agentic_system(self, request: AgenticSystemExecuteRequest) -> Union[AgenticSystemExecuteResponse, AgenticSystemExecuteResponseStreamChunk]:
|
|
||||||
response = requests.post(f"{self.base_url}/agentic_system/execute", data=json.dumps(asdict(request), cls=EnumEncoder), headers={'Content-Type': 'application/json'})
|
|
||||||
response.raise_for_status()
|
|
||||||
response_json = response.json()
|
|
||||||
if 'turn' in response_json:
|
|
||||||
return AgenticSystemExecuteResponse(**response_json)
|
|
||||||
else:
|
|
||||||
return AgenticSystemExecuteResponseStreamChunk(**response_json)
|
|
||||||
|
|
||||||
# Example usage
|
|
||||||
if __name__ == "__main__":
|
|
||||||
client = AgenticSystemClient("http://localhost:5000")
|
|
||||||
|
|
||||||
# Create a new agentic system
|
|
||||||
create_request = AgenticSystemCreateRequest(
|
|
||||||
instructions="Your instructions here",
|
|
||||||
model=InstructModel.llama3_8b_chat,
|
|
||||||
)
|
|
||||||
create_response = client.create_agentic_system(create_request)
|
|
||||||
print("Agent ID:", create_response.agent_id)
|
|
||||||
|
|
||||||
# Execute the agentic system
|
|
||||||
execute_request = AgenticSystemExecuteRequest(
|
|
||||||
agent_id=create_response.agent_id,
|
|
||||||
messages=[Message(role="user", content="Tell me a joke")],
|
|
||||||
turn_history=[],
|
|
||||||
stream=False
|
|
||||||
)
|
|
||||||
execute_response = client.execute_agentic_system(execute_request)
|
|
||||||
print("Execute Response:", execute_response)
|
|
Binary file not shown.
|
@ -1,14 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
|
|
||||||
set -euo pipefail
|
|
||||||
set -x
|
|
||||||
|
|
||||||
export JAVA_HOME=/usr/local/java-runtime/impl/11
|
|
||||||
|
|
||||||
$JAVA_HOME/bin/java -jar codegen/openapi-generator-cli.jar \
|
|
||||||
generate \
|
|
||||||
-i openapi.yaml \
|
|
||||||
-g python-flask \
|
|
||||||
-o /tmp/foo \
|
|
||||||
--log-to-stderr \
|
|
||||||
--global-property debugModels,debugOperations,debugOpenAPI,debugSupportingFiles
|
|
|
@ -1,3 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
|
|
||||||
PYTHONPATH=. python3 api_definitions.py
|
|
|
@ -1,149 +0,0 @@
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Any, Dict, List, Optional, Set, Union
|
|
||||||
|
|
||||||
from strong_typing.schema import json_schema_type
|
|
||||||
|
|
||||||
|
|
||||||
class ShieldType(Enum):
|
|
||||||
"""The type of safety shield."""
|
|
||||||
|
|
||||||
llama_guard = "llama_guard"
|
|
||||||
prompt_guard = "prompt_guard"
|
|
||||||
code_guard = "code_guard"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
@dataclass
|
|
||||||
class ShieldConfig:
|
|
||||||
shield_type: ShieldType
|
|
||||||
params: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SafetyViolation:
|
|
||||||
violation_type: str
|
|
||||||
details: str
|
|
||||||
suggested_user_response: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type(
|
|
||||||
schema={"type": "string", "format": "uri", "pattern": "^(https?://|file://|data:)"}
|
|
||||||
)
|
|
||||||
@dataclass
|
|
||||||
class URL:
|
|
||||||
url: str
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return self.url
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
@dataclass
|
|
||||||
class Attachment:
|
|
||||||
"""
|
|
||||||
Attachments are used to refer to external resources, such as images, videos, audio, etc.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
url: URL
|
|
||||||
mime_type: str
|
|
||||||
|
|
||||||
# TODO(ashwin): make this better named maybe InterleavedTextMedia
|
|
||||||
Content = Union[
|
|
||||||
str,
|
|
||||||
Attachment,
|
|
||||||
List[Union[str, Attachment]],
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class Role(Enum):
|
|
||||||
system = "system"
|
|
||||||
user = "user"
|
|
||||||
assistant = "assistant"
|
|
||||||
tool = "tool"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ToolCall:
|
|
||||||
"""
|
|
||||||
A tool call is a request to a tool.
|
|
||||||
"""
|
|
||||||
|
|
||||||
tool_name: str
|
|
||||||
arguments: Dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ToolResponse:
|
|
||||||
tool_name: str
|
|
||||||
content: Content
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: we need to document the parameters for the tool calls
|
|
||||||
class BuiltinTool(Enum):
|
|
||||||
web_search = "web_search"
|
|
||||||
math = "math"
|
|
||||||
image_gen = "image_gen"
|
|
||||||
code_interpreter = "code_interpreter"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ToolDefinition:
|
|
||||||
tool_name: Union[BuiltinTool, str]
|
|
||||||
parameters: Optional[Dict[str, Any]] = None
|
|
||||||
input_shields: List[ShieldConfig] = field(default_factory=list)
|
|
||||||
output_shields: List[ShieldConfig] = field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
class StopReason(Enum):
|
|
||||||
"""
|
|
||||||
Stop reasons are used to indicate why the model stopped generating text.
|
|
||||||
"""
|
|
||||||
|
|
||||||
not_stopped = "not_stopped"
|
|
||||||
finished_ok = "finished_ok"
|
|
||||||
max_tokens = "max_tokens"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
@dataclass
|
|
||||||
class Message:
|
|
||||||
role: Role
|
|
||||||
|
|
||||||
# input to the model or output from the model
|
|
||||||
content: Content
|
|
||||||
|
|
||||||
# output from the model
|
|
||||||
tool_calls: List[ToolCall] = field(default_factory=list)
|
|
||||||
|
|
||||||
# input to the model
|
|
||||||
tool_responses: List[ToolResponse] = field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
@dataclass
|
|
||||||
class Dialog:
|
|
||||||
message: Message
|
|
||||||
message_history: List[Message] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SamplingParams:
|
|
||||||
temperature: float = 0.0
|
|
||||||
strategy: str = "greedy"
|
|
||||||
top_p: float = 0.95
|
|
||||||
top_k: int = 0
|
|
||||||
|
|
||||||
|
|
||||||
class PretrainedModel(Enum):
|
|
||||||
llama3_8b = "llama3_8b"
|
|
||||||
llama3_70b = "llama3_70b"
|
|
||||||
|
|
||||||
|
|
||||||
class InstructModel(Enum):
|
|
||||||
llama3_8b_chat = "llama3_8b_chat"
|
|
||||||
llama3_70b_chat = "llama3_70b_chat"
|
|
||||||
|
|
||||||
class RewardModel(Enum):
|
|
||||||
llama3_405b_reward = "llama3_405b_reward"
|
|
3597
source/openapi.html
3597
source/openapi.html
File diff suppressed because it is too large
Load diff
2249
source/openapi.yaml
2249
source/openapi.yaml
File diff suppressed because it is too large
Load diff
|
@ -1,119 +0,0 @@
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Any, Dict, List, Optional, Set, Union
|
|
||||||
|
|
||||||
from model_types import Message, URL
|
|
||||||
|
|
||||||
from strong_typing.schema import json_schema_type
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetColumnType(Enum):
|
|
||||||
dialog = "dialog"
|
|
||||||
text = "text"
|
|
||||||
media = "media"
|
|
||||||
number = "number"
|
|
||||||
json = "json"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
@dataclass
|
|
||||||
class Dataset:
|
|
||||||
"""Dataset to be used for training or evaluating language models."""
|
|
||||||
|
|
||||||
# TODO(ashwin): figure out if we need to add an enum for a "dataset type"
|
|
||||||
|
|
||||||
columns: Dict[str, DatasetColumnType]
|
|
||||||
content_url: URL
|
|
||||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
class OptimizerType(Enum):
|
|
||||||
adam = "adam"
|
|
||||||
adamw = "adamw"
|
|
||||||
sgd = "sgd"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
@dataclass
|
|
||||||
class OptimizerConfig:
|
|
||||||
optimizer_type: OptimizerType
|
|
||||||
lr: float
|
|
||||||
lr_min: float
|
|
||||||
weight_decay: float
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
@dataclass
|
|
||||||
class TrainingConfig:
|
|
||||||
n_epochs: int
|
|
||||||
batch_size: int
|
|
||||||
shuffle: bool
|
|
||||||
n_iters: int
|
|
||||||
|
|
||||||
enable_activation_checkpointing: bool
|
|
||||||
memory_efficient_fsdp_wrap: bool
|
|
||||||
fsdp_cpu_offload: bool
|
|
||||||
|
|
||||||
|
|
||||||
class FinetuningAlgorithm(Enum):
|
|
||||||
full = "full"
|
|
||||||
lora = "lora"
|
|
||||||
qlora = "qlora"
|
|
||||||
dora = "dora"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
@dataclass
|
|
||||||
class LoraFinetuningConfig:
|
|
||||||
lora_attn_modules: List[str]
|
|
||||||
apply_lora_to_mlp: bool
|
|
||||||
apply_lora_to_output: bool
|
|
||||||
rank: int
|
|
||||||
alpha: int
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
@dataclass
|
|
||||||
class QLoraFinetuningConfig(LoraFinetuningConfig):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
@dataclass
|
|
||||||
class DoraFinetuningConfig(LoraFinetuningConfig):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
@dataclass
|
|
||||||
class PostTrainingJobLogStream:
|
|
||||||
"""Stream of logs from a finetuning job."""
|
|
||||||
|
|
||||||
job_uuid: str
|
|
||||||
log_lines: List[str]
|
|
||||||
|
|
||||||
|
|
||||||
class PostTrainingJobStatus(Enum):
|
|
||||||
running = "running"
|
|
||||||
completed = "completed"
|
|
||||||
failed = "failed"
|
|
||||||
scheduled = "scheduled"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Checkpoint:
|
|
||||||
iters: int
|
|
||||||
path: URL
|
|
||||||
|
|
||||||
|
|
||||||
class RLHFAlgorithm(Enum):
|
|
||||||
dpo = "dpo"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
@dataclass
|
|
||||||
class DPOAlignmentConfig:
|
|
||||||
reward_scale: float
|
|
||||||
reward_clip: float
|
|
||||||
epsilon: float
|
|
||||||
gamma: float
|
|
|
@ -1,47 +0,0 @@
|
||||||
from flask import Flask, request, jsonify
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import List, Set, Optional, Union, Protocol
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
app = Flask(__name__)
|
|
||||||
|
|
||||||
from model_types import *
|
|
||||||
from agentic_system_types import *
|
|
||||||
from api_definitions import *
|
|
||||||
|
|
||||||
class AgenticSystemImpl(AgenticSystem):
|
|
||||||
def create_agentic_system(self, request: AgenticSystemCreateRequest) -> AgenticSystemCreateResponse:
|
|
||||||
# Mock implementation
|
|
||||||
return AgenticSystemCreateResponse(agent_id="12345")
|
|
||||||
|
|
||||||
def create_agentic_system_execute(self, request: AgenticSystemExecuteRequest) -> Union[AgenticSystemExecuteResponse, AgenticSystemExecuteResponseStreamChunk]:
|
|
||||||
# Mock implementation
|
|
||||||
return AgenticSystemExecuteResponse(
|
|
||||||
turn=AgenticSystemTurn(
|
|
||||||
user_messages=[],
|
|
||||||
steps=[],
|
|
||||||
response_message=Message(
|
|
||||||
role="assistant",
|
|
||||||
content="Hello, I am an agent. I can help you with your tasks. What can I help you with?",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
agentic_system = AgenticSystemImpl()
|
|
||||||
|
|
||||||
@app.route("/agentic_system/create", methods=["POST"])
|
|
||||||
def create_agentic_system():
|
|
||||||
data = request.json
|
|
||||||
create_request = AgenticSystemCreateRequest(**data)
|
|
||||||
response = agentic_system.create_agentic_system(create_request)
|
|
||||||
return jsonify(response)
|
|
||||||
|
|
||||||
@app.route("/agentic_system/execute", methods=["POST"])
|
|
||||||
def create_agentic_system_execute():
|
|
||||||
data = request.json
|
|
||||||
execute_request = AgenticSystemExecuteRequest(**data)
|
|
||||||
response = agentic_system.create_agentic_system_execute(execute_request)
|
|
||||||
return jsonify(response)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
app.run(debug=True)
|
|
Loading…
Add table
Add a link
Reference in a new issue