Avoid warnings from pydantic for overriding schema

Also fix structured output in completions
This commit is contained in:
Ashwin Bharambe 2024-10-28 13:36:17 -07:00
parent ed833bb758
commit eccd7dc4a9
7 changed files with 50 additions and 26 deletions

View file

@ -27,17 +27,33 @@ from llama_stack.providers.utils.inference import supported_inference_models
def completion_request_to_prompt(
request: CompletionRequest, formatter: ChatFormat
) -> str:
model_input = formatter.encode_content(request.content)
content = augment_content_with_response_format_prompt(
request.response_format, request.content
)
model_input = formatter.encode_content(content)
return formatter.tokenizer.decode(model_input.tokens)
def completion_request_to_prompt_model_input_info(
request: CompletionRequest, formatter: ChatFormat
) -> Tuple[str, int]:
model_input = formatter.encode_content(request.content)
content = augment_content_with_response_format_prompt(
request.response_format, request.content
)
model_input = formatter.encode_content(content)
return (formatter.tokenizer.decode(model_input.tokens), len(model_input.tokens))
def augment_content_with_response_format_prompt(response_format, content):
if fmt_prompt := response_format_prompt(response_format):
if isinstance(content, list):
return content + [fmt_prompt]
else:
return [content, fmt_prompt]
return content
def chat_completion_request_to_prompt(
request: ChatCompletionRequest, formatter: ChatFormat
) -> str:
@ -84,21 +100,24 @@ def chat_completion_request_to_messages(
else:
messages = request.messages
if fmt := request.response_format:
if fmt.type == ResponseFormatType.json_schema.value:
messages.append(
UserMessage(
content=f"Please respond in JSON format with the schema: {json.dumps(fmt.schema)}"
)
)
elif fmt.type == ResponseFormatType.grammar.value:
raise NotImplementedError("Grammar response format not supported yet")
else:
raise ValueError(f"Unknown response format {fmt.type}")
if fmt_prompt := response_format_prompt(request.response_format):
messages.append(UserMessage(content=fmt_prompt))
return messages
def response_format_prompt(fmt: Optional[ResponseFormat]):
if not fmt:
return None
if fmt.type == ResponseFormatType.json_schema.value:
return f"Please respond in JSON format with the schema: {json.dumps(fmt.json_schema)}"
elif fmt.type == ResponseFormatType.grammar.value:
raise NotImplementedError("Grammar response format not supported yet")
else:
raise ValueError(f"Unknown response format {fmt.type}")
def augment_messages_for_tools_llama_3_1(
request: ChatCompletionRequest,
) -> List[Message]: