llama-stack-mirror/source/defn.py
2024-07-08 16:35:28 -07:00

271 lines
5.9 KiB
Python

from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Protocol, Set, Union
import yaml
from pyopenapi import Info, Options, Server, Specification, webmethod
from strong_typing.schema import json_schema_type
@json_schema_type(
schema={"type": "string", "format": "uri", "pattern": "^(https?://|file://|data:)"}
)
@dataclass
class URL:
url: str
def __str__(self) -> str:
return self.url
@json_schema_type
@dataclass
class Attachment:
"""
Attachments are used to refer to external resources, such as images, videos, audio, etc.
"""
url: URL
mime_type: str
Content = Union[
str,
Attachment,
List[Union[str, Attachment]],
]
class Role(Enum):
system = "system"
user = "user"
assistant = "assistant"
tool = "tool"
@dataclass
class ToolCall:
"""
A tool call is a request to a tool.
"""
tool_name: str
arguments: Dict[str, Any]
@dataclass
class ToolResponse:
tool_name: str
response: str
@dataclass
class ToolDefinition:
tool_name: str
parameters: Dict[str, Any]
# TODO: we need to document the parameters for the tool calls
class BuiltinTool(Enum):
"""
Builtin tools are tools the model is natively aware of and was potentially fine-tuned with.
"""
web_search = "web_search"
math = "math"
image_gen = "image_gen"
code_interpreter = "code_interpreter"
class StopReason(Enum):
"""
Stop reasons are used to indicate why the model stopped generating text.
"""
not_stopped = "not_stopped"
finished_ok = "finished_ok"
max_tokens = "max_tokens"
@json_schema_type
@dataclass
class Message:
role: Role
# input to the model or output from the model
content: Content
# output from the model
tool_calls: List[ToolCall] = field(default_factory=list)
# input to the model
tool_responses: List[ToolResponse] = field(default_factory=list)
@dataclass
class SamplingParams:
temperature: float = 0.0
strategy: str = "greedy"
top_p: float = 0.95
top_k: int = 0
class PretrainedModel(Enum):
llama3_8b = "llama3_8b"
llama3_70b = "llama3_70b"
class InstructModel(Enum):
llama3_8b_chat = "llama3_8b_chat"
llama3_70b_chat = "llama3_70b_chat"
@json_schema_type
@dataclass
class CompletionRequest:
content: Content
model: PretrainedModel = PretrainedModel.llama3_8b
sampling_params: SamplingParams = SamplingParams()
max_tokens: int = 0
stream: bool = False
logprobs: bool = False
@json_schema_type
@dataclass
class CompletionResponse:
"""Normal completion response."""
content: Content
stop_reason: Optional[StopReason] = None
logprobs: Optional[Dict[str, Any]] = None
@json_schema_type
@dataclass
class StreamedCompletionResponse:
"""streamed completion response."""
text_delta: str
stop_reason: Optional[StopReason] = None
logprobs: Optional[Dict[str, Any]] = None
@dataclass
class ChatCompletionRequestCommon:
message: Message
message_history: List[Message] = None
model: InstructModel = InstructModel.llama3_8b_chat
sampling_params: SamplingParams = SamplingParams()
# zero-shot tool definitions as input to the model
available_tools: List[Union[BuiltinTool, ToolDefinition]] = field(
default_factory=list
)
@json_schema_type
@dataclass
class ChatCompletionRequest(ChatCompletionRequestCommon):
max_tokens: int = 0
stream: bool = False
logprobs: bool = False
@json_schema_type
@dataclass
class ChatCompletionResponse:
"""Normal chat completion response."""
content: Content
# note: multiple tool calls can be generated in a single response
tool_calls: List[ToolCall] = field(default_factory=list)
stop_reason: Optional[StopReason] = None
logprobs: Optional[Dict[str, Any]] = None
@json_schema_type
@dataclass
class StreamedChatCompletionResponse:
"""Streamed chat completion response."""
text_delta: str
stop_reason: Optional[StopReason] = None
tool_call: Optional[ToolCall] = None
class Inference(Protocol):
def post_completion(
self,
request: CompletionRequest,
) -> Union[CompletionResponse, StreamedCompletionResponse]: ...
def post_chat_completion(
self,
request: ChatCompletionRequest,
) -> Union[ChatCompletionResponse, StreamedChatCompletionResponse]: ...
@json_schema_type
@dataclass
class AgenticSystemExecuteRequest(ChatCompletionRequestCommon):
executable_tools: Set[str] = field(default_factory=set)
stream: bool = False
@json_schema_type
@dataclass
class AgenticSystemExecuteResponse:
"""Normal chat completion response."""
content: Content
stop_reason: StopReason
tool_calls: List[ToolCall] = field(default_factory=list)
logprobs: Optional[Dict[str, Any]] = None
@json_schema_type
@dataclass
class StreamedAgenticSystemExecuteResponse:
"""Streamed chat completion response."""
text_delta: str
stop_reason: StopReason
tool_call: Optional[ToolCall] = None
class AgenticSystem(Protocol):
@webmethod(route="/agentic/system/execute")
def create_agentic_system_execute(
self,
request: AgenticSystemExecuteRequest,
) -> Union[AgenticSystemExecuteResponse, StreamedAgenticSystemExecuteResponse]: ...
class Endpoint(Inference, AgenticSystem): ...
if __name__ == "__main__":
print("Converting the spec to YAML (openapi.yaml) and HTML (openapi.html)")
spec = Specification(
Endpoint,
Options(
server=Server(url="http://llama.meta.com"),
info=Info(
title="Llama Stack specification",
version="0.1",
description="This is the llama stack",
),
),
)
with open("openapi.yaml", "w", encoding="utf-8") as fp:
yaml.dump(spec.get_json(), fp, allow_unicode=True)
with open("openapi.html", "w") as fp:
spec.write_html(fp, pretty_print=True)