This commit is contained in:
Ashwin Bharambe 2024-07-10 23:33:57 -07:00
parent ee86f2c75f
commit 7cade3acc3
3 changed files with 721 additions and 271 deletions

View file

@ -80,15 +80,12 @@ class CompletionResponseStreamChunk:
@json_schema_type
@dataclass
class ChatCompletionRequest:
message: Message
model: InstructModel
message_history: List[Message] = None
dialog: Dialog
sampling_params: SamplingParams = SamplingParams()
# zero-shot tool definitions as input to the model
available_tools: List[Union[BuiltinTool, ToolDefinition]] = field(
default_factory=list
)
available_tools: List[ToolDefinition] = field(default_factory=list)
max_tokens: int = 0
stream: bool = False
@ -119,6 +116,30 @@ class ChatCompletionResponseStreamChunk:
tool_call: Optional[ToolCall] = None
@json_schema_type
@dataclass
class BatchCompletionRequest:
model: PretrainedModel
content_batch: List[Content]
sampling_params: SamplingParams = SamplingParams()
max_tokens: int = 0
logprobs: bool = False
@json_schema_type
@dataclass
class BatchChatCompletionRequest:
model: InstructModel
batch_dialogs: List[Dialog]
sampling_params: SamplingParams = SamplingParams()
# zero-shot tool definitions as input to the model
available_tools: List[ToolDefinition] = field(default_factory=list)
max_tokens: int = 0
logprobs: bool = False
class Inference(Protocol):
def post_completion(
@ -131,35 +152,6 @@ class Inference(Protocol):
request: ChatCompletionRequest,
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ...
@json_schema_type
@dataclass
class BatchCompletionRequest:
content_batch: List[Content]
model: PretrainedModel
sampling_params: SamplingParams = SamplingParams()
max_tokens: int = 0
logprobs: bool = False
@json_schema_type
@dataclass
class BatchChatCompletionRequest:
model: InstructModel
batch_messages: List[Dialog]
sampling_params: SamplingParams = SamplingParams()
# zero-shot tool definitions as input to the model
available_tools: List[Union[BuiltinTool, ToolDefinition]] = field(
default_factory=list
)
max_tokens: int = 0
logprobs: bool = False
class BatchInference(Protocol):
"""Batch inference calls"""
def post_batch_completion(
self,
request: BatchCompletionRequest,
@ -302,8 +294,7 @@ class MemoryBanks(Protocol):
@dataclass
class KPromptGenerations:
prompt: Message
message_history: List[Message]
dialog: Dialog
k_generations: List[Message]