From cd84dee3e99141ee635d932fb24ddc7e4873972c Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 21 Oct 2024 22:02:37 -0700 Subject: [PATCH] Fix and add a test --- llama_stack/apis/inference/inference.py | 8 ++- llama_stack/distribution/routers/routers.py | 2 + .../meta_reference/inference/generation.py | 64 ++++++++----------- .../tests/inference/test_inference.py | 57 +++++++++++------ .../utils/inference/prompt_adapter.py | 7 +- 5 files changed, 76 insertions(+), 62 deletions(-) diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index f256901ee..4ee01acae 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -75,17 +75,19 @@ class ChatCompletionResponseEvent(BaseModel): class ResponseFormatType(Enum): - json = "json" + json_schema = "json_schema" grammar = "grammar" class JsonResponseFormat(BaseModel): - type: Literal[ResponseFormat.json.value] = ResponseFormat.json.value + type: Literal[ResponseFormatType.json_schema.value] = ( + ResponseFormatType.json_schema.value + ) schema: Dict[str, Any] class GrammarResponseFormat(BaseModel): - type: Literal[ResponseFormat.grammar.value] = ResponseFormat.grammar.value + type: Literal[ResponseFormatType.grammar.value] = ResponseFormatType.grammar.value bnf: Dict[str, Any] diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index b33c5ec36..26a0988ed 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -89,6 +89,7 @@ class InferenceRouter(Inference): tools=tools or [], tool_choice=tool_choice, tool_prompt_format=tool_prompt_format, + response_format=response_format, stream=stream, logprobs=logprobs, ) @@ -112,6 +113,7 @@ class InferenceRouter(Inference): model=model, content=content, sampling_params=sampling_params, + response_format=response_format, stream=stream, logprobs=logprobs, ) diff --git a/llama_stack/providers/impls/meta_reference/inference/generation.py b/llama_stack/providers/impls/meta_reference/inference/generation.py index fc1c809ad..b424a9347 100644 --- a/llama_stack/providers/impls/meta_reference/inference/generation.py +++ b/llama_stack/providers/impls/meta_reference/inference/generation.py @@ -8,6 +8,7 @@ # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. import json +import math import os import sys import time @@ -34,11 +35,8 @@ from pydantic import BaseModel from termcolor import cprint from llama_stack.apis.inference import * # noqa: F403 -import math -from lmformatenforcer import JsonSchemaParser - -from lmformatenforcer.tokenenforcer import TokenEnforcer, TokenEnforcerTokenizerData +from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.providers.utils.inference.prompt_adapter import ( @@ -48,13 +46,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig -class AnswerFormat(BaseModel): - first_name: str - last_name: str - year_of_birth: int - num_seasons_in_nba: int - - def model_checkpoint_dir(model) -> str: checkpoint_dir = Path(model_local_dir(model.descriptor())) @@ -261,9 +252,7 @@ class Llama: logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) if logits_processor is not None: - logits = logits_processor.process_logits( - tokens[0, :cur_pos].tolist(), logits - ) + logits = logits_processor.process_logits(tokens[:, :cur_pos], logits) if temperature > 0: probs = torch.softmax(logits[:, -1] / temperature, dim=-1) @@ -401,15 +390,36 @@ def sample_top_p(probs, p): return next_token +class LogitsProcessor: + def __init__(self, token_enforcer: TokenEnforcer): + self.token_enforcer = token_enforcer + self.mask: Optional[torch.Tensor] = None + + def process_logits( + self, tokens: torch.Tensor, scores: torch.Tensor + ) -> torch.Tensor: + token_sequence = tokens[0, :].tolist() + allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence) + + if self.mask is not None: + self.mask.fill_(-math.inf) + else: + self.mask = torch.full_like(scores, -math.inf) + + self.mask[:, :, allowed_tokens] = 0 + scores = scores + self.mask + return scores + + def get_logits_processor( tokenizer: Tokenizer, vocab_size: int, response_format: Optional[ResponseFormat], -) -> Optional[LogitsProcessor]: +) -> Optional["LogitsProcessor"]: if response_format is None: return None - if response_format.type != ResponseFormatType.json: + if response_format.type != ResponseFormatType.json_schema.value: raise ValueError(f"Unsupported response format type {response_format.type}") parser = JsonSchemaParser(response_format.schema) @@ -422,28 +432,6 @@ def get_logits_processor( return LogitsProcessor(token_enforcer) -class LogitsProcessor: - def __init__(self, token_enforcer: TokenEnforcer): - self.token_enforcer = token_enforcer - self.mask: Optional[torch.Tensor] = None - - def process_logits( - self, input_ids: List[int], scores: torch.Tensor - ) -> torch.Tensor: - token_sequence = input_ids - allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence) - # print(f"{allowed_tokens=}") - allowed_tokens = torch.tensor(allowed_tokens, device=scores.device) - if self.mask is not None: - self.mask.fill_(-math.inf) - else: - self.mask = torch.full_like(scores, -math.inf) - - self.mask[:, :, allowed_tokens] = 0 - scores = scores + self.mask - return scores - - def _build_regular_tokens_list( tokenizer: Tokenizer, vocab_size: int ) -> List[Tuple[int, str, bool]]: diff --git a/llama_stack/providers/tests/inference/test_inference.py b/llama_stack/providers/tests/inference/test_inference.py index 18c6327db..3e61337eb 100644 --- a/llama_stack/providers/tests/inference/test_inference.py +++ b/llama_stack/providers/tests/inference/test_inference.py @@ -10,6 +10,8 @@ import os import pytest import pytest_asyncio +from pydantic import BaseModel + from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403 @@ -85,24 +87,11 @@ async def inference_settings(request): } -from pydantic import BaseModel - - -class AnswerFormat(BaseModel): - first_name: str - last_name: str - year_of_birth: int - num_seasons_in_nba: int - - @pytest.fixture def sample_messages(): - question = "Please give me information about Michael Jordan." - # question_with_schema = f"{question}{AnswerFormat.schema_json()}" return [ SystemMessage(content="You are a helpful assistant."), - # UserMessage(content="What's the weather like today?"), - UserMessage(content=question), + UserMessage(content="What's the weather like today?"), ] @@ -183,23 +172,55 @@ async def test_completion(inference_settings): @pytest.mark.asyncio async def test_chat_completion_non_streaming(inference_settings, sample_messages): - print(AnswerFormat.schema_json()) - print(AnswerFormat.schema()) inference_impl = inference_settings["impl"] response = await inference_impl.chat_completion( messages=sample_messages, stream=False, + **inference_settings["common_params"], + ) + + assert isinstance(response, ChatCompletionResponse) + assert response.completion_message.role == "assistant" + assert isinstance(response.completion_message.content, str) + assert len(response.completion_message.content) > 0 + + +@pytest.mark.asyncio +async def test_structured_output(inference_settings): + inference_impl = inference_settings["impl"] + params = inference_settings["common_params"] + + provider = inference_impl.routing_table.get_provider_impl(params["model"]) + if provider.__provider_id__ != "meta-reference": + pytest.skip("Other inference providers don't support structured output yet") + + class AnswerFormat(BaseModel): + first_name: str + last_name: str + year_of_birth: int + num_seasons_in_nba: int + + response = await inference_impl.chat_completion( + messages=[ + SystemMessage(content="You are a helpful assistant."), + UserMessage(content="Please give me information about Michael Jordan."), + ], + stream=False, response_format=JsonResponseFormat( schema=AnswerFormat.schema(), ), **inference_settings["common_params"], ) - print(response) assert isinstance(response, ChatCompletionResponse) assert response.completion_message.role == "assistant" assert isinstance(response.completion_message.content, str) - assert len(response.completion_message.content) > 0 + + answer = AnswerFormat.parse_raw(response.completion_message.content) + assert answer.first_name == "Michael" + assert answer.last_name == "Jordan" + assert answer.year_of_birth == 1963 + assert answer.num_seasons_in_nba == 15 @pytest.mark.asyncio diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index cab2e5169..48f1df02f 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -3,6 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import json from typing import Tuple from llama_models.llama3.api.chat_format import ChatFormat @@ -77,13 +78,13 @@ def chat_completion_request_to_messages( messages = request.messages if fmt := request.response_format: - if fmt.type == ResponseFormatType.json: + if fmt.type == ResponseFormatType.json_schema.value: messages.append( UserMessage( - content=f"Please response in JSON format with the schema: {json.dumps(fmt.schema)}" + content=f"Please respond in JSON format with the schema: {json.dumps(fmt.schema)}" ) ) - elif fmt.type == ResponseFormatType.grammar: + elif fmt.type == ResponseFormatType.grammar.value: raise NotImplementedError("Grammar response format not supported yet") else: raise ValueError(f"Unknown response format {fmt.type}")