mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Fix and add a test
This commit is contained in:
parent
40ba22f4c8
commit
cd84dee3e9
5 changed files with 76 additions and 62 deletions
|
@ -75,17 +75,19 @@ class ChatCompletionResponseEvent(BaseModel):
|
|||
|
||||
|
||||
class ResponseFormatType(Enum):
|
||||
json = "json"
|
||||
json_schema = "json_schema"
|
||||
grammar = "grammar"
|
||||
|
||||
|
||||
class JsonResponseFormat(BaseModel):
|
||||
type: Literal[ResponseFormat.json.value] = ResponseFormat.json.value
|
||||
type: Literal[ResponseFormatType.json_schema.value] = (
|
||||
ResponseFormatType.json_schema.value
|
||||
)
|
||||
schema: Dict[str, Any]
|
||||
|
||||
|
||||
class GrammarResponseFormat(BaseModel):
|
||||
type: Literal[ResponseFormat.grammar.value] = ResponseFormat.grammar.value
|
||||
type: Literal[ResponseFormatType.grammar.value] = ResponseFormatType.grammar.value
|
||||
bnf: Dict[str, Any]
|
||||
|
||||
|
||||
|
|
|
@ -89,6 +89,7 @@ class InferenceRouter(Inference):
|
|||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
@ -112,6 +113,7 @@ class InferenceRouter(Inference):
|
|||
model=model,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
@ -34,11 +35,8 @@ from pydantic import BaseModel
|
|||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
import math
|
||||
|
||||
from lmformatenforcer import JsonSchemaParser
|
||||
|
||||
from lmformatenforcer.tokenenforcer import TokenEnforcer, TokenEnforcerTokenizerData
|
||||
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
|
@ -48,13 +46,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||
|
||||
|
||||
class AnswerFormat(BaseModel):
|
||||
first_name: str
|
||||
last_name: str
|
||||
year_of_birth: int
|
||||
num_seasons_in_nba: int
|
||||
|
||||
|
||||
def model_checkpoint_dir(model) -> str:
|
||||
checkpoint_dir = Path(model_local_dir(model.descriptor()))
|
||||
|
||||
|
@ -261,9 +252,7 @@ class Llama:
|
|||
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
||||
|
||||
if logits_processor is not None:
|
||||
logits = logits_processor.process_logits(
|
||||
tokens[0, :cur_pos].tolist(), logits
|
||||
)
|
||||
logits = logits_processor.process_logits(tokens[:, :cur_pos], logits)
|
||||
|
||||
if temperature > 0:
|
||||
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
|
||||
|
@ -401,15 +390,36 @@ def sample_top_p(probs, p):
|
|||
return next_token
|
||||
|
||||
|
||||
class LogitsProcessor:
|
||||
def __init__(self, token_enforcer: TokenEnforcer):
|
||||
self.token_enforcer = token_enforcer
|
||||
self.mask: Optional[torch.Tensor] = None
|
||||
|
||||
def process_logits(
|
||||
self, tokens: torch.Tensor, scores: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
token_sequence = tokens[0, :].tolist()
|
||||
allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence)
|
||||
|
||||
if self.mask is not None:
|
||||
self.mask.fill_(-math.inf)
|
||||
else:
|
||||
self.mask = torch.full_like(scores, -math.inf)
|
||||
|
||||
self.mask[:, :, allowed_tokens] = 0
|
||||
scores = scores + self.mask
|
||||
return scores
|
||||
|
||||
|
||||
def get_logits_processor(
|
||||
tokenizer: Tokenizer,
|
||||
vocab_size: int,
|
||||
response_format: Optional[ResponseFormat],
|
||||
) -> Optional[LogitsProcessor]:
|
||||
) -> Optional["LogitsProcessor"]:
|
||||
if response_format is None:
|
||||
return None
|
||||
|
||||
if response_format.type != ResponseFormatType.json:
|
||||
if response_format.type != ResponseFormatType.json_schema.value:
|
||||
raise ValueError(f"Unsupported response format type {response_format.type}")
|
||||
|
||||
parser = JsonSchemaParser(response_format.schema)
|
||||
|
@ -422,28 +432,6 @@ def get_logits_processor(
|
|||
return LogitsProcessor(token_enforcer)
|
||||
|
||||
|
||||
class LogitsProcessor:
|
||||
def __init__(self, token_enforcer: TokenEnforcer):
|
||||
self.token_enforcer = token_enforcer
|
||||
self.mask: Optional[torch.Tensor] = None
|
||||
|
||||
def process_logits(
|
||||
self, input_ids: List[int], scores: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
token_sequence = input_ids
|
||||
allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence)
|
||||
# print(f"{allowed_tokens=}")
|
||||
allowed_tokens = torch.tensor(allowed_tokens, device=scores.device)
|
||||
if self.mask is not None:
|
||||
self.mask.fill_(-math.inf)
|
||||
else:
|
||||
self.mask = torch.full_like(scores, -math.inf)
|
||||
|
||||
self.mask[:, :, allowed_tokens] = 0
|
||||
scores = scores + self.mask
|
||||
return scores
|
||||
|
||||
|
||||
def _build_regular_tokens_list(
|
||||
tokenizer: Tokenizer, vocab_size: int
|
||||
) -> List[Tuple[int, str, bool]]:
|
||||
|
|
|
@ -10,6 +10,8 @@ import os
|
|||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
|
@ -85,24 +87,11 @@ async def inference_settings(request):
|
|||
}
|
||||
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AnswerFormat(BaseModel):
|
||||
first_name: str
|
||||
last_name: str
|
||||
year_of_birth: int
|
||||
num_seasons_in_nba: int
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_messages():
|
||||
question = "Please give me information about Michael Jordan."
|
||||
# question_with_schema = f"{question}{AnswerFormat.schema_json()}"
|
||||
return [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
# UserMessage(content="What's the weather like today?"),
|
||||
UserMessage(content=question),
|
||||
UserMessage(content="What's the weather like today?"),
|
||||
]
|
||||
|
||||
|
||||
|
@ -183,23 +172,55 @@ async def test_completion(inference_settings):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_non_streaming(inference_settings, sample_messages):
|
||||
print(AnswerFormat.schema_json())
|
||||
print(AnswerFormat.schema())
|
||||
inference_impl = inference_settings["impl"]
|
||||
response = await inference_impl.chat_completion(
|
||||
messages=sample_messages,
|
||||
stream=False,
|
||||
**inference_settings["common_params"],
|
||||
)
|
||||
|
||||
assert isinstance(response, ChatCompletionResponse)
|
||||
assert response.completion_message.role == "assistant"
|
||||
assert isinstance(response.completion_message.content, str)
|
||||
assert len(response.completion_message.content) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_structured_output(inference_settings):
|
||||
inference_impl = inference_settings["impl"]
|
||||
params = inference_settings["common_params"]
|
||||
|
||||
provider = inference_impl.routing_table.get_provider_impl(params["model"])
|
||||
if provider.__provider_id__ != "meta-reference":
|
||||
pytest.skip("Other inference providers don't support structured output yet")
|
||||
|
||||
class AnswerFormat(BaseModel):
|
||||
first_name: str
|
||||
last_name: str
|
||||
year_of_birth: int
|
||||
num_seasons_in_nba: int
|
||||
|
||||
response = await inference_impl.chat_completion(
|
||||
messages=[
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
UserMessage(content="Please give me information about Michael Jordan."),
|
||||
],
|
||||
stream=False,
|
||||
response_format=JsonResponseFormat(
|
||||
schema=AnswerFormat.schema(),
|
||||
),
|
||||
**inference_settings["common_params"],
|
||||
)
|
||||
|
||||
print(response)
|
||||
assert isinstance(response, ChatCompletionResponse)
|
||||
assert response.completion_message.role == "assistant"
|
||||
assert isinstance(response.completion_message.content, str)
|
||||
assert len(response.completion_message.content) > 0
|
||||
|
||||
answer = AnswerFormat.parse_raw(response.completion_message.content)
|
||||
assert answer.first_name == "Michael"
|
||||
assert answer.last_name == "Jordan"
|
||||
assert answer.year_of_birth == 1963
|
||||
assert answer.num_seasons_in_nba == 15
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import json
|
||||
from typing import Tuple
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
|
@ -77,13 +78,13 @@ def chat_completion_request_to_messages(
|
|||
messages = request.messages
|
||||
|
||||
if fmt := request.response_format:
|
||||
if fmt.type == ResponseFormatType.json:
|
||||
if fmt.type == ResponseFormatType.json_schema.value:
|
||||
messages.append(
|
||||
UserMessage(
|
||||
content=f"Please response in JSON format with the schema: {json.dumps(fmt.schema)}"
|
||||
content=f"Please respond in JSON format with the schema: {json.dumps(fmt.schema)}"
|
||||
)
|
||||
)
|
||||
elif fmt.type == ResponseFormatType.grammar:
|
||||
elif fmt.type == ResponseFormatType.grammar.value:
|
||||
raise NotImplementedError("Grammar response format not supported yet")
|
||||
else:
|
||||
raise ValueError(f"Unknown response format {fmt.type}")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue