From c06718fbd54ede05c47b6e6c2b21c9f9ead98ec9 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 22 Oct 2024 12:53:34 -0700 Subject: [PATCH] Add support for Structured Output / Guided decoding (#281) Added support for structured output in the API and added a reference implementation for meta-reference. A few notes: * Two formats are specified in the API: Json schema and EBNF based grammar * Implementation only supports Json for now We use lm-format-enhancer to provide the implementation right now but may change this especially because BNF grammars aren't supported by that library. Fireworks has support for structured output and Together has limited supported for it too. Subsequent PRs will add these changes. We would like all our inference providers to provide structured output for llama models since it is an extremely important and highly sought-after need by the developers. --- llama_stack/apis/inference/client.py | 2 + llama_stack/apis/inference/inference.py | 28 +++++++ llama_stack/distribution/routers/routers.py | 4 + .../adapters/inference/bedrock/bedrock.py | 2 + .../inference/databricks/databricks.py | 2 + .../adapters/inference/fireworks/fireworks.py | 17 ++++ .../adapters/inference/ollama/ollama.py | 2 + .../providers/adapters/inference/tgi/tgi.py | 14 ++++ .../adapters/inference/together/together.py | 2 + .../providers/adapters/inference/vllm/vllm.py | 2 + .../agents/tests/test_chat_agent.py | 1 + .../meta_reference/inference/generation.py | 84 ++++++++++++++++++- .../meta_reference/inference/inference.py | 4 + llama_stack/providers/registry/inference.py | 38 ++++----- .../tests/inference/test_inference.py | 59 +++++++++++++ .../utils/inference/prompt_adapter.py | 21 ++++- 16 files changed, 257 insertions(+), 25 deletions(-) diff --git a/llama_stack/apis/inference/client.py b/llama_stack/apis/inference/client.py index 90636fa36..7359c6057 100644 --- a/llama_stack/apis/inference/client.py +++ b/llama_stack/apis/inference/client.py @@ -53,6 +53,7 @@ class InferenceClient(Inference): tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: @@ -63,6 +64,7 @@ class InferenceClient(Inference): tools=tools or [], tool_choice=tool_choice, tool_prompt_format=tool_prompt_format, + response_format=response_format, stream=stream, logprobs=logprobs, ) diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 5895e528e..4ee01acae 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -74,11 +74,35 @@ class ChatCompletionResponseEvent(BaseModel): stop_reason: Optional[StopReason] = None +class ResponseFormatType(Enum): + json_schema = "json_schema" + grammar = "grammar" + + +class JsonResponseFormat(BaseModel): + type: Literal[ResponseFormatType.json_schema.value] = ( + ResponseFormatType.json_schema.value + ) + schema: Dict[str, Any] + + +class GrammarResponseFormat(BaseModel): + type: Literal[ResponseFormatType.grammar.value] = ResponseFormatType.grammar.value + bnf: Dict[str, Any] + + +ResponseFormat = Annotated[ + Union[JsonResponseFormat, GrammarResponseFormat], + Field(discriminator="type"), +] + + @json_schema_type class CompletionRequest(BaseModel): model: str content: InterleavedTextMedia sampling_params: Optional[SamplingParams] = SamplingParams() + response_format: Optional[ResponseFormat] = None stream: Optional[bool] = False logprobs: Optional[LogProbConfig] = None @@ -107,6 +131,7 @@ class BatchCompletionRequest(BaseModel): model: str content_batch: List[InterleavedTextMedia] sampling_params: Optional[SamplingParams] = SamplingParams() + response_format: Optional[ResponseFormat] = None logprobs: Optional[LogProbConfig] = None @@ -129,6 +154,7 @@ class ChatCompletionRequest(BaseModel): tool_prompt_format: Optional[ToolPromptFormat] = Field( default=ToolPromptFormat.json ) + response_format: Optional[ResponseFormat] = None stream: Optional[bool] = False logprobs: Optional[LogProbConfig] = None @@ -188,6 +214,7 @@ class Inference(Protocol): model: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ... @@ -204,6 +231,7 @@ class Inference(Protocol): tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ... diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index a78e808d0..26a0988ed 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -75,6 +75,7 @@ class InferenceRouter(Inference): model: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, @@ -88,6 +89,7 @@ class InferenceRouter(Inference): tools=tools or [], tool_choice=tool_choice, tool_prompt_format=tool_prompt_format, + response_format=response_format, stream=stream, logprobs=logprobs, ) @@ -102,6 +104,7 @@ class InferenceRouter(Inference): model: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: @@ -110,6 +113,7 @@ class InferenceRouter(Inference): model=model, content=content, sampling_params=sampling_params, + response_format=response_format, stream=stream, logprobs=logprobs, ) diff --git a/llama_stack/providers/adapters/inference/bedrock/bedrock.py b/llama_stack/providers/adapters/inference/bedrock/bedrock.py index 8440ecc20..3800c0496 100644 --- a/llama_stack/providers/adapters/inference/bedrock/bedrock.py +++ b/llama_stack/providers/adapters/inference/bedrock/bedrock.py @@ -52,6 +52,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): model: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: @@ -288,6 +289,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): model: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, # zero-shot tool definitions as input to the model tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, diff --git a/llama_stack/providers/adapters/inference/databricks/databricks.py b/llama_stack/providers/adapters/inference/databricks/databricks.py index 9f50ad227..4752e3fe4 100644 --- a/llama_stack/providers/adapters/inference/databricks/databricks.py +++ b/llama_stack/providers/adapters/inference/databricks/databricks.py @@ -53,6 +53,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): model: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: @@ -63,6 +64,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): model: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, diff --git a/llama_stack/providers/adapters/inference/fireworks/fireworks.py b/llama_stack/providers/adapters/inference/fireworks/fireworks.py index 537f3a6b4..441f32166 100644 --- a/llama_stack/providers/adapters/inference/fireworks/fireworks.py +++ b/llama_stack/providers/adapters/inference/fireworks/fireworks.py @@ -56,6 +56,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): model: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: @@ -69,6 +70,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: @@ -79,6 +81,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): tools=tools or [], tool_choice=tool_choice, tool_prompt_format=tool_prompt_format, + response_format=response_format, stream=stream, logprobs=logprobs, ) @@ -115,6 +118,20 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): options = get_sampling_options(request) options.setdefault("max_tokens", 512) + + if fmt := request.response_format: + if fmt.type == ResponseFormatType.json_schema.value: + options["response_format"] = { + "type": "json_object", + "schema": fmt.schema, + } + elif fmt.type == ResponseFormatType.grammar.value: + options["response_format"] = { + "type": "grammar", + "grammar": fmt.bnf, + } + else: + raise ValueError(f"Unknown response format {fmt.type}") return { "model": self.map_to_provider_model(request.model), "prompt": prompt, diff --git a/llama_stack/providers/adapters/inference/ollama/ollama.py b/llama_stack/providers/adapters/inference/ollama/ollama.py index b19d54182..d4fe75cfa 100644 --- a/llama_stack/providers/adapters/inference/ollama/ollama.py +++ b/llama_stack/providers/adapters/inference/ollama/ollama.py @@ -93,6 +93,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): model: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: @@ -160,6 +161,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): model: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, diff --git a/llama_stack/providers/adapters/inference/tgi/tgi.py b/llama_stack/providers/adapters/inference/tgi/tgi.py index 3c610099c..f19181320 100644 --- a/llama_stack/providers/adapters/inference/tgi/tgi.py +++ b/llama_stack/providers/adapters/inference/tgi/tgi.py @@ -71,6 +71,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): model: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: @@ -84,6 +85,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: @@ -94,6 +96,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): tools=tools or [], tool_choice=tool_choice, tool_prompt_format=tool_prompt_format, + response_format=response_format, stream=stream, logprobs=logprobs, ) @@ -148,6 +151,17 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): self.max_tokens - input_tokens - 1, ) options = get_sampling_options(request) + if fmt := request.response_format: + if fmt.type == ResponseFormatType.json_schema.value: + options["grammar"] = { + "type": "json", + "value": fmt.schema, + } + elif fmt.type == ResponseFormatType.grammar.value: + raise ValueError("Grammar response format not supported yet") + else: + raise ValueError(f"Unexpected response format: {fmt.type}") + return dict( prompt=prompt, stream=request.stream, diff --git a/llama_stack/providers/adapters/inference/together/together.py b/llama_stack/providers/adapters/inference/together/together.py index 8c73d75ec..f88e4c4c2 100644 --- a/llama_stack/providers/adapters/inference/together/together.py +++ b/llama_stack/providers/adapters/inference/together/together.py @@ -59,6 +59,7 @@ class TogetherInferenceAdapter( model: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: @@ -69,6 +70,7 @@ class TogetherInferenceAdapter( model: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, diff --git a/llama_stack/providers/adapters/inference/vllm/vllm.py b/llama_stack/providers/adapters/inference/vllm/vllm.py index a5934928a..dacf646b0 100644 --- a/llama_stack/providers/adapters/inference/vllm/vllm.py +++ b/llama_stack/providers/adapters/inference/vllm/vllm.py @@ -80,6 +80,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): model: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: @@ -90,6 +91,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): model: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, diff --git a/llama_stack/providers/impls/meta_reference/agents/tests/test_chat_agent.py b/llama_stack/providers/impls/meta_reference/agents/tests/test_chat_agent.py index 46423814b..782e0ca7d 100644 --- a/llama_stack/providers/impls/meta_reference/agents/tests/test_chat_agent.py +++ b/llama_stack/providers/impls/meta_reference/agents/tests/test_chat_agent.py @@ -26,6 +26,7 @@ class MockInferenceAPI: model: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = None, tool_prompt_format: Optional[ToolPromptFormat] = None, diff --git a/llama_stack/providers/impls/meta_reference/inference/generation.py b/llama_stack/providers/impls/meta_reference/inference/generation.py index 9ca128176..b424a9347 100644 --- a/llama_stack/providers/impls/meta_reference/inference/generation.py +++ b/llama_stack/providers/impls/meta_reference/inference/generation.py @@ -8,11 +8,12 @@ # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. import json +import math import os import sys import time from pathlib import Path -from typing import Generator, List, Optional, Union +from typing import Generator, List, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -34,6 +35,9 @@ from pydantic import BaseModel from termcolor import cprint from llama_stack.apis.inference import * # noqa: F403 + +from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData + from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_messages, @@ -67,7 +71,7 @@ class Llama: def build( config: Union[ MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig - ] + ], ): """ Build a Llama instance by initializing and loading a model checkpoint. @@ -171,6 +175,7 @@ class Llama: echo: bool = False, include_stop_token: bool = False, print_input_tokens: bool = False, + logits_processor: Optional["LogitsProcessor"] = None, ) -> Generator: params = self.model.params @@ -246,6 +251,9 @@ class Llama: else: logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) + if logits_processor is not None: + logits = logits_processor.process_logits(tokens[:, :cur_pos], logits) + if temperature > 0: probs = torch.softmax(logits[:, -1] / temperature, dim=-1) next_token = sample_top_p(probs, top_p) @@ -317,7 +325,11 @@ class Llama: top_p=sampling_params.top_p, logprobs=bool(request.logprobs), include_stop_token=True, - echo=False, + logits_processor=get_logits_processor( + self.tokenizer, + self.args.vocab_size, + request.response_format, + ), ) def chat_completion( @@ -345,6 +357,11 @@ class Llama: top_p=sampling_params.top_p, logprobs=bool(request.logprobs), include_stop_token=True, + logits_processor=get_logits_processor( + self.tokenizer, + self.args.vocab_size, + request.response_format, + ), ) @@ -371,3 +388,64 @@ def sample_top_p(probs, p): next_token = torch.multinomial(probs_sort, num_samples=1) next_token = torch.gather(probs_idx, -1, next_token) return next_token + + +class LogitsProcessor: + def __init__(self, token_enforcer: TokenEnforcer): + self.token_enforcer = token_enforcer + self.mask: Optional[torch.Tensor] = None + + def process_logits( + self, tokens: torch.Tensor, scores: torch.Tensor + ) -> torch.Tensor: + token_sequence = tokens[0, :].tolist() + allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence) + + if self.mask is not None: + self.mask.fill_(-math.inf) + else: + self.mask = torch.full_like(scores, -math.inf) + + self.mask[:, :, allowed_tokens] = 0 + scores = scores + self.mask + return scores + + +def get_logits_processor( + tokenizer: Tokenizer, + vocab_size: int, + response_format: Optional[ResponseFormat], +) -> Optional["LogitsProcessor"]: + if response_format is None: + return None + + if response_format.type != ResponseFormatType.json_schema.value: + raise ValueError(f"Unsupported response format type {response_format.type}") + + parser = JsonSchemaParser(response_format.schema) + data = TokenEnforcerTokenizerData( + _build_regular_tokens_list(tokenizer, vocab_size), + tokenizer.decode, + tokenizer.stop_tokens, + ) + token_enforcer = TokenEnforcer(data, parser) + return LogitsProcessor(token_enforcer) + + +def _build_regular_tokens_list( + tokenizer: Tokenizer, vocab_size: int +) -> List[Tuple[int, str, bool]]: + token_0 = tokenizer.encode("0", bos=False, eos=False)[-1] + regular_tokens = [] + + special_token_ids = set(tokenizer.special_tokens.values()) + for token_idx in range(vocab_size): + if token_idx in special_token_ids: + continue + + # We prepend token 0 and skip the first letter of the result to get a space if the token is a start word. + decoded_after_0 = tokenizer.decode([token_0, token_idx])[1:] + decoded_regular = tokenizer.decode([token_idx]) + is_word_start_token = len(decoded_after_0) > len(decoded_regular) + regular_tokens.append((token_idx, decoded_after_0, is_word_start_token)) + return regular_tokens diff --git a/llama_stack/providers/impls/meta_reference/inference/inference.py b/llama_stack/providers/impls/meta_reference/inference/inference.py index 34053343e..5588be6c0 100644 --- a/llama_stack/providers/impls/meta_reference/inference/inference.py +++ b/llama_stack/providers/impls/meta_reference/inference/inference.py @@ -71,6 +71,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): model: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: @@ -81,6 +82,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): model=model, content=content, sampling_params=sampling_params, + response_format=response_format, stream=stream, logprobs=logprobs, ) @@ -186,6 +188,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): model: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, @@ -203,6 +206,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): tools=tools or [], tool_choice=tool_choice, tool_prompt_format=tool_prompt_format, + response_format=response_format, stream=stream, logprobs=logprobs, ) diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 5a09b6af5..6f8bc2c6e 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -9,36 +9,36 @@ from typing import List from llama_stack.distribution.datatypes import * # noqa: F403 +META_REFERENCE_DEPS = [ + "accelerate", + "blobfile", + "fairscale", + "torch", + "torchvision", + "transformers", + "zmq", + "lm-format-enforcer", +] + + def available_providers() -> List[ProviderSpec]: return [ InlineProviderSpec( api=Api.inference, provider_type="meta-reference", - pip_packages=[ - "accelerate", - "blobfile", - "fairscale", - "torch", - "torchvision", - "transformers", - "zmq", - ], + pip_packages=META_REFERENCE_DEPS, module="llama_stack.providers.impls.meta_reference.inference", config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceInferenceConfig", ), InlineProviderSpec( api=Api.inference, provider_type="meta-reference-quantized", - pip_packages=[ - "accelerate", - "blobfile", - "fairscale", - "fbgemm-gpu==0.8.0", - "torch", - "torchvision", - "transformers", - "zmq", - ], + pip_packages=( + META_REFERENCE_DEPS + + [ + "fbgemm-gpu==0.8.0", + ] + ), module="llama_stack.providers.impls.meta_reference.inference", config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceQuantizedInferenceConfig", ), diff --git a/llama_stack/providers/tests/inference/test_inference.py b/llama_stack/providers/tests/inference/test_inference.py index afec9a837..e89f672b1 100644 --- a/llama_stack/providers/tests/inference/test_inference.py +++ b/llama_stack/providers/tests/inference/test_inference.py @@ -10,6 +10,8 @@ import os import pytest import pytest_asyncio +from pydantic import BaseModel, ValidationError + from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403 @@ -183,6 +185,63 @@ async def test_chat_completion_non_streaming(inference_settings, sample_messages assert len(response.completion_message.content) > 0 +@pytest.mark.asyncio +async def test_structured_output(inference_settings): + inference_impl = inference_settings["impl"] + params = inference_settings["common_params"] + + provider = inference_impl.routing_table.get_provider_impl(params["model"]) + if provider.__provider_spec__.provider_type not in ( + "meta-reference", + "remote::fireworks", + "remote::tgi", + ): + pytest.skip("Other inference providers don't support structured output yet") + + class AnswerFormat(BaseModel): + first_name: str + last_name: str + year_of_birth: int + num_seasons_in_nba: int + + response = await inference_impl.chat_completion( + messages=[ + SystemMessage(content="You are a helpful assistant."), + UserMessage(content="Please give me information about Michael Jordan."), + ], + stream=False, + response_format=JsonResponseFormat( + schema=AnswerFormat.model_json_schema(), + ), + **inference_settings["common_params"], + ) + + assert isinstance(response, ChatCompletionResponse) + assert response.completion_message.role == "assistant" + assert isinstance(response.completion_message.content, str) + + answer = AnswerFormat.parse_raw(response.completion_message.content) + assert answer.first_name == "Michael" + assert answer.last_name == "Jordan" + assert answer.year_of_birth == 1963 + assert answer.num_seasons_in_nba == 15 + + response = await inference_impl.chat_completion( + messages=[ + SystemMessage(content="You are a helpful assistant."), + UserMessage(content="Please give me information about Michael Jordan."), + ], + stream=False, + **inference_settings["common_params"], + ) + + assert isinstance(response, ChatCompletionResponse) + assert isinstance(response.completion_message.content, str) + + with pytest.raises(ValidationError): + AnswerFormat.parse_raw(response.completion_message.content) + + @pytest.mark.asyncio async def test_chat_completion_streaming(inference_settings, sample_messages): inference_impl = inference_settings["impl"] diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 9d695698f..48f1df02f 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -3,6 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import json from typing import Tuple from llama_models.llama3.api.chat_format import ChatFormat @@ -70,11 +71,25 @@ def chat_completion_request_to_messages( and is_multimodal(model.core_model_id) ): # llama3.1 and llama3.2 multimodal models follow the same tool prompt format - return augment_messages_for_tools_llama_3_1(request) + messages = augment_messages_for_tools_llama_3_1(request) elif model.model_family == ModelFamily.llama3_2: - return augment_messages_for_tools_llama_3_2(request) + messages = augment_messages_for_tools_llama_3_2(request) else: - return request.messages + messages = request.messages + + if fmt := request.response_format: + if fmt.type == ResponseFormatType.json_schema.value: + messages.append( + UserMessage( + content=f"Please respond in JSON format with the schema: {json.dumps(fmt.schema)}" + ) + ) + elif fmt.type == ResponseFormatType.grammar.value: + raise NotImplementedError("Grammar response format not supported yet") + else: + raise ValueError(f"Unknown response format {fmt.type}") + + return messages def augment_messages_for_tools_llama_3_1(