mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-01 20:18:50 +00:00
more definitions
This commit is contained in:
parent
722d20c6de
commit
6e4586ba7a
3 changed files with 775 additions and 178 deletions
165
source/defn.py
165
source/defn.py
|
@ -1,6 +1,6 @@
|
|||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Protocol, Union
|
||||
from typing import Any, Dict, List, Optional, Protocol, Set, Union
|
||||
|
||||
import yaml
|
||||
|
||||
|
@ -45,16 +45,6 @@ class Role(Enum):
|
|||
tool = "tool"
|
||||
|
||||
|
||||
class StopReason(Enum):
|
||||
"""
|
||||
Stop reasons are used to indicate why the model stopped generating text.
|
||||
"""
|
||||
|
||||
not_stopped = "not_stopped"
|
||||
finished_ok = "finished_ok"
|
||||
max_tokens = "max_tokens"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCall:
|
||||
"""
|
||||
|
@ -77,6 +67,28 @@ class ToolDefinition:
|
|||
parameters: Dict[str, Any]
|
||||
|
||||
|
||||
# TODO: we need to document the parameters for the tool calls
|
||||
class BuiltinTool(Enum):
|
||||
"""
|
||||
Builtin tools are tools the model is natively aware of and was potentially fine-tuned with.
|
||||
"""
|
||||
|
||||
web_search = "web_search"
|
||||
math = "math"
|
||||
image_gen = "image_gen"
|
||||
code_interpreter = "code_interpreter"
|
||||
|
||||
|
||||
class StopReason(Enum):
|
||||
"""
|
||||
Stop reasons are used to indicate why the model stopped generating text.
|
||||
"""
|
||||
|
||||
not_stopped = "not_stopped"
|
||||
finished_ok = "finished_ok"
|
||||
max_tokens = "max_tokens"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
@dataclass
|
||||
class Message:
|
||||
|
@ -85,9 +97,6 @@ class Message:
|
|||
# input to the model or output from the model
|
||||
content: Content
|
||||
|
||||
# zero-shot tool definitions as input to the model
|
||||
tool_definitions: List[ToolDefinition] = field(default_factory=list)
|
||||
|
||||
# output from the model
|
||||
tool_calls: List[ToolCall] = field(default_factory=list)
|
||||
|
||||
|
@ -95,45 +104,6 @@ class Message:
|
|||
tool_responses: List[ToolResponse] = field(default_factory=list)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
@dataclass
|
||||
class CompletionResponse:
|
||||
"""Normal completion response."""
|
||||
content: Content
|
||||
stop_reason: StopReason
|
||||
logprobs: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
@dataclass
|
||||
class StreamedCompletionResponse:
|
||||
"""streamed completion response."""
|
||||
text_delta: str
|
||||
stop_reason: StopReason
|
||||
logprobs: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
@dataclass
|
||||
class ChatCompletionResponse:
|
||||
"""Normal chat completion response."""
|
||||
|
||||
content: Content
|
||||
stop_reason: StopReason
|
||||
tool_calls: List[ToolCall] = field(default_factory=list)
|
||||
logprobs: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
@dataclass
|
||||
class StreamedChatCompletionResponse:
|
||||
"""Streamed chat completion response."""
|
||||
|
||||
text_delta: str
|
||||
stop_reason: StopReason
|
||||
tool_call: Optional[ToolCall] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplingParams:
|
||||
temperature: float = 0.0
|
||||
|
@ -165,16 +135,69 @@ class CompletionRequest:
|
|||
|
||||
@json_schema_type
|
||||
@dataclass
|
||||
class ChatCompletionRequest:
|
||||
class CompletionResponse:
|
||||
"""Normal completion response."""
|
||||
|
||||
content: Content
|
||||
stop_reason: Optional[StopReason] = None
|
||||
logprobs: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
@dataclass
|
||||
class StreamedCompletionResponse:
|
||||
"""streamed completion response."""
|
||||
|
||||
text_delta: str
|
||||
stop_reason: Optional[StopReason] = None
|
||||
logprobs: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatCompletionRequestCommon:
|
||||
message: Message
|
||||
message_history: List[Message] = None
|
||||
model: InstructModel = InstructModel.llama3_8b_chat
|
||||
sampling_params: SamplingParams = SamplingParams()
|
||||
|
||||
# zero-shot tool definitions as input to the model
|
||||
available_tools: List[Union[BuiltinTool, ToolDefinition]] = field(
|
||||
default_factory=list
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
@dataclass
|
||||
class ChatCompletionRequest(ChatCompletionRequestCommon):
|
||||
max_tokens: int = 0
|
||||
stream: bool = False
|
||||
logprobs: bool = False
|
||||
|
||||
|
||||
@json_schema_type
|
||||
@dataclass
|
||||
class ChatCompletionResponse:
|
||||
"""Normal chat completion response."""
|
||||
|
||||
content: Content
|
||||
|
||||
# note: multiple tool calls can be generated in a single response
|
||||
tool_calls: List[ToolCall] = field(default_factory=list)
|
||||
|
||||
stop_reason: Optional[StopReason] = None
|
||||
logprobs: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
@dataclass
|
||||
class StreamedChatCompletionResponse:
|
||||
"""Streamed chat completion response."""
|
||||
|
||||
text_delta: str
|
||||
stop_reason: Optional[StopReason] = None
|
||||
tool_call: Optional[ToolCall] = None
|
||||
|
||||
|
||||
class Inference(Protocol):
|
||||
|
||||
def post_completion(
|
||||
|
@ -188,19 +211,41 @@ class Inference(Protocol):
|
|||
) -> Union[ChatCompletionResponse, StreamedChatCompletionResponse]: ...
|
||||
|
||||
|
||||
@json_schema_type
|
||||
@dataclass
|
||||
class AgenticSystemExecuteRequest(ChatCompletionRequestCommon):
|
||||
executable_tools: Set[str] = field(default_factory=set)
|
||||
stream: bool = False
|
||||
|
||||
|
||||
@json_schema_type
|
||||
@dataclass
|
||||
class AgenticSystemExecuteRequest:
|
||||
message: Message
|
||||
message_history: List[Message] = None
|
||||
model: InstructModel = InstructModel.llama3_8b_chat
|
||||
sampling_params: SamplingParams = SamplingParams()
|
||||
class AgenticSystemExecuteResponse:
|
||||
"""Normal chat completion response."""
|
||||
|
||||
content: Content
|
||||
stop_reason: StopReason
|
||||
tool_calls: List[ToolCall] = field(default_factory=list)
|
||||
logprobs: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
@dataclass
|
||||
class StreamedAgenticSystemExecuteResponse:
|
||||
"""Streamed chat completion response."""
|
||||
|
||||
text_delta: str
|
||||
stop_reason: StopReason
|
||||
tool_call: Optional[ToolCall] = None
|
||||
|
||||
|
||||
class AgenticSystem(Protocol):
|
||||
|
||||
@webmethod(route="/agentic/system/execute")
|
||||
def create_agentic_system_execute(self,) -> str: ...
|
||||
def create_agentic_system_execute(
|
||||
self,
|
||||
request: AgenticSystemExecuteRequest,
|
||||
) -> Union[AgenticSystemExecuteResponse, StreamedAgenticSystemExecuteResponse]: ...
|
||||
|
||||
|
||||
class Endpoint(Inference, AgenticSystem): ...
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue