From 6fb69efbe561df793e12cfeec0ff858df1161d10 Mon Sep 17 00:00:00 2001 From: Raghotham Murthy Date: Wed, 10 Jul 2024 23:25:23 -0700 Subject: [PATCH] Added batch inference --- source/api_definitions.py | 40 +++++++++++++++++++++++++++++++++++++++ source/model_types.py | 7 +++++++ source/openapi.html | 12 ++++++------ source/openapi.yaml | 8 ++++---- 4 files changed, 57 insertions(+), 10 deletions(-) diff --git a/source/api_definitions.py b/source/api_definitions.py index 37fd663b5..87976a0ea 100644 --- a/source/api_definitions.py +++ b/source/api_definitions.py @@ -27,6 +27,7 @@ from finetuning_types import ( from model_types import ( BuiltinTool, Content, + Dialog, InstructModel, Message, PretrainedModel, @@ -130,6 +131,45 @@ class Inference(Protocol): ) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ... +@json_schema_type +@dataclass +class BatchCompletionRequest: + content_batch: List[Content] + model: PretrainedModel + sampling_params: SamplingParams = SamplingParams() + max_tokens: int = 0 + logprobs: bool = False + + +@json_schema_type +@dataclass +class BatchChatCompletionRequest: + model: InstructModel + batch_messages: List[Dialog] + sampling_params: SamplingParams = SamplingParams() + + # zero-shot tool definitions as input to the model + available_tools: List[Union[BuiltinTool, ToolDefinition]] = field( + default_factory=list + ) + + max_tokens: int = 0 + logprobs: bool = False + + +class BatchInference(Protocol): + """Batch inference calls""" + def post_batch_completion( + self, + request: BatchCompletionRequest, + ) -> List[CompletionResponse]: ... + + def post_batch_chat_completion( + self, + request: BatchChatCompletionRequest, + ) -> List[ChatCompletionResponse]: ... + + @dataclass class AgenticSystemCreateRequest: instructions: str diff --git a/source/model_types.py b/source/model_types.py index f695938f3..9e6e3dc4b 100644 --- a/source/model_types.py +++ b/source/model_types.py @@ -121,6 +121,13 @@ class Message: tool_responses: List[ToolResponse] = field(default_factory=list) +@json_schema_type +@dataclass +class Dialog: + message: Message + message_history: List[Message] = None + + @dataclass class SamplingParams: temperature: float = 0.0 diff --git a/source/openapi.html b/source/openapi.html index f44cae180..ff00b3a0c 100644 --- a/source/openapi.html +++ b/source/openapi.html @@ -2406,22 +2406,22 @@ ], "tags": [ { - "name": "RewardScoring" - }, - { - "name": "Inference" + "name": "SyntheticDataGeneration" }, { "name": "Datasets" }, { - "name": "SyntheticDataGeneration" + "name": "AgenticSystem" + }, + { + "name": "Inference" }, { "name": "Finetuning" }, { - "name": "AgenticSystem" + "name": "RewardScoring" }, { "name": "ShieldConfig", diff --git a/source/openapi.yaml b/source/openapi.yaml index 0cdd6af14..36978ac42 100644 --- a/source/openapi.yaml +++ b/source/openapi.yaml @@ -1469,12 +1469,12 @@ security: servers: - url: http://llama.meta.com tags: -- name: RewardScoring -- name: Inference -- name: Datasets - name: SyntheticDataGeneration -- name: Finetuning +- name: Datasets - name: AgenticSystem +- name: Inference +- name: Finetuning +- name: RewardScoring - description: name: ShieldConfig - description: