mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
586 lines
16 KiB
Python
586 lines
16 KiB
Python
from dataclasses import dataclass, field
|
||
from datetime import datetime
|
||
from enum import Enum
|
||
from typing import Any, Dict, List, Optional, Protocol, Set, Tuple, Union
|
||
|
||
import yaml
|
||
from agentic_system_types import (
|
||
AgenticSystemTurn,
|
||
ExecutionStepType,
|
||
MemoryBank,
|
||
MemoryBankDocument,
|
||
SafetyViolation,
|
||
)
|
||
|
||
from model_types import (
|
||
BuiltinTool,
|
||
Content,
|
||
Dialog,
|
||
InstructModel,
|
||
Message,
|
||
PretrainedModel,
|
||
RewardModel,
|
||
SamplingParams,
|
||
ShieldConfig,
|
||
StopReason,
|
||
ToolCall,
|
||
ToolDefinition,
|
||
ToolResponse,
|
||
URL,
|
||
)
|
||
|
||
from post_training_types import (
|
||
Checkpoint,
|
||
Dataset,
|
||
DoraFinetuningConfig,
|
||
DPOAlignmentConfig,
|
||
FinetuningAlgorithm,
|
||
LoraFinetuningConfig,
|
||
OptimizerConfig,
|
||
PostTrainingJobLogStream,
|
||
PostTrainingJobStatus,
|
||
QLoraFinetuningConfig,
|
||
RLHFAlgorithm,
|
||
TrainingConfig,
|
||
)
|
||
|
||
from pyopenapi import Info, Options, Server, Specification, webmethod
|
||
from strong_typing.schema import json_schema_type
|
||
|
||
|
||
@json_schema_type
|
||
@dataclass
|
||
class CompletionRequest:
|
||
content: Content
|
||
model: PretrainedModel
|
||
sampling_params: SamplingParams = SamplingParams()
|
||
max_tokens: int = 0
|
||
stream: bool = False
|
||
logprobs: bool = False
|
||
|
||
|
||
@json_schema_type
|
||
@dataclass
|
||
class CompletionResponse:
|
||
"""Normal completion response."""
|
||
|
||
content: Content
|
||
stop_reason: Optional[StopReason] = None
|
||
logprobs: Optional[Dict[str, Any]] = None
|
||
|
||
|
||
@json_schema_type
|
||
@dataclass
|
||
class CompletionResponseStreamChunk:
|
||
"""streamed completion response."""
|
||
|
||
text_delta: str
|
||
stop_reason: Optional[StopReason] = None
|
||
logprobs: Optional[Dict[str, Any]] = None
|
||
|
||
|
||
@json_schema_type
|
||
@dataclass
|
||
class ChatCompletionRequest:
|
||
model: InstructModel
|
||
dialog: Dialog
|
||
sampling_params: SamplingParams = SamplingParams()
|
||
|
||
# zero-shot tool definitions as input to the model
|
||
available_tools: List[ToolDefinition] = field(default_factory=list)
|
||
|
||
max_tokens: int = 0
|
||
stream: bool = False
|
||
logprobs: bool = False
|
||
|
||
|
||
@json_schema_type
|
||
@dataclass
|
||
class ChatCompletionResponse:
|
||
"""Normal chat completion response."""
|
||
|
||
content: Content
|
||
|
||
# note: multiple tool calls can be generated in a single response
|
||
tool_calls: List[ToolCall] = field(default_factory=list)
|
||
|
||
stop_reason: Optional[StopReason] = None
|
||
logprobs: Optional[Dict[str, Any]] = None
|
||
|
||
|
||
@json_schema_type
|
||
@dataclass
|
||
class ChatCompletionResponseStreamChunk:
|
||
"""Streamed chat completion response. The actual response is a series of such objects."""
|
||
|
||
text_delta: str
|
||
stop_reason: Optional[StopReason] = None
|
||
tool_call: Optional[ToolCall] = None
|
||
|
||
|
||
@json_schema_type
|
||
@dataclass
|
||
class BatchCompletionRequest:
|
||
model: PretrainedModel
|
||
content_batch: List[Content]
|
||
sampling_params: SamplingParams = SamplingParams()
|
||
max_tokens: int = 0
|
||
logprobs: bool = False
|
||
|
||
|
||
@json_schema_type
|
||
@dataclass
|
||
class BatchChatCompletionRequest:
|
||
model: InstructModel
|
||
batch_dialogs: List[Dialog]
|
||
sampling_params: SamplingParams = SamplingParams()
|
||
|
||
# zero-shot tool definitions as input to the model
|
||
available_tools: List[ToolDefinition] = field(default_factory=list)
|
||
|
||
max_tokens: int = 0
|
||
logprobs: bool = False
|
||
|
||
|
||
class Inference(Protocol):
|
||
|
||
@webmethod(route="/inference/completion")
|
||
def post_completion(
|
||
self,
|
||
request: CompletionRequest,
|
||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
|
||
|
||
@webmethod(route="/inference/chat_completion")
|
||
def post_chat_completion(
|
||
self,
|
||
request: ChatCompletionRequest,
|
||
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ...
|
||
|
||
@webmethod(route="/inference/batch_completion")
|
||
def post_batch_completion(
|
||
self,
|
||
request: BatchCompletionRequest,
|
||
) -> List[CompletionResponse]: ...
|
||
|
||
@webmethod(route="/inference/batch_chat_completion")
|
||
def post_batch_chat_completion(
|
||
self,
|
||
request: BatchChatCompletionRequest,
|
||
) -> List[ChatCompletionResponse]: ...
|
||
|
||
|
||
@dataclass
|
||
class AgenticSystemCreateRequest:
|
||
uuid: str
|
||
|
||
instructions: str
|
||
model: InstructModel
|
||
|
||
# zero-shot or built-in tool configurations as input to the model
|
||
available_tools: List[ToolDefinition] = field(default_factory=list)
|
||
|
||
# tools which aren't executable are emitted as tool calls which the users can
|
||
# execute themselves.
|
||
executable_tools: Set[str] = field(default_factory=set)
|
||
|
||
memory_bank_uuids: List[str] = field(default_factory=list)
|
||
|
||
input_shields: List[ShieldConfig] = field(default_factory=list)
|
||
output_shields: List[ShieldConfig] = field(default_factory=list)
|
||
|
||
|
||
@json_schema_type
|
||
@dataclass
|
||
class AgenticSystemCreateResponse:
|
||
agent_uuid: str
|
||
|
||
|
||
@json_schema_type
|
||
@dataclass
|
||
class AgenticSystemExecuteRequest:
|
||
agent_uuid: str
|
||
messages: List[Message]
|
||
turn_history: List[AgenticSystemTurn] = None
|
||
stream: bool = False
|
||
|
||
|
||
@json_schema_type
|
||
@dataclass
|
||
class AgenticSystemExecuteResponse:
|
||
"""non-stream response from the agentic system."""
|
||
|
||
turn: AgenticSystemTurn
|
||
|
||
|
||
class AgenticSystemExecuteResponseEventType(Enum):
|
||
"""The type of event."""
|
||
|
||
step_start = "step_start"
|
||
step_end = "step_end"
|
||
step_progress = "step_progress"
|
||
|
||
|
||
@json_schema_type
|
||
@dataclass
|
||
class AgenticSystemExecuteResponseStreamChunk:
|
||
"""Streamed agent execution response."""
|
||
|
||
event_type: AgenticSystemExecuteResponseEventType
|
||
|
||
step_uuid: str
|
||
step_type: ExecutionStepType
|
||
|
||
# TODO(ashwin): maybe add more structure here and do this as a proper tagged union
|
||
violation: Optional[SafetyViolation] = None
|
||
tool_call: Optional[ToolCall] = None
|
||
tool_response_delta: Optional[ToolResponse] = None
|
||
response_text_delta: Optional[str] = None
|
||
retrieved_document: Optional[MemoryBankDocument] = None
|
||
|
||
stop_reason: Optional[StopReason] = None
|
||
|
||
|
||
class AgenticSystem(Protocol):
|
||
|
||
@webmethod(route="/agentic_system/create")
|
||
def create_agentic_system(
|
||
self,
|
||
request: AgenticSystemCreateRequest,
|
||
) -> AgenticSystemCreateResponse: ...
|
||
|
||
@webmethod(route="/agentic_system/execute")
|
||
def create_agentic_system_execute(
|
||
self,
|
||
request: AgenticSystemExecuteRequest,
|
||
) -> Union[
|
||
AgenticSystemExecuteResponse, AgenticSystemExecuteResponseStreamChunk
|
||
]: ...
|
||
|
||
@webmethod(route="/agentic_system/delete")
|
||
def delete_agentic_system(
|
||
self,
|
||
agent_id: str,
|
||
) -> None: ...
|
||
|
||
|
||
class MemoryBanks(Protocol):
|
||
@webmethod(route="/memory_banks/create")
|
||
def post_create_memory_bank(
|
||
self,
|
||
bank_uuid: str,
|
||
bank_name: str,
|
||
documents: List[MemoryBankDocument],
|
||
) -> None: ...
|
||
|
||
@webmethod(route="/memory_banks/get")
|
||
def get_memory_banks(
|
||
self
|
||
) -> List[MemoryBank]: ...
|
||
|
||
@webmethod(route="/memory_banks/drop")
|
||
def delete_memory_bank(
|
||
self,
|
||
bank_uuid: str,
|
||
) -> str: ...
|
||
|
||
@webmethod(route="/memory_bank/insert")
|
||
def post_insert_memory_documents(
|
||
self,
|
||
bank_uuid: str,
|
||
documents: List[MemoryBankDocument],
|
||
) -> None: ...
|
||
|
||
@webmethod(route="/memory_bank/update")
|
||
def post_update_memory_documents(
|
||
self,
|
||
bank_uuid: str,
|
||
documents: List[MemoryBankDocument],
|
||
) -> None: ...
|
||
|
||
@webmethod(route="/memory_bank/get")
|
||
def get_memory_documents(
|
||
self,
|
||
bank_uuid: str,
|
||
document_uuids: List[str],
|
||
) -> List[MemoryBankDocument]: ...
|
||
|
||
@webmethod(route="/memory_bank/delete")
|
||
def delete_memory_documents(
|
||
self,
|
||
bank_uuid: str,
|
||
document_uuids: List[str],
|
||
) -> List[str]: ...
|
||
|
||
|
||
@dataclass
|
||
class KPromptGenerations:
|
||
dialog: Dialog
|
||
k_generations: List[Message]
|
||
|
||
|
||
@json_schema_type
|
||
@dataclass
|
||
class ScoredMessage:
|
||
message: Message
|
||
score: float
|
||
|
||
|
||
@json_schema_type
|
||
@dataclass
|
||
class KScoredPromptGenerations:
|
||
prompt: Message
|
||
k_scored_generations: List[ScoredMessage]
|
||
|
||
|
||
@json_schema_type
|
||
@dataclass
|
||
class RewardScoringRequest:
|
||
"""Request to score a reward function. A list of prompts and a list of responses per prompt."""
|
||
|
||
prompt_generations: List[KPromptGenerations]
|
||
model: RewardModel
|
||
|
||
|
||
@json_schema_type
|
||
@dataclass
|
||
class RewardScoringResponse:
|
||
"""Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold."""
|
||
|
||
scored_generations: List[KScoredPromptGenerations]
|
||
|
||
|
||
class RewardScoring(Protocol):
|
||
@webmethod(route="/reward_scoring/score")
|
||
def post_score(
|
||
self,
|
||
request: RewardScoringRequest,
|
||
) -> Union[RewardScoringResponse]: ...
|
||
|
||
|
||
class FilteringFunction(Enum):
|
||
"""The type of filtering function."""
|
||
|
||
none = "none"
|
||
random = "random"
|
||
top_k = "top_k"
|
||
top_p = "top_p"
|
||
top_k_top_p = "top_k_top_p"
|
||
sigmoid = "sigmoid"
|
||
|
||
|
||
@json_schema_type
|
||
@dataclass
|
||
class SyntheticDataGenerationRequest:
|
||
"""Request to generate synthetic data. A small batch of prompts and a filtering function"""
|
||
|
||
prompts: List[Message]
|
||
filtering_function: FilteringFunction = FilteringFunction.none
|
||
reward_scoring: Optional[RewardScoring] = None
|
||
|
||
|
||
@json_schema_type
|
||
@dataclass
|
||
class SyntheticDataGenerationResponse:
|
||
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."""
|
||
|
||
synthetic_data: List[KScoredPromptGenerations]
|
||
statistics: Optional[Dict[str, Any]] = None
|
||
|
||
|
||
class SyntheticDataGeneration(Protocol):
|
||
@webmethod(route="/synthetic_data_generation/generate")
|
||
def post_generate(
|
||
self,
|
||
request: SyntheticDataGenerationRequest,
|
||
) -> Union[SyntheticDataGenerationResponse]: ...
|
||
|
||
|
||
@json_schema_type
|
||
@dataclass
|
||
class CreateDatasetRequest:
|
||
"""Request to create a dataset."""
|
||
|
||
uuid: str
|
||
dataset: Dataset
|
||
|
||
|
||
class Datasets(Protocol):
|
||
@webmethod(route="/datasets/create")
|
||
def create_dataset(
|
||
self,
|
||
request: CreateDatasetRequest,
|
||
) -> None: ...
|
||
|
||
@webmethod(route="/datasets/get")
|
||
def get_dataset(
|
||
self,
|
||
dataset_id: str,
|
||
) -> Dataset: ...
|
||
|
||
@webmethod(route="/datasets/delete")
|
||
def delete_dataset(
|
||
self,
|
||
dataset_id: str,
|
||
) -> None: ...
|
||
|
||
|
||
@json_schema_type
|
||
@dataclass
|
||
class PostTrainingSFTRequest:
|
||
"""Request to finetune a model."""
|
||
|
||
job_uuid: str
|
||
|
||
model: PretrainedModel
|
||
dataset: Dataset
|
||
validation_dataset: Dataset
|
||
|
||
algorithm: FinetuningAlgorithm
|
||
algorithm_config: Union[
|
||
LoraFinetuningConfig, QLoraFinetuningConfig, DoraFinetuningConfig
|
||
]
|
||
|
||
optimizer_config: OptimizerConfig
|
||
training_config: TrainingConfig
|
||
|
||
# TODO: define these
|
||
hyperparam_search_config: Dict[str, Any]
|
||
logger_config: Dict[str, Any]
|
||
|
||
|
||
@json_schema_type
|
||
@dataclass
|
||
class PostTrainingRLHFRequest:
|
||
"""Request to finetune a model."""
|
||
|
||
job_uuid: str
|
||
|
||
finetuned_model: URL
|
||
|
||
dataset: Dataset
|
||
validation_dataset: Dataset
|
||
|
||
algorithm: RLHFAlgorithm
|
||
algorithm_config: Union[DPOAlignmentConfig]
|
||
|
||
optimizer_config: OptimizerConfig
|
||
training_config: TrainingConfig
|
||
|
||
# TODO: define these
|
||
hyperparam_search_config: Dict[str, Any]
|
||
logger_config: Dict[str, Any]
|
||
|
||
|
||
@json_schema_type
|
||
@dataclass
|
||
class PostTrainingJobStatusResponse:
|
||
"""Status of a finetuning job."""
|
||
|
||
job_uuid: str
|
||
status: PostTrainingJobStatus
|
||
|
||
scheduled_at: Optional[datetime] = None
|
||
started_at: Optional[datetime] = None
|
||
completed_at: Optional[datetime] = None
|
||
|
||
resources_allocated: Optional[Dict[str, Any]] = None
|
||
|
||
checkpoints: List[Checkpoint] = field(default_factory=list)
|
||
|
||
|
||
@json_schema_type
|
||
@dataclass
|
||
class PostTrainingJobArtifactsResponse:
|
||
"""Artifacts of a finetuning job."""
|
||
|
||
job_uuid: str
|
||
checkpoints: List[Checkpoint] = field(default_factory=list)
|
||
|
||
# TODO(ashwin): metrics, evals
|
||
|
||
|
||
class PostTraining(Protocol):
|
||
@webmethod(route="/post_training/supervised_fine_tune/")
|
||
def post_supervised_fine_tune(
|
||
self,
|
||
request: PostTrainingSFTRequest,
|
||
) -> None: ...
|
||
|
||
@webmethod(route="/post_training/preference_optimize/")
|
||
def post_preference_optimize(
|
||
self,
|
||
request: PostTrainingRLHFRequest,
|
||
) -> None: ...
|
||
|
||
# sends SSE stream of logs
|
||
@webmethod(route="/post_training/job/logs")
|
||
def get_training_log_stream(self, job_uuid: str) -> PostTrainingJobLogStream: ...
|
||
|
||
@webmethod(route="/post_training/job/status")
|
||
def get_training_job_status(
|
||
self, job_uuid: str
|
||
) -> PostTrainingJobStatusResponse: ...
|
||
|
||
@webmethod(route="/post_training/job/cancel")
|
||
def cancel_training_job(self, job_uuid: str) -> None: ...
|
||
|
||
@webmethod(route="/post_training/job/artifacts")
|
||
def get_training_job_artifacts(
|
||
self, job_uuid: str
|
||
) -> PostTrainingJobArtifactsResponse: ...
|
||
|
||
|
||
class LlamaStackEndpoints(
|
||
Inference,
|
||
AgenticSystem,
|
||
RewardScoring,
|
||
SyntheticDataGeneration,
|
||
Datasets,
|
||
PostTraining,
|
||
MemoryBanks,
|
||
): ...
|
||
|
||
|
||
if __name__ == "__main__":
|
||
print("Converting the spec to YAML (openapi.yaml) and HTML (openapi.html)")
|
||
spec = Specification(
|
||
LlamaStackEndpoints,
|
||
Options(
|
||
server=Server(url="http://any-hosted-llama-stack.com"),
|
||
info=Info(
|
||
title="[DRAFT] Llama Stack Specification",
|
||
version="0.0.1",
|
||
description="""Meta has built out a fairly sophisticated platform internally to post train, evaluate, and
|
||
serve Llama models to support Meta’s products. Given the newer capabilities of the llama models,
|
||
the model development and model serving capabilities of the platform need to be enhanced in
|
||
specific ways in order to best leverage the models. For example, the inference platform needs
|
||
to support code execution to take advantage of the built-in knowledge of tools of the model.
|
||
The largest models are of high enough quality to be used to generate synthetic data or be used
|
||
as reward models. There are specific fine tuning and quantization techniques that we have found
|
||
result in the best performing Llama models. We would like to share ways in which an LLM Ops
|
||
toolchain can be designed by leveraging our learnings in getting Llama models to power Meta’s products.
|
||
<br>
|
||
In addition, the Llama 3 models Meta will release in July should not just be seen as a model, but
|
||
really as a system starting the transition towards an entity capable of performing "agentic" tasks
|
||
which require the ability to act as the central planner and break a task down and perform multi-step
|
||
reasoning and call tools for specific operations. In addition, there needs to be general model-level
|
||
safety checks as well as task-specific safety checks that are performed at a system level.
|
||
<br>
|
||
We are defining the Llama Stack as a set of APIs and standards by synthesizing our learnings while
|
||
working with Llama models. The APIs are divided into the llama-toolchain-api and the llama-agentic-system-api.
|
||
These APIs provide a coherent way for model developers to fine tune and serve Llama models, and agentic app
|
||
developers to leverage all the capabilities of the Llama models seamlessly. We would like to work with the
|
||
ecosystem to enhance and simplify the API. In addition, we will be releasing a plug-in architecture to allow
|
||
creating distributions of the llama stack with different implementations.
|
||
<br>
|
||
This is the specification of the llama stack that provides
|
||
a set of endpoints and their corresponding interfaces that are tailored to
|
||
best leverage Llama Models. The specification is still in draft and subject to change.""",
|
||
),
|
||
),
|
||
)
|
||
with open("openapi.yaml", "w", encoding="utf-8") as fp:
|
||
yaml.dump(spec.get_json(), fp, allow_unicode=True)
|
||
|
||
with open("openapi.html", "w") as fp:
|
||
spec.write_html(fp, pretty_print=True)
|