more definitions

This commit is contained in:
Ashwin Bharambe 2024-07-08 16:35:28 -07:00
parent 722d20c6de
commit 6e4586ba7a
3 changed files with 775 additions and 178 deletions

View file

@ -1,6 +1,6 @@
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Protocol, Union
from typing import Any, Dict, List, Optional, Protocol, Set, Union
import yaml
@ -45,16 +45,6 @@ class Role(Enum):
tool = "tool"
class StopReason(Enum):
"""
Stop reasons are used to indicate why the model stopped generating text.
"""
not_stopped = "not_stopped"
finished_ok = "finished_ok"
max_tokens = "max_tokens"
@dataclass
class ToolCall:
"""
@ -77,6 +67,28 @@ class ToolDefinition:
parameters: Dict[str, Any]
# TODO: we need to document the parameters for the tool calls
class BuiltinTool(Enum):
"""
Builtin tools are tools the model is natively aware of and was potentially fine-tuned with.
"""
web_search = "web_search"
math = "math"
image_gen = "image_gen"
code_interpreter = "code_interpreter"
class StopReason(Enum):
"""
Stop reasons are used to indicate why the model stopped generating text.
"""
not_stopped = "not_stopped"
finished_ok = "finished_ok"
max_tokens = "max_tokens"
@json_schema_type
@dataclass
class Message:
@ -85,9 +97,6 @@ class Message:
# input to the model or output from the model
content: Content
# zero-shot tool definitions as input to the model
tool_definitions: List[ToolDefinition] = field(default_factory=list)
# output from the model
tool_calls: List[ToolCall] = field(default_factory=list)
@ -95,45 +104,6 @@ class Message:
tool_responses: List[ToolResponse] = field(default_factory=list)
@json_schema_type
@dataclass
class CompletionResponse:
"""Normal completion response."""
content: Content
stop_reason: StopReason
logprobs: Optional[Dict[str, Any]] = None
@json_schema_type
@dataclass
class StreamedCompletionResponse:
"""streamed completion response."""
text_delta: str
stop_reason: StopReason
logprobs: Optional[Dict[str, Any]] = None
@json_schema_type
@dataclass
class ChatCompletionResponse:
"""Normal chat completion response."""
content: Content
stop_reason: StopReason
tool_calls: List[ToolCall] = field(default_factory=list)
logprobs: Optional[Dict[str, Any]] = None
@json_schema_type
@dataclass
class StreamedChatCompletionResponse:
"""Streamed chat completion response."""
text_delta: str
stop_reason: StopReason
tool_call: Optional[ToolCall] = None
@dataclass
class SamplingParams:
temperature: float = 0.0
@ -165,16 +135,69 @@ class CompletionRequest:
@json_schema_type
@dataclass
class ChatCompletionRequest:
class CompletionResponse:
"""Normal completion response."""
content: Content
stop_reason: Optional[StopReason] = None
logprobs: Optional[Dict[str, Any]] = None
@json_schema_type
@dataclass
class StreamedCompletionResponse:
"""streamed completion response."""
text_delta: str
stop_reason: Optional[StopReason] = None
logprobs: Optional[Dict[str, Any]] = None
@dataclass
class ChatCompletionRequestCommon:
message: Message
message_history: List[Message] = None
model: InstructModel = InstructModel.llama3_8b_chat
sampling_params: SamplingParams = SamplingParams()
# zero-shot tool definitions as input to the model
available_tools: List[Union[BuiltinTool, ToolDefinition]] = field(
default_factory=list
)
@json_schema_type
@dataclass
class ChatCompletionRequest(ChatCompletionRequestCommon):
max_tokens: int = 0
stream: bool = False
logprobs: bool = False
@json_schema_type
@dataclass
class ChatCompletionResponse:
"""Normal chat completion response."""
content: Content
# note: multiple tool calls can be generated in a single response
tool_calls: List[ToolCall] = field(default_factory=list)
stop_reason: Optional[StopReason] = None
logprobs: Optional[Dict[str, Any]] = None
@json_schema_type
@dataclass
class StreamedChatCompletionResponse:
"""Streamed chat completion response."""
text_delta: str
stop_reason: Optional[StopReason] = None
tool_call: Optional[ToolCall] = None
class Inference(Protocol):
def post_completion(
@ -188,19 +211,41 @@ class Inference(Protocol):
) -> Union[ChatCompletionResponse, StreamedChatCompletionResponse]: ...
@json_schema_type
@dataclass
class AgenticSystemExecuteRequest(ChatCompletionRequestCommon):
executable_tools: Set[str] = field(default_factory=set)
stream: bool = False
@json_schema_type
@dataclass
class AgenticSystemExecuteRequest:
message: Message
message_history: List[Message] = None
model: InstructModel = InstructModel.llama3_8b_chat
sampling_params: SamplingParams = SamplingParams()
class AgenticSystemExecuteResponse:
"""Normal chat completion response."""
content: Content
stop_reason: StopReason
tool_calls: List[ToolCall] = field(default_factory=list)
logprobs: Optional[Dict[str, Any]] = None
@json_schema_type
@dataclass
class StreamedAgenticSystemExecuteResponse:
"""Streamed chat completion response."""
text_delta: str
stop_reason: StopReason
tool_call: Optional[ToolCall] = None
class AgenticSystem(Protocol):
@webmethod(route="/agentic/system/execute")
def create_agentic_system_execute(self,) -> str: ...
def create_agentic_system_execute(
self,
request: AgenticSystemExecuteRequest,
) -> Union[AgenticSystemExecuteResponse, StreamedAgenticSystemExecuteResponse]: ...
class Endpoint(Inference, AgenticSystem): ...