guided decoding initial draft

This commit is contained in:
Ashwin Bharambe 2024-10-21 18:44:19 -07:00
parent 1d241bf3fe
commit 6d26bbdce3
4 changed files with 133 additions and 22 deletions

View file

@ -74,11 +74,28 @@ class ChatCompletionResponseEvent(BaseModel):
stop_reason: Optional[StopReason] = None
class JsonResponseFormat(BaseModel):
type: Literal["json"] = "json"
schema: Dict[str, Any]
class GrammarResponseFormat(BaseModel):
type: Literal["grammar"] = "grammar"
bnf: Dict[str, Any]
ResponseFormat = Annotated[
Union[JsonResponseFormat, GrammarResponseFormat],
Field(discriminator="type"),
]
@json_schema_type
class CompletionRequest(BaseModel):
model: str
content: InterleavedTextMedia
sampling_params: Optional[SamplingParams] = SamplingParams()
response_format: Optional[ResponseFormat] = None
stream: Optional[bool] = False
logprobs: Optional[LogProbConfig] = None
@ -107,6 +124,7 @@ class BatchCompletionRequest(BaseModel):
model: str
content_batch: List[InterleavedTextMedia]
sampling_params: Optional[SamplingParams] = SamplingParams()
response_format: Optional[ResponseFormat] = None
logprobs: Optional[LogProbConfig] = None
@ -129,6 +147,7 @@ class ChatCompletionRequest(BaseModel):
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
)
response_format: Optional[ResponseFormat] = None
stream: Optional[bool] = False
logprobs: Optional[LogProbConfig] = None
@ -188,6 +207,7 @@ class Inference(Protocol):
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
@ -204,6 +224,7 @@ class Inference(Protocol):
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ...