mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 23:29:43 +00:00
add response format to signature
This commit is contained in:
parent
6d26bbdce3
commit
40ba22f4c8
15 changed files with 93 additions and 32 deletions
|
@ -53,6 +53,7 @@ class InferenceClient(Inference):
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
@ -63,6 +64,7 @@ class InferenceClient(Inference):
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
tool_prompt_format=tool_prompt_format,
|
tool_prompt_format=tool_prompt_format,
|
||||||
|
response_format=response_format,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
|
|
@ -74,13 +74,18 @@ class ChatCompletionResponseEvent(BaseModel):
|
||||||
stop_reason: Optional[StopReason] = None
|
stop_reason: Optional[StopReason] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseFormatType(Enum):
|
||||||
|
json = "json"
|
||||||
|
grammar = "grammar"
|
||||||
|
|
||||||
|
|
||||||
class JsonResponseFormat(BaseModel):
|
class JsonResponseFormat(BaseModel):
|
||||||
type: Literal["json"] = "json"
|
type: Literal[ResponseFormat.json.value] = ResponseFormat.json.value
|
||||||
schema: Dict[str, Any]
|
schema: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
class GrammarResponseFormat(BaseModel):
|
class GrammarResponseFormat(BaseModel):
|
||||||
type: Literal["grammar"] = "grammar"
|
type: Literal[ResponseFormat.grammar.value] = ResponseFormat.grammar.value
|
||||||
bnf: Dict[str, Any]
|
bnf: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -75,6 +75,7 @@ class InferenceRouter(Inference):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
|
@ -102,6 +103,7 @@ class InferenceRouter(Inference):
|
||||||
model: str,
|
model: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
|
|
@ -52,6 +52,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
model: str,
|
model: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||||
|
@ -288,6 +289,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
# zero-shot tool definitions as input to the model
|
# zero-shot tool definitions as input to the model
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
|
|
|
@ -53,6 +53,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
model: str,
|
model: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
@ -63,6 +64,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
|
|
|
@ -56,6 +56,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
model: str,
|
model: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
@ -66,6 +67,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
|
|
|
@ -93,6 +93,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
model: str,
|
model: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
@ -160,6 +161,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
|
|
|
@ -71,6 +71,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
model: str,
|
model: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
@ -81,6 +82,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
|
|
|
@ -59,6 +59,7 @@ class TogetherInferenceAdapter(
|
||||||
model: str,
|
model: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
@ -69,6 +70,7 @@ class TogetherInferenceAdapter(
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
|
|
|
@ -80,6 +80,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
model: str,
|
model: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||||
|
@ -90,6 +91,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
|
|
|
@ -26,6 +26,7 @@ class MockInferenceAPI:
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = None,
|
tool_choice: Optional[ToolChoice] = None,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
|
|
|
@ -80,7 +80,7 @@ class Llama:
|
||||||
def build(
|
def build(
|
||||||
config: Union[
|
config: Union[
|
||||||
MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||||
]
|
],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Build a Llama instance by initializing and loading a model checkpoint.
|
Build a Llama instance by initializing and loading a model checkpoint.
|
||||||
|
@ -184,17 +184,11 @@ class Llama:
|
||||||
echo: bool = False,
|
echo: bool = False,
|
||||||
include_stop_token: bool = False,
|
include_stop_token: bool = False,
|
||||||
print_input_tokens: bool = False,
|
print_input_tokens: bool = False,
|
||||||
|
logits_processor: Optional["LogitsProcessor"] = None,
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
parser = JsonSchemaParser(AnswerFormat.schema())
|
|
||||||
tokenizer_data = build_token_enforcer_tokenizer_data(
|
|
||||||
self.tokenizer, self.args.vocab_size
|
|
||||||
)
|
|
||||||
token_enforcer = TokenEnforcer(tokenizer_data, parser)
|
|
||||||
logits_processor = LogitsProcessor(token_enforcer)
|
|
||||||
|
|
||||||
params = self.model.params
|
params = self.model.params
|
||||||
|
|
||||||
if print_input_tokens or True:
|
if print_input_tokens:
|
||||||
input_tokens = [
|
input_tokens = [
|
||||||
self.formatter.vision_token if t == 128256 else t
|
self.formatter.vision_token if t == 128256 else t
|
||||||
for t in model_input.tokens
|
for t in model_input.tokens
|
||||||
|
@ -266,10 +260,10 @@ class Llama:
|
||||||
else:
|
else:
|
||||||
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
||||||
|
|
||||||
# print(f"{logits=}")
|
if logits_processor is not None:
|
||||||
input_ids = tokens[0, :cur_pos].tolist()
|
logits = logits_processor.process_logits(
|
||||||
# logits = logits_processor.process_logits(input_ids, logits)
|
tokens[0, :cur_pos].tolist(), logits
|
||||||
# print(f"{logits=}")
|
)
|
||||||
|
|
||||||
if temperature > 0:
|
if temperature > 0:
|
||||||
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
|
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
|
||||||
|
@ -342,7 +336,11 @@ class Llama:
|
||||||
top_p=sampling_params.top_p,
|
top_p=sampling_params.top_p,
|
||||||
logprobs=bool(request.logprobs),
|
logprobs=bool(request.logprobs),
|
||||||
include_stop_token=True,
|
include_stop_token=True,
|
||||||
echo=False,
|
logits_processor=get_logits_processor(
|
||||||
|
self.tokenizer,
|
||||||
|
self.args.vocab_size,
|
||||||
|
request.response_format,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def chat_completion(
|
def chat_completion(
|
||||||
|
@ -370,6 +368,11 @@ class Llama:
|
||||||
top_p=sampling_params.top_p,
|
top_p=sampling_params.top_p,
|
||||||
logprobs=bool(request.logprobs),
|
logprobs=bool(request.logprobs),
|
||||||
include_stop_token=True,
|
include_stop_token=True,
|
||||||
|
logits_processor=get_logits_processor(
|
||||||
|
self.tokenizer,
|
||||||
|
self.args.vocab_size,
|
||||||
|
request.response_format,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -398,6 +401,27 @@ def sample_top_p(probs, p):
|
||||||
return next_token
|
return next_token
|
||||||
|
|
||||||
|
|
||||||
|
def get_logits_processor(
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
vocab_size: int,
|
||||||
|
response_format: Optional[ResponseFormat],
|
||||||
|
) -> Optional[LogitsProcessor]:
|
||||||
|
if response_format is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if response_format.type != ResponseFormatType.json:
|
||||||
|
raise ValueError(f"Unsupported response format type {response_format.type}")
|
||||||
|
|
||||||
|
parser = JsonSchemaParser(response_format.schema)
|
||||||
|
data = TokenEnforcerTokenizerData(
|
||||||
|
_build_regular_tokens_list(tokenizer, vocab_size),
|
||||||
|
tokenizer.decode,
|
||||||
|
tokenizer.stop_tokens,
|
||||||
|
)
|
||||||
|
token_enforcer = TokenEnforcer(data, parser)
|
||||||
|
return LogitsProcessor(token_enforcer)
|
||||||
|
|
||||||
|
|
||||||
class LogitsProcessor:
|
class LogitsProcessor:
|
||||||
def __init__(self, token_enforcer: TokenEnforcer):
|
def __init__(self, token_enforcer: TokenEnforcer):
|
||||||
self.token_enforcer = token_enforcer
|
self.token_enforcer = token_enforcer
|
||||||
|
@ -437,13 +461,3 @@ def _build_regular_tokens_list(
|
||||||
is_word_start_token = len(decoded_after_0) > len(decoded_regular)
|
is_word_start_token = len(decoded_after_0) > len(decoded_regular)
|
||||||
regular_tokens.append((token_idx, decoded_after_0, is_word_start_token))
|
regular_tokens.append((token_idx, decoded_after_0, is_word_start_token))
|
||||||
return regular_tokens
|
return regular_tokens
|
||||||
|
|
||||||
|
|
||||||
def build_token_enforcer_tokenizer_data(
|
|
||||||
tokenizer: Tokenizer,
|
|
||||||
vocab_size: int,
|
|
||||||
) -> TokenEnforcerTokenizerData:
|
|
||||||
regular_tokens = _build_regular_tokens_list(tokenizer, vocab_size)
|
|
||||||
return TokenEnforcerTokenizerData(
|
|
||||||
regular_tokens, tokenizer.decode, tokenizer.stop_tokens
|
|
||||||
)
|
|
||||||
|
|
|
@ -71,6 +71,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
model: str,
|
model: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||||
|
@ -81,6 +82,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
model=model,
|
model=model,
|
||||||
content=content,
|
content=content,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
|
response_format=response_format,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
@ -186,6 +188,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
|
@ -203,6 +206,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
tool_prompt_format=tool_prompt_format,
|
tool_prompt_format=tool_prompt_format,
|
||||||
|
response_format=response_format,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
|
|
@ -97,12 +97,12 @@ class AnswerFormat(BaseModel):
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_messages():
|
def sample_messages():
|
||||||
question = "Please give me information about Michael Jordan. You MUST answer using the following json schema: "
|
question = "Please give me information about Michael Jordan."
|
||||||
question_with_schema = f"{question}{AnswerFormat.schema_json()}"
|
# question_with_schema = f"{question}{AnswerFormat.schema_json()}"
|
||||||
return [
|
return [
|
||||||
SystemMessage(content="You are a helpful assistant."),
|
SystemMessage(content="You are a helpful assistant."),
|
||||||
# UserMessage(content="What's the weather like today?"),
|
# UserMessage(content="What's the weather like today?"),
|
||||||
UserMessage(content=question_with_schema),
|
UserMessage(content=question),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -183,10 +183,15 @@ async def test_completion(inference_settings):
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_chat_completion_non_streaming(inference_settings, sample_messages):
|
async def test_chat_completion_non_streaming(inference_settings, sample_messages):
|
||||||
|
print(AnswerFormat.schema_json())
|
||||||
|
print(AnswerFormat.schema())
|
||||||
inference_impl = inference_settings["impl"]
|
inference_impl = inference_settings["impl"]
|
||||||
response = await inference_impl.chat_completion(
|
response = await inference_impl.chat_completion(
|
||||||
messages=sample_messages,
|
messages=sample_messages,
|
||||||
stream=False,
|
stream=False,
|
||||||
|
response_format=JsonResponseFormat(
|
||||||
|
schema=AnswerFormat.schema(),
|
||||||
|
),
|
||||||
**inference_settings["common_params"],
|
**inference_settings["common_params"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -70,11 +70,25 @@ def chat_completion_request_to_messages(
|
||||||
and is_multimodal(model.core_model_id)
|
and is_multimodal(model.core_model_id)
|
||||||
):
|
):
|
||||||
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
|
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
|
||||||
return augment_messages_for_tools_llama_3_1(request)
|
messages = augment_messages_for_tools_llama_3_1(request)
|
||||||
elif model.model_family == ModelFamily.llama3_2:
|
elif model.model_family == ModelFamily.llama3_2:
|
||||||
return augment_messages_for_tools_llama_3_2(request)
|
messages = augment_messages_for_tools_llama_3_2(request)
|
||||||
else:
|
else:
|
||||||
return request.messages
|
messages = request.messages
|
||||||
|
|
||||||
|
if fmt := request.response_format:
|
||||||
|
if fmt.type == ResponseFormatType.json:
|
||||||
|
messages.append(
|
||||||
|
UserMessage(
|
||||||
|
content=f"Please response in JSON format with the schema: {json.dumps(fmt.schema)}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif fmt.type == ResponseFormatType.grammar:
|
||||||
|
raise NotImplementedError("Grammar response format not supported yet")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown response format {fmt.type}")
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
def augment_messages_for_tools_llama_3_1(
|
def augment_messages_for_tools_llama_3_1(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue