add response format to signature

This commit is contained in:
Ashwin Bharambe 2024-10-21 19:14:52 -07:00 committed by Ashwin Bharambe
parent 6d26bbdce3
commit 40ba22f4c8
15 changed files with 93 additions and 32 deletions

View file

@ -53,6 +53,7 @@ class InferenceClient(Inference):
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
@ -63,6 +64,7 @@ class InferenceClient(Inference):
tools=tools or [], tools=tools or [],
tool_choice=tool_choice, tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format, tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
) )

View file

@ -74,13 +74,18 @@ class ChatCompletionResponseEvent(BaseModel):
stop_reason: Optional[StopReason] = None stop_reason: Optional[StopReason] = None
class ResponseFormatType(Enum):
json = "json"
grammar = "grammar"
class JsonResponseFormat(BaseModel): class JsonResponseFormat(BaseModel):
type: Literal["json"] = "json" type: Literal[ResponseFormat.json.value] = ResponseFormat.json.value
schema: Dict[str, Any] schema: Dict[str, Any]
class GrammarResponseFormat(BaseModel): class GrammarResponseFormat(BaseModel):
type: Literal["grammar"] = "grammar" type: Literal[ResponseFormat.grammar.value] = ResponseFormat.grammar.value
bnf: Dict[str, Any] bnf: Dict[str, Any]

View file

@ -75,6 +75,7 @@ class InferenceRouter(Inference):
model: str, model: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
@ -102,6 +103,7 @@ class InferenceRouter(Inference):
model: str, model: str,
content: InterleavedTextMedia, content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:

View file

@ -52,6 +52,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
model: str, model: str,
content: InterleavedTextMedia, content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
@ -288,6 +289,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
model: str, model: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
# zero-shot tool definitions as input to the model # zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,

View file

@ -53,6 +53,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
model: str, model: str,
content: InterleavedTextMedia, content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
@ -63,6 +64,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
model: str, model: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,

View file

@ -56,6 +56,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
model: str, model: str,
content: InterleavedTextMedia, content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
@ -66,6 +67,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
model: str, model: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,

View file

@ -93,6 +93,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
model: str, model: str,
content: InterleavedTextMedia, content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
@ -160,6 +161,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
model: str, model: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,

View file

@ -71,6 +71,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
model: str, model: str,
content: InterleavedTextMedia, content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
@ -81,6 +82,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
model: str, model: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,

View file

@ -59,6 +59,7 @@ class TogetherInferenceAdapter(
model: str, model: str,
content: InterleavedTextMedia, content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
@ -69,6 +70,7 @@ class TogetherInferenceAdapter(
model: str, model: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,

View file

@ -80,6 +80,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
model: str, model: str,
content: InterleavedTextMedia, content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
@ -90,6 +91,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
model: str, model: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,

View file

@ -26,6 +26,7 @@ class MockInferenceAPI:
model: str, model: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = None, tool_choice: Optional[ToolChoice] = None,
tool_prompt_format: Optional[ToolPromptFormat] = None, tool_prompt_format: Optional[ToolPromptFormat] = None,

View file

@ -80,7 +80,7 @@ class Llama:
def build( def build(
config: Union[ config: Union[
MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
] ],
): ):
""" """
Build a Llama instance by initializing and loading a model checkpoint. Build a Llama instance by initializing and loading a model checkpoint.
@ -184,17 +184,11 @@ class Llama:
echo: bool = False, echo: bool = False,
include_stop_token: bool = False, include_stop_token: bool = False,
print_input_tokens: bool = False, print_input_tokens: bool = False,
logits_processor: Optional["LogitsProcessor"] = None,
) -> Generator: ) -> Generator:
parser = JsonSchemaParser(AnswerFormat.schema())
tokenizer_data = build_token_enforcer_tokenizer_data(
self.tokenizer, self.args.vocab_size
)
token_enforcer = TokenEnforcer(tokenizer_data, parser)
logits_processor = LogitsProcessor(token_enforcer)
params = self.model.params params = self.model.params
if print_input_tokens or True: if print_input_tokens:
input_tokens = [ input_tokens = [
self.formatter.vision_token if t == 128256 else t self.formatter.vision_token if t == 128256 else t
for t in model_input.tokens for t in model_input.tokens
@ -266,10 +260,10 @@ class Llama:
else: else:
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
# print(f"{logits=}") if logits_processor is not None:
input_ids = tokens[0, :cur_pos].tolist() logits = logits_processor.process_logits(
# logits = logits_processor.process_logits(input_ids, logits) tokens[0, :cur_pos].tolist(), logits
# print(f"{logits=}") )
if temperature > 0: if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1) probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
@ -342,7 +336,11 @@ class Llama:
top_p=sampling_params.top_p, top_p=sampling_params.top_p,
logprobs=bool(request.logprobs), logprobs=bool(request.logprobs),
include_stop_token=True, include_stop_token=True,
echo=False, logits_processor=get_logits_processor(
self.tokenizer,
self.args.vocab_size,
request.response_format,
),
) )
def chat_completion( def chat_completion(
@ -370,6 +368,11 @@ class Llama:
top_p=sampling_params.top_p, top_p=sampling_params.top_p,
logprobs=bool(request.logprobs), logprobs=bool(request.logprobs),
include_stop_token=True, include_stop_token=True,
logits_processor=get_logits_processor(
self.tokenizer,
self.args.vocab_size,
request.response_format,
),
) )
@ -398,6 +401,27 @@ def sample_top_p(probs, p):
return next_token return next_token
def get_logits_processor(
tokenizer: Tokenizer,
vocab_size: int,
response_format: Optional[ResponseFormat],
) -> Optional[LogitsProcessor]:
if response_format is None:
return None
if response_format.type != ResponseFormatType.json:
raise ValueError(f"Unsupported response format type {response_format.type}")
parser = JsonSchemaParser(response_format.schema)
data = TokenEnforcerTokenizerData(
_build_regular_tokens_list(tokenizer, vocab_size),
tokenizer.decode,
tokenizer.stop_tokens,
)
token_enforcer = TokenEnforcer(data, parser)
return LogitsProcessor(token_enforcer)
class LogitsProcessor: class LogitsProcessor:
def __init__(self, token_enforcer: TokenEnforcer): def __init__(self, token_enforcer: TokenEnforcer):
self.token_enforcer = token_enforcer self.token_enforcer = token_enforcer
@ -437,13 +461,3 @@ def _build_regular_tokens_list(
is_word_start_token = len(decoded_after_0) > len(decoded_regular) is_word_start_token = len(decoded_after_0) > len(decoded_regular)
regular_tokens.append((token_idx, decoded_after_0, is_word_start_token)) regular_tokens.append((token_idx, decoded_after_0, is_word_start_token))
return regular_tokens return regular_tokens
def build_token_enforcer_tokenizer_data(
tokenizer: Tokenizer,
vocab_size: int,
) -> TokenEnforcerTokenizerData:
regular_tokens = _build_regular_tokens_list(tokenizer, vocab_size)
return TokenEnforcerTokenizerData(
regular_tokens, tokenizer.decode, tokenizer.stop_tokens
)

View file

@ -71,6 +71,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
model: str, model: str,
content: InterleavedTextMedia, content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
@ -81,6 +82,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
model=model, model=model,
content=content, content=content,
sampling_params=sampling_params, sampling_params=sampling_params,
response_format=response_format,
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
) )
@ -186,6 +188,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
model: str, model: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
@ -203,6 +206,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
tools=tools or [], tools=tools or [],
tool_choice=tool_choice, tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format, tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
) )

View file

@ -97,12 +97,12 @@ class AnswerFormat(BaseModel):
@pytest.fixture @pytest.fixture
def sample_messages(): def sample_messages():
question = "Please give me information about Michael Jordan. You MUST answer using the following json schema: " question = "Please give me information about Michael Jordan."
question_with_schema = f"{question}{AnswerFormat.schema_json()}" # question_with_schema = f"{question}{AnswerFormat.schema_json()}"
return [ return [
SystemMessage(content="You are a helpful assistant."), SystemMessage(content="You are a helpful assistant."),
# UserMessage(content="What's the weather like today?"), # UserMessage(content="What's the weather like today?"),
UserMessage(content=question_with_schema), UserMessage(content=question),
] ]
@ -183,10 +183,15 @@ async def test_completion(inference_settings):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_chat_completion_non_streaming(inference_settings, sample_messages): async def test_chat_completion_non_streaming(inference_settings, sample_messages):
print(AnswerFormat.schema_json())
print(AnswerFormat.schema())
inference_impl = inference_settings["impl"] inference_impl = inference_settings["impl"]
response = await inference_impl.chat_completion( response = await inference_impl.chat_completion(
messages=sample_messages, messages=sample_messages,
stream=False, stream=False,
response_format=JsonResponseFormat(
schema=AnswerFormat.schema(),
),
**inference_settings["common_params"], **inference_settings["common_params"],
) )

View file

@ -70,11 +70,25 @@ def chat_completion_request_to_messages(
and is_multimodal(model.core_model_id) and is_multimodal(model.core_model_id)
): ):
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format # llama3.1 and llama3.2 multimodal models follow the same tool prompt format
return augment_messages_for_tools_llama_3_1(request) messages = augment_messages_for_tools_llama_3_1(request)
elif model.model_family == ModelFamily.llama3_2: elif model.model_family == ModelFamily.llama3_2:
return augment_messages_for_tools_llama_3_2(request) messages = augment_messages_for_tools_llama_3_2(request)
else: else:
return request.messages messages = request.messages
if fmt := request.response_format:
if fmt.type == ResponseFormatType.json:
messages.append(
UserMessage(
content=f"Please response in JSON format with the schema: {json.dumps(fmt.schema)}"
)
)
elif fmt.type == ResponseFormatType.grammar:
raise NotImplementedError("Grammar response format not supported yet")
else:
raise ValueError(f"Unknown response format {fmt.type}")
return messages
def augment_messages_for_tools_llama_3_1( def augment_messages_for_tools_llama_3_1(