Added batch inference

This commit is contained in:
Raghotham Murthy 2024-07-10 23:25:23 -07:00
parent 22d6093258
commit 6fb69efbe5
4 changed files with 57 additions and 10 deletions

View file

@ -27,6 +27,7 @@ from finetuning_types import (
from model_types import (
BuiltinTool,
Content,
Dialog,
InstructModel,
Message,
PretrainedModel,
@ -130,6 +131,45 @@ class Inference(Protocol):
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ...
@json_schema_type
@dataclass
class BatchCompletionRequest:
content_batch: List[Content]
model: PretrainedModel
sampling_params: SamplingParams = SamplingParams()
max_tokens: int = 0
logprobs: bool = False
@json_schema_type
@dataclass
class BatchChatCompletionRequest:
model: InstructModel
batch_messages: List[Dialog]
sampling_params: SamplingParams = SamplingParams()
# zero-shot tool definitions as input to the model
available_tools: List[Union[BuiltinTool, ToolDefinition]] = field(
default_factory=list
)
max_tokens: int = 0
logprobs: bool = False
class BatchInference(Protocol):
"""Batch inference calls"""
def post_batch_completion(
self,
request: BatchCompletionRequest,
) -> List[CompletionResponse]: ...
def post_batch_chat_completion(
self,
request: BatchChatCompletionRequest,
) -> List[ChatCompletionResponse]: ...
@dataclass
class AgenticSystemCreateRequest:
instructions: str

View file

@ -121,6 +121,13 @@ class Message:
tool_responses: List[ToolResponse] = field(default_factory=list)
@json_schema_type
@dataclass
class Dialog:
message: Message
message_history: List[Message] = None
@dataclass
class SamplingParams:
temperature: float = 0.0

View file

@ -2406,22 +2406,22 @@
],
"tags": [
{
"name": "RewardScoring"
},
{
"name": "Inference"
"name": "SyntheticDataGeneration"
},
{
"name": "Datasets"
},
{
"name": "SyntheticDataGeneration"
"name": "AgenticSystem"
},
{
"name": "Inference"
},
{
"name": "Finetuning"
},
{
"name": "AgenticSystem"
"name": "RewardScoring"
},
{
"name": "ShieldConfig",

View file

@ -1469,12 +1469,12 @@ security:
servers:
- url: http://llama.meta.com
tags:
- name: RewardScoring
- name: Inference
- name: Datasets
- name: SyntheticDataGeneration
- name: Finetuning
- name: Datasets
- name: AgenticSystem
- name: Inference
- name: Finetuning
- name: RewardScoring
- description: <SchemaDefinition schemaRef="#/components/schemas/ShieldConfig" />
name: ShieldConfig
- description: <SchemaDefinition schemaRef="#/components/schemas/AgenticSystemCreateRequest"