Fix and add a test

This commit is contained in:
Ashwin Bharambe 2024-10-21 22:02:37 -07:00
parent 40ba22f4c8
commit cd84dee3e9
5 changed files with 76 additions and 62 deletions

View file

@ -75,17 +75,19 @@ class ChatCompletionResponseEvent(BaseModel):
class ResponseFormatType(Enum):
json = "json"
json_schema = "json_schema"
grammar = "grammar"
class JsonResponseFormat(BaseModel):
type: Literal[ResponseFormat.json.value] = ResponseFormat.json.value
type: Literal[ResponseFormatType.json_schema.value] = (
ResponseFormatType.json_schema.value
)
schema: Dict[str, Any]
class GrammarResponseFormat(BaseModel):
type: Literal[ResponseFormat.grammar.value] = ResponseFormat.grammar.value
type: Literal[ResponseFormatType.grammar.value] = ResponseFormatType.grammar.value
bnf: Dict[str, Any]

View file

@ -89,6 +89,7 @@ class InferenceRouter(Inference):
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
@ -112,6 +113,7 @@ class InferenceRouter(Inference):
model=model,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)

View file

@ -8,6 +8,7 @@
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import json
import math
import os
import sys
import time
@ -34,11 +35,8 @@ from pydantic import BaseModel
from termcolor import cprint
from llama_stack.apis.inference import * # noqa: F403
import math
from lmformatenforcer import JsonSchemaParser
from lmformatenforcer.tokenenforcer import TokenEnforcer, TokenEnforcerTokenizerData
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.utils.inference.prompt_adapter import (
@ -48,13 +46,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
class AnswerFormat(BaseModel):
first_name: str
last_name: str
year_of_birth: int
num_seasons_in_nba: int
def model_checkpoint_dir(model) -> str:
checkpoint_dir = Path(model_local_dir(model.descriptor()))
@ -261,9 +252,7 @@ class Llama:
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if logits_processor is not None:
logits = logits_processor.process_logits(
tokens[0, :cur_pos].tolist(), logits
)
logits = logits_processor.process_logits(tokens[:, :cur_pos], logits)
if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
@ -401,15 +390,36 @@ def sample_top_p(probs, p):
return next_token
class LogitsProcessor:
def __init__(self, token_enforcer: TokenEnforcer):
self.token_enforcer = token_enforcer
self.mask: Optional[torch.Tensor] = None
def process_logits(
self, tokens: torch.Tensor, scores: torch.Tensor
) -> torch.Tensor:
token_sequence = tokens[0, :].tolist()
allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence)
if self.mask is not None:
self.mask.fill_(-math.inf)
else:
self.mask = torch.full_like(scores, -math.inf)
self.mask[:, :, allowed_tokens] = 0
scores = scores + self.mask
return scores
def get_logits_processor(
tokenizer: Tokenizer,
vocab_size: int,
response_format: Optional[ResponseFormat],
) -> Optional[LogitsProcessor]:
) -> Optional["LogitsProcessor"]:
if response_format is None:
return None
if response_format.type != ResponseFormatType.json:
if response_format.type != ResponseFormatType.json_schema.value:
raise ValueError(f"Unsupported response format type {response_format.type}")
parser = JsonSchemaParser(response_format.schema)
@ -422,28 +432,6 @@ def get_logits_processor(
return LogitsProcessor(token_enforcer)
class LogitsProcessor:
def __init__(self, token_enforcer: TokenEnforcer):
self.token_enforcer = token_enforcer
self.mask: Optional[torch.Tensor] = None
def process_logits(
self, input_ids: List[int], scores: torch.Tensor
) -> torch.Tensor:
token_sequence = input_ids
allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence)
# print(f"{allowed_tokens=}")
allowed_tokens = torch.tensor(allowed_tokens, device=scores.device)
if self.mask is not None:
self.mask.fill_(-math.inf)
else:
self.mask = torch.full_like(scores, -math.inf)
self.mask[:, :, allowed_tokens] = 0
scores = scores + self.mask
return scores
def _build_regular_tokens_list(
tokenizer: Tokenizer, vocab_size: int
) -> List[Tuple[int, str, bool]]:

View file

@ -10,6 +10,8 @@ import os
import pytest
import pytest_asyncio
from pydantic import BaseModel
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
@ -85,24 +87,11 @@ async def inference_settings(request):
}
from pydantic import BaseModel
class AnswerFormat(BaseModel):
first_name: str
last_name: str
year_of_birth: int
num_seasons_in_nba: int
@pytest.fixture
def sample_messages():
question = "Please give me information about Michael Jordan."
# question_with_schema = f"{question}{AnswerFormat.schema_json()}"
return [
SystemMessage(content="You are a helpful assistant."),
# UserMessage(content="What's the weather like today?"),
UserMessage(content=question),
UserMessage(content="What's the weather like today?"),
]
@ -183,23 +172,55 @@ async def test_completion(inference_settings):
@pytest.mark.asyncio
async def test_chat_completion_non_streaming(inference_settings, sample_messages):
print(AnswerFormat.schema_json())
print(AnswerFormat.schema())
inference_impl = inference_settings["impl"]
response = await inference_impl.chat_completion(
messages=sample_messages,
stream=False,
**inference_settings["common_params"],
)
assert isinstance(response, ChatCompletionResponse)
assert response.completion_message.role == "assistant"
assert isinstance(response.completion_message.content, str)
assert len(response.completion_message.content) > 0
@pytest.mark.asyncio
async def test_structured_output(inference_settings):
inference_impl = inference_settings["impl"]
params = inference_settings["common_params"]
provider = inference_impl.routing_table.get_provider_impl(params["model"])
if provider.__provider_id__ != "meta-reference":
pytest.skip("Other inference providers don't support structured output yet")
class AnswerFormat(BaseModel):
first_name: str
last_name: str
year_of_birth: int
num_seasons_in_nba: int
response = await inference_impl.chat_completion(
messages=[
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="Please give me information about Michael Jordan."),
],
stream=False,
response_format=JsonResponseFormat(
schema=AnswerFormat.schema(),
),
**inference_settings["common_params"],
)
print(response)
assert isinstance(response, ChatCompletionResponse)
assert response.completion_message.role == "assistant"
assert isinstance(response.completion_message.content, str)
assert len(response.completion_message.content) > 0
answer = AnswerFormat.parse_raw(response.completion_message.content)
assert answer.first_name == "Michael"
assert answer.last_name == "Jordan"
assert answer.year_of_birth == 1963
assert answer.num_seasons_in_nba == 15
@pytest.mark.asyncio

View file

@ -3,6 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import Tuple
from llama_models.llama3.api.chat_format import ChatFormat
@ -77,13 +78,13 @@ def chat_completion_request_to_messages(
messages = request.messages
if fmt := request.response_format:
if fmt.type == ResponseFormatType.json:
if fmt.type == ResponseFormatType.json_schema.value:
messages.append(
UserMessage(
content=f"Please response in JSON format with the schema: {json.dumps(fmt.schema)}"
content=f"Please respond in JSON format with the schema: {json.dumps(fmt.schema)}"
)
)
elif fmt.type == ResponseFormatType.grammar:
elif fmt.type == ResponseFormatType.grammar.value:
raise NotImplementedError("Grammar response format not supported yet")
else:
raise ValueError(f"Unknown response format {fmt.type}")