Avoid warnings from pydantic for overriding schema

Also fix structured output in completions
This commit is contained in:
Ashwin Bharambe 2024-10-28 13:36:17 -07:00
parent ed833bb758
commit eccd7dc4a9
7 changed files with 50 additions and 26 deletions

View file

@ -39,6 +39,7 @@ from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerToken
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.utils.inference.prompt_adapter import (
augment_content_with_response_format_prompt,
chat_completion_request_to_messages,
)
@ -346,7 +347,10 @@ class Llama:
):
max_gen_len = self.model.params.max_seq_len - 1
model_input = self.formatter.encode_content(request.content)
content = augment_content_with_response_format_prompt(
request.response_format, request.content
)
model_input = self.formatter.encode_content(content)
yield from self.generate(
model_input=model_input,
max_gen_len=max_gen_len,
@ -451,7 +455,7 @@ def get_logits_processor(
if response_format.type != ResponseFormatType.json_schema.value:
raise ValueError(f"Unsupported response format type {response_format.type}")
parser = JsonSchemaParser(response_format.schema)
parser = JsonSchemaParser(response_format.json_schema)
data = TokenEnforcerTokenizerData(
_build_regular_tokens_list(tokenizer, vocab_size),
tokenizer.decode,