Avoid warnings from pydantic for overriding schema

Also fix structured output in completions
This commit is contained in:
Ashwin Bharambe 2024-10-28 13:36:17 -07:00
parent ed833bb758
commit eccd7dc4a9
7 changed files with 50 additions and 26 deletions

View file

@ -86,11 +86,11 @@ class ResponseFormatType(Enum):
grammar = "grammar"
class JsonResponseFormat(BaseModel):
class JsonSchemaResponseFormat(BaseModel):
type: Literal[ResponseFormatType.json_schema.value] = (
ResponseFormatType.json_schema.value
)
schema: Dict[str, Any]
json_schema: Dict[str, Any]
class GrammarResponseFormat(BaseModel):
@ -99,7 +99,7 @@ class GrammarResponseFormat(BaseModel):
ResponseFormat = Annotated[
Union[JsonResponseFormat, GrammarResponseFormat],
Union[JsonSchemaResponseFormat, GrammarResponseFormat],
Field(discriminator="type"),
]

View file

@ -163,7 +163,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
if fmt.type == ResponseFormatType.json_schema.value:
options["response_format"] = {
"type": "json_object",
"schema": fmt.schema,
"schema": fmt.json_schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
options["response_format"] = {

View file

@ -109,7 +109,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
if fmt.type == ResponseFormatType.json_schema.value:
options["grammar"] = {
"type": "json",
"value": fmt.schema,
"value": fmt.json_schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
raise ValueError("Grammar response format not supported yet")

View file

@ -121,7 +121,7 @@ class TogetherInferenceAdapter(
if fmt.type == ResponseFormatType.json_schema.value:
options["response_format"] = {
"type": "json_object",
"schema": fmt.schema,
"schema": fmt.json_schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
raise NotImplementedError("Grammar response format not supported yet")

View file

@ -39,6 +39,7 @@ from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerToken
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.utils.inference.prompt_adapter import (
augment_content_with_response_format_prompt,
chat_completion_request_to_messages,
)
@ -346,7 +347,10 @@ class Llama:
):
max_gen_len = self.model.params.max_seq_len - 1
model_input = self.formatter.encode_content(request.content)
content = augment_content_with_response_format_prompt(
request.response_format, request.content
)
model_input = self.formatter.encode_content(content)
yield from self.generate(
model_input=model_input,
max_gen_len=max_gen_len,
@ -451,7 +455,7 @@ def get_logits_processor(
if response_format.type != ResponseFormatType.json_schema.value:
raise ValueError(f"Unsupported response format type {response_format.type}")
parser = JsonSchemaParser(response_format.schema)
parser = JsonSchemaParser(response_format.json_schema)
data = TokenEnforcerTokenizerData(
_build_regular_tokens_list(tokenizer, vocab_size),
tokenizer.decode,

View file

@ -174,6 +174,7 @@ async def test_completion(inference_settings):
@pytest.mark.asyncio
@pytest.mark.skip("This test is not quite robust")
async def test_completions_structured_output(inference_settings):
inference_impl = inference_settings["impl"]
params = inference_settings["common_params"]
@ -196,14 +197,14 @@ async def test_completions_structured_output(inference_settings):
user_input = "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003."
response = await inference_impl.completion(
content=f"input: '{user_input}'. the schema for json: {Output.schema()}, the json is: ",
content=user_input,
stream=False,
model=params["model"],
sampling_params=SamplingParams(
max_tokens=50,
),
response_format=JsonResponseFormat(
schema=Output.model_json_schema(),
response_format=JsonSchemaResponseFormat(
json_schema=Output.model_json_schema(),
),
)
assert isinstance(response, CompletionResponse)
@ -256,8 +257,8 @@ async def test_structured_output(inference_settings):
UserMessage(content="Please give me information about Michael Jordan."),
],
stream=False,
response_format=JsonResponseFormat(
schema=AnswerFormat.model_json_schema(),
response_format=JsonSchemaResponseFormat(
json_schema=AnswerFormat.model_json_schema(),
),
**inference_settings["common_params"],
)

View file

@ -27,17 +27,33 @@ from llama_stack.providers.utils.inference import supported_inference_models
def completion_request_to_prompt(
request: CompletionRequest, formatter: ChatFormat
) -> str:
model_input = formatter.encode_content(request.content)
content = augment_content_with_response_format_prompt(
request.response_format, request.content
)
model_input = formatter.encode_content(content)
return formatter.tokenizer.decode(model_input.tokens)
def completion_request_to_prompt_model_input_info(
request: CompletionRequest, formatter: ChatFormat
) -> Tuple[str, int]:
model_input = formatter.encode_content(request.content)
content = augment_content_with_response_format_prompt(
request.response_format, request.content
)
model_input = formatter.encode_content(content)
return (formatter.tokenizer.decode(model_input.tokens), len(model_input.tokens))
def augment_content_with_response_format_prompt(response_format, content):
if fmt_prompt := response_format_prompt(response_format):
if isinstance(content, list):
return content + [fmt_prompt]
else:
return [content, fmt_prompt]
return content
def chat_completion_request_to_prompt(
request: ChatCompletionRequest, formatter: ChatFormat
) -> str:
@ -84,21 +100,24 @@ def chat_completion_request_to_messages(
else:
messages = request.messages
if fmt := request.response_format:
if fmt.type == ResponseFormatType.json_schema.value:
messages.append(
UserMessage(
content=f"Please respond in JSON format with the schema: {json.dumps(fmt.schema)}"
)
)
elif fmt.type == ResponseFormatType.grammar.value:
raise NotImplementedError("Grammar response format not supported yet")
else:
raise ValueError(f"Unknown response format {fmt.type}")
if fmt_prompt := response_format_prompt(request.response_format):
messages.append(UserMessage(content=fmt_prompt))
return messages
def response_format_prompt(fmt: Optional[ResponseFormat]):
if not fmt:
return None
if fmt.type == ResponseFormatType.json_schema.value:
return f"Please respond in JSON format with the schema: {json.dumps(fmt.json_schema)}"
elif fmt.type == ResponseFormatType.grammar.value:
raise NotImplementedError("Grammar response format not supported yet")
else:
raise ValueError(f"Unknown response format {fmt.type}")
def augment_messages_for_tools_llama_3_1(
request: ChatCompletionRequest,
) -> List[Message]: