forked from phoenix-oss/llama-stack-mirror
Avoid warnings from pydantic for overriding schema
Also fix structured output in completions
This commit is contained in:
parent
ed833bb758
commit
eccd7dc4a9
7 changed files with 50 additions and 26 deletions
|
@ -27,17 +27,33 @@ from llama_stack.providers.utils.inference import supported_inference_models
|
|||
def completion_request_to_prompt(
|
||||
request: CompletionRequest, formatter: ChatFormat
|
||||
) -> str:
|
||||
model_input = formatter.encode_content(request.content)
|
||||
content = augment_content_with_response_format_prompt(
|
||||
request.response_format, request.content
|
||||
)
|
||||
model_input = formatter.encode_content(content)
|
||||
return formatter.tokenizer.decode(model_input.tokens)
|
||||
|
||||
|
||||
def completion_request_to_prompt_model_input_info(
|
||||
request: CompletionRequest, formatter: ChatFormat
|
||||
) -> Tuple[str, int]:
|
||||
model_input = formatter.encode_content(request.content)
|
||||
content = augment_content_with_response_format_prompt(
|
||||
request.response_format, request.content
|
||||
)
|
||||
model_input = formatter.encode_content(content)
|
||||
return (formatter.tokenizer.decode(model_input.tokens), len(model_input.tokens))
|
||||
|
||||
|
||||
def augment_content_with_response_format_prompt(response_format, content):
|
||||
if fmt_prompt := response_format_prompt(response_format):
|
||||
if isinstance(content, list):
|
||||
return content + [fmt_prompt]
|
||||
else:
|
||||
return [content, fmt_prompt]
|
||||
|
||||
return content
|
||||
|
||||
|
||||
def chat_completion_request_to_prompt(
|
||||
request: ChatCompletionRequest, formatter: ChatFormat
|
||||
) -> str:
|
||||
|
@ -84,21 +100,24 @@ def chat_completion_request_to_messages(
|
|||
else:
|
||||
messages = request.messages
|
||||
|
||||
if fmt := request.response_format:
|
||||
if fmt.type == ResponseFormatType.json_schema.value:
|
||||
messages.append(
|
||||
UserMessage(
|
||||
content=f"Please respond in JSON format with the schema: {json.dumps(fmt.schema)}"
|
||||
)
|
||||
)
|
||||
elif fmt.type == ResponseFormatType.grammar.value:
|
||||
raise NotImplementedError("Grammar response format not supported yet")
|
||||
else:
|
||||
raise ValueError(f"Unknown response format {fmt.type}")
|
||||
if fmt_prompt := response_format_prompt(request.response_format):
|
||||
messages.append(UserMessage(content=fmt_prompt))
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def response_format_prompt(fmt: Optional[ResponseFormat]):
|
||||
if not fmt:
|
||||
return None
|
||||
|
||||
if fmt.type == ResponseFormatType.json_schema.value:
|
||||
return f"Please respond in JSON format with the schema: {json.dumps(fmt.json_schema)}"
|
||||
elif fmt.type == ResponseFormatType.grammar.value:
|
||||
raise NotImplementedError("Grammar response format not supported yet")
|
||||
else:
|
||||
raise ValueError(f"Unknown response format {fmt.type}")
|
||||
|
||||
|
||||
def augment_messages_for_tools_llama_3_1(
|
||||
request: ChatCompletionRequest,
|
||||
) -> List[Message]:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue