making the API python based with a converter script

This commit is contained in:
Ashwin Bharambe 2024-07-07 17:01:14 -07:00
parent 1a2b17af7f
commit 722d20c6de
5 changed files with 1565 additions and 0 deletions

226
source/defn.py Normal file
View file

@ -0,0 +1,226 @@
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Protocol, Union
import yaml
from pyopenapi import Info, Options, Server, Specification, webmethod
from strong_typing.schema import json_schema_type
@json_schema_type(
schema={"type": "string", "format": "uri", "pattern": "^(https?://|file://|data:)"}
)
@dataclass
class URL:
url: str
def __str__(self) -> str:
return self.url
@json_schema_type
@dataclass
class Attachment:
"""
Attachments are used to refer to external resources, such as images, videos, audio, etc.
"""
url: URL
mime_type: str
Content = Union[
str,
Attachment,
List[Union[str, Attachment]],
]
class Role(Enum):
system = "system"
user = "user"
assistant = "assistant"
tool = "tool"
class StopReason(Enum):
"""
Stop reasons are used to indicate why the model stopped generating text.
"""
not_stopped = "not_stopped"
finished_ok = "finished_ok"
max_tokens = "max_tokens"
@dataclass
class ToolCall:
"""
A tool call is a request to a tool.
"""
tool_name: str
arguments: Dict[str, Any]
@dataclass
class ToolResponse:
tool_name: str
response: str
@dataclass
class ToolDefinition:
tool_name: str
parameters: Dict[str, Any]
@json_schema_type
@dataclass
class Message:
role: Role
# input to the model or output from the model
content: Content
# zero-shot tool definitions as input to the model
tool_definitions: List[ToolDefinition] = field(default_factory=list)
# output from the model
tool_calls: List[ToolCall] = field(default_factory=list)
# input to the model
tool_responses: List[ToolResponse] = field(default_factory=list)
@json_schema_type
@dataclass
class CompletionResponse:
"""Normal completion response."""
content: Content
stop_reason: StopReason
logprobs: Optional[Dict[str, Any]] = None
@json_schema_type
@dataclass
class StreamedCompletionResponse:
"""streamed completion response."""
text_delta: str
stop_reason: StopReason
logprobs: Optional[Dict[str, Any]] = None
@json_schema_type
@dataclass
class ChatCompletionResponse:
"""Normal chat completion response."""
content: Content
stop_reason: StopReason
tool_calls: List[ToolCall] = field(default_factory=list)
logprobs: Optional[Dict[str, Any]] = None
@json_schema_type
@dataclass
class StreamedChatCompletionResponse:
"""Streamed chat completion response."""
text_delta: str
stop_reason: StopReason
tool_call: Optional[ToolCall] = None
@dataclass
class SamplingParams:
temperature: float = 0.0
strategy: str = "greedy"
top_p: float = 0.95
top_k: int = 0
class PretrainedModel(Enum):
llama3_8b = "llama3_8b"
llama3_70b = "llama3_70b"
class InstructModel(Enum):
llama3_8b_chat = "llama3_8b_chat"
llama3_70b_chat = "llama3_70b_chat"
@json_schema_type
@dataclass
class CompletionRequest:
content: Content
model: PretrainedModel = PretrainedModel.llama3_8b
sampling_params: SamplingParams = SamplingParams()
max_tokens: int = 0
stream: bool = False
logprobs: bool = False
@json_schema_type
@dataclass
class ChatCompletionRequest:
message: Message
message_history: List[Message] = None
model: InstructModel = InstructModel.llama3_8b_chat
sampling_params: SamplingParams = SamplingParams()
max_tokens: int = 0
stream: bool = False
logprobs: bool = False
class Inference(Protocol):
def post_completion(
self,
request: CompletionRequest,
) -> Union[CompletionResponse, StreamedCompletionResponse]: ...
def post_chat_completion(
self,
request: ChatCompletionRequest,
) -> Union[ChatCompletionResponse, StreamedChatCompletionResponse]: ...
@json_schema_type
@dataclass
class AgenticSystemExecuteRequest:
message: Message
message_history: List[Message] = None
model: InstructModel = InstructModel.llama3_8b_chat
sampling_params: SamplingParams = SamplingParams()
class AgenticSystem(Protocol):
@webmethod(route="/agentic/system/execute")
def create_agentic_system_execute(self,) -> str: ...
class Endpoint(Inference, AgenticSystem): ...
if __name__ == "__main__":
print("Converting the spec to YAML (openapi.yaml) and HTML (openapi.html)")
spec = Specification(
Endpoint,
Options(
server=Server(url="http://llama.meta.com"),
info=Info(
title="Llama Stack specification",
version="0.1",
description="This is the llama stack",
),
),
)
with open("openapi.yaml", "w", encoding="utf-8") as fp:
yaml.dump(spec.get_json(), fp, allow_unicode=True)
with open("openapi.html", "w") as fp:
spec.write_html(fp, pretty_print=True)