diff --git a/llama_stack/apis/inference/client.py b/llama_stack/apis/inference/client.py index 90636fa36..7359c6057 100644 --- a/llama_stack/apis/inference/client.py +++ b/llama_stack/apis/inference/client.py @@ -53,6 +53,7 @@ class InferenceClient(Inference): tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: @@ -63,6 +64,7 @@ class InferenceClient(Inference): tools=tools or [], tool_choice=tool_choice, tool_prompt_format=tool_prompt_format, + response_format=response_format, stream=stream, logprobs=logprobs, ) diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 5895e528e..4ee01acae 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -74,11 +74,35 @@ class ChatCompletionResponseEvent(BaseModel): stop_reason: Optional[StopReason] = None +class ResponseFormatType(Enum): + json_schema = "json_schema" + grammar = "grammar" + + +class JsonResponseFormat(BaseModel): + type: Literal[ResponseFormatType.json_schema.value] = ( + ResponseFormatType.json_schema.value + ) + schema: Dict[str, Any] + + +class GrammarResponseFormat(BaseModel): + type: Literal[ResponseFormatType.grammar.value] = ResponseFormatType.grammar.value + bnf: Dict[str, Any] + + +ResponseFormat = Annotated[ + Union[JsonResponseFormat, GrammarResponseFormat], + Field(discriminator="type"), +] + + @json_schema_type class CompletionRequest(BaseModel): model: str content: InterleavedTextMedia sampling_params: Optional[SamplingParams] = SamplingParams() + response_format: Optional[ResponseFormat] = None stream: Optional[bool] = False logprobs: Optional[LogProbConfig] = None @@ -107,6 +131,7 @@ class BatchCompletionRequest(BaseModel): model: str content_batch: List[InterleavedTextMedia] sampling_params: Optional[SamplingParams] = SamplingParams() + response_format: Optional[ResponseFormat] = None logprobs: Optional[LogProbConfig] = None @@ -129,6 +154,7 @@ class ChatCompletionRequest(BaseModel): tool_prompt_format: Optional[ToolPromptFormat] = Field( default=ToolPromptFormat.json ) + response_format: Optional[ResponseFormat] = None stream: Optional[bool] = False logprobs: Optional[LogProbConfig] = None @@ -188,6 +214,7 @@ class Inference(Protocol): model: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ... @@ -204,6 +231,7 @@ class Inference(Protocol): tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ... diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index a78e808d0..26a0988ed 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -75,6 +75,7 @@ class InferenceRouter(Inference): model: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, @@ -88,6 +89,7 @@ class InferenceRouter(Inference): tools=tools or [], tool_choice=tool_choice, tool_prompt_format=tool_prompt_format, + response_format=response_format, stream=stream, logprobs=logprobs, ) @@ -102,6 +104,7 @@ class InferenceRouter(Inference): model: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: @@ -110,6 +113,7 @@ class InferenceRouter(Inference): model=model, content=content, sampling_params=sampling_params, + response_format=response_format, stream=stream, logprobs=logprobs, ) diff --git a/llama_stack/providers/adapters/inference/bedrock/bedrock.py b/llama_stack/providers/adapters/inference/bedrock/bedrock.py index 8440ecc20..3800c0496 100644 --- a/llama_stack/providers/adapters/inference/bedrock/bedrock.py +++ b/llama_stack/providers/adapters/inference/bedrock/bedrock.py @@ -52,6 +52,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): model: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: @@ -288,6 +289,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): model: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, # zero-shot tool definitions as input to the model tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, diff --git a/llama_stack/providers/adapters/inference/databricks/databricks.py b/llama_stack/providers/adapters/inference/databricks/databricks.py index 9f50ad227..4752e3fe4 100644 --- a/llama_stack/providers/adapters/inference/databricks/databricks.py +++ b/llama_stack/providers/adapters/inference/databricks/databricks.py @@ -53,6 +53,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): model: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: @@ -63,6 +64,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): model: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, diff --git a/llama_stack/providers/adapters/inference/fireworks/fireworks.py b/llama_stack/providers/adapters/inference/fireworks/fireworks.py index 537f3a6b4..441f32166 100644 --- a/llama_stack/providers/adapters/inference/fireworks/fireworks.py +++ b/llama_stack/providers/adapters/inference/fireworks/fireworks.py @@ -56,6 +56,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): model: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: @@ -69,6 +70,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: @@ -79,6 +81,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): tools=tools or [], tool_choice=tool_choice, tool_prompt_format=tool_prompt_format, + response_format=response_format, stream=stream, logprobs=logprobs, ) @@ -115,6 +118,20 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): options = get_sampling_options(request) options.setdefault("max_tokens", 512) + + if fmt := request.response_format: + if fmt.type == ResponseFormatType.json_schema.value: + options["response_format"] = { + "type": "json_object", + "schema": fmt.schema, + } + elif fmt.type == ResponseFormatType.grammar.value: + options["response_format"] = { + "type": "grammar", + "grammar": fmt.bnf, + } + else: + raise ValueError(f"Unknown response format {fmt.type}") return { "model": self.map_to_provider_model(request.model), "prompt": prompt, diff --git a/llama_stack/providers/adapters/inference/ollama/ollama.py b/llama_stack/providers/adapters/inference/ollama/ollama.py index b19d54182..d4fe75cfa 100644 --- a/llama_stack/providers/adapters/inference/ollama/ollama.py +++ b/llama_stack/providers/adapters/inference/ollama/ollama.py @@ -93,6 +93,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): model: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: @@ -160,6 +161,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): model: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, diff --git a/llama_stack/providers/adapters/inference/tgi/tgi.py b/llama_stack/providers/adapters/inference/tgi/tgi.py index 3c610099c..f19181320 100644 --- a/llama_stack/providers/adapters/inference/tgi/tgi.py +++ b/llama_stack/providers/adapters/inference/tgi/tgi.py @@ -71,6 +71,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): model: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: @@ -84,6 +85,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: @@ -94,6 +96,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): tools=tools or [], tool_choice=tool_choice, tool_prompt_format=tool_prompt_format, + response_format=response_format, stream=stream, logprobs=logprobs, ) @@ -148,6 +151,17 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): self.max_tokens - input_tokens - 1, ) options = get_sampling_options(request) + if fmt := request.response_format: + if fmt.type == ResponseFormatType.json_schema.value: + options["grammar"] = { + "type": "json", + "value": fmt.schema, + } + elif fmt.type == ResponseFormatType.grammar.value: + raise ValueError("Grammar response format not supported yet") + else: + raise ValueError(f"Unexpected response format: {fmt.type}") + return dict( prompt=prompt, stream=request.stream, diff --git a/llama_stack/providers/adapters/inference/together/together.py b/llama_stack/providers/adapters/inference/together/together.py index 8c73d75ec..f88e4c4c2 100644 --- a/llama_stack/providers/adapters/inference/together/together.py +++ b/llama_stack/providers/adapters/inference/together/together.py @@ -59,6 +59,7 @@ class TogetherInferenceAdapter( model: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: @@ -69,6 +70,7 @@ class TogetherInferenceAdapter( model: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, diff --git a/llama_stack/providers/adapters/inference/vllm/vllm.py b/llama_stack/providers/adapters/inference/vllm/vllm.py index a5934928a..dacf646b0 100644 --- a/llama_stack/providers/adapters/inference/vllm/vllm.py +++ b/llama_stack/providers/adapters/inference/vllm/vllm.py @@ -80,6 +80,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): model: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: @@ -90,6 +91,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): model: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, diff --git a/llama_stack/providers/impls/meta_reference/agents/tests/test_chat_agent.py b/llama_stack/providers/impls/meta_reference/agents/tests/test_chat_agent.py index 46423814b..782e0ca7d 100644 --- a/llama_stack/providers/impls/meta_reference/agents/tests/test_chat_agent.py +++ b/llama_stack/providers/impls/meta_reference/agents/tests/test_chat_agent.py @@ -26,6 +26,7 @@ class MockInferenceAPI: model: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = None, tool_prompt_format: Optional[ToolPromptFormat] = None, diff --git a/llama_stack/providers/impls/meta_reference/inference/generation.py b/llama_stack/providers/impls/meta_reference/inference/generation.py index 9ca128176..b424a9347 100644 --- a/llama_stack/providers/impls/meta_reference/inference/generation.py +++ b/llama_stack/providers/impls/meta_reference/inference/generation.py @@ -8,11 +8,12 @@ # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. import json +import math import os import sys import time from pathlib import Path -from typing import Generator, List, Optional, Union +from typing import Generator, List, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -34,6 +35,9 @@ from pydantic import BaseModel from termcolor import cprint from llama_stack.apis.inference import * # noqa: F403 + +from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData + from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_messages, @@ -67,7 +71,7 @@ class Llama: def build( config: Union[ MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig - ] + ], ): """ Build a Llama instance by initializing and loading a model checkpoint. @@ -171,6 +175,7 @@ class Llama: echo: bool = False, include_stop_token: bool = False, print_input_tokens: bool = False, + logits_processor: Optional["LogitsProcessor"] = None, ) -> Generator: params = self.model.params @@ -246,6 +251,9 @@ class Llama: else: logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) + if logits_processor is not None: + logits = logits_processor.process_logits(tokens[:, :cur_pos], logits) + if temperature > 0: probs = torch.softmax(logits[:, -1] / temperature, dim=-1) next_token = sample_top_p(probs, top_p) @@ -317,7 +325,11 @@ class Llama: top_p=sampling_params.top_p, logprobs=bool(request.logprobs), include_stop_token=True, - echo=False, + logits_processor=get_logits_processor( + self.tokenizer, + self.args.vocab_size, + request.response_format, + ), ) def chat_completion( @@ -345,6 +357,11 @@ class Llama: top_p=sampling_params.top_p, logprobs=bool(request.logprobs), include_stop_token=True, + logits_processor=get_logits_processor( + self.tokenizer, + self.args.vocab_size, + request.response_format, + ), ) @@ -371,3 +388,64 @@ def sample_top_p(probs, p): next_token = torch.multinomial(probs_sort, num_samples=1) next_token = torch.gather(probs_idx, -1, next_token) return next_token + + +class LogitsProcessor: + def __init__(self, token_enforcer: TokenEnforcer): + self.token_enforcer = token_enforcer + self.mask: Optional[torch.Tensor] = None + + def process_logits( + self, tokens: torch.Tensor, scores: torch.Tensor + ) -> torch.Tensor: + token_sequence = tokens[0, :].tolist() + allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence) + + if self.mask is not None: + self.mask.fill_(-math.inf) + else: + self.mask = torch.full_like(scores, -math.inf) + + self.mask[:, :, allowed_tokens] = 0 + scores = scores + self.mask + return scores + + +def get_logits_processor( + tokenizer: Tokenizer, + vocab_size: int, + response_format: Optional[ResponseFormat], +) -> Optional["LogitsProcessor"]: + if response_format is None: + return None + + if response_format.type != ResponseFormatType.json_schema.value: + raise ValueError(f"Unsupported response format type {response_format.type}") + + parser = JsonSchemaParser(response_format.schema) + data = TokenEnforcerTokenizerData( + _build_regular_tokens_list(tokenizer, vocab_size), + tokenizer.decode, + tokenizer.stop_tokens, + ) + token_enforcer = TokenEnforcer(data, parser) + return LogitsProcessor(token_enforcer) + + +def _build_regular_tokens_list( + tokenizer: Tokenizer, vocab_size: int +) -> List[Tuple[int, str, bool]]: + token_0 = tokenizer.encode("0", bos=False, eos=False)[-1] + regular_tokens = [] + + special_token_ids = set(tokenizer.special_tokens.values()) + for token_idx in range(vocab_size): + if token_idx in special_token_ids: + continue + + # We prepend token 0 and skip the first letter of the result to get a space if the token is a start word. + decoded_after_0 = tokenizer.decode([token_0, token_idx])[1:] + decoded_regular = tokenizer.decode([token_idx]) + is_word_start_token = len(decoded_after_0) > len(decoded_regular) + regular_tokens.append((token_idx, decoded_after_0, is_word_start_token)) + return regular_tokens diff --git a/llama_stack/providers/impls/meta_reference/inference/inference.py b/llama_stack/providers/impls/meta_reference/inference/inference.py index 34053343e..5588be6c0 100644 --- a/llama_stack/providers/impls/meta_reference/inference/inference.py +++ b/llama_stack/providers/impls/meta_reference/inference/inference.py @@ -71,6 +71,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): model: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: @@ -81,6 +82,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): model=model, content=content, sampling_params=sampling_params, + response_format=response_format, stream=stream, logprobs=logprobs, ) @@ -186,6 +188,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): model: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, @@ -203,6 +206,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): tools=tools or [], tool_choice=tool_choice, tool_prompt_format=tool_prompt_format, + response_format=response_format, stream=stream, logprobs=logprobs, ) diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 5a09b6af5..6f8bc2c6e 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -9,36 +9,36 @@ from typing import List from llama_stack.distribution.datatypes import * # noqa: F403 +META_REFERENCE_DEPS = [ + "accelerate", + "blobfile", + "fairscale", + "torch", + "torchvision", + "transformers", + "zmq", + "lm-format-enforcer", +] + + def available_providers() -> List[ProviderSpec]: return [ InlineProviderSpec( api=Api.inference, provider_type="meta-reference", - pip_packages=[ - "accelerate", - "blobfile", - "fairscale", - "torch", - "torchvision", - "transformers", - "zmq", - ], + pip_packages=META_REFERENCE_DEPS, module="llama_stack.providers.impls.meta_reference.inference", config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceInferenceConfig", ), InlineProviderSpec( api=Api.inference, provider_type="meta-reference-quantized", - pip_packages=[ - "accelerate", - "blobfile", - "fairscale", - "fbgemm-gpu==0.8.0", - "torch", - "torchvision", - "transformers", - "zmq", - ], + pip_packages=( + META_REFERENCE_DEPS + + [ + "fbgemm-gpu==0.8.0", + ] + ), module="llama_stack.providers.impls.meta_reference.inference", config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceQuantizedInferenceConfig", ), diff --git a/llama_stack/providers/tests/inference/test_inference.py b/llama_stack/providers/tests/inference/test_inference.py index afec9a837..e89f672b1 100644 --- a/llama_stack/providers/tests/inference/test_inference.py +++ b/llama_stack/providers/tests/inference/test_inference.py @@ -10,6 +10,8 @@ import os import pytest import pytest_asyncio +from pydantic import BaseModel, ValidationError + from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403 @@ -183,6 +185,63 @@ async def test_chat_completion_non_streaming(inference_settings, sample_messages assert len(response.completion_message.content) > 0 +@pytest.mark.asyncio +async def test_structured_output(inference_settings): + inference_impl = inference_settings["impl"] + params = inference_settings["common_params"] + + provider = inference_impl.routing_table.get_provider_impl(params["model"]) + if provider.__provider_spec__.provider_type not in ( + "meta-reference", + "remote::fireworks", + "remote::tgi", + ): + pytest.skip("Other inference providers don't support structured output yet") + + class AnswerFormat(BaseModel): + first_name: str + last_name: str + year_of_birth: int + num_seasons_in_nba: int + + response = await inference_impl.chat_completion( + messages=[ + SystemMessage(content="You are a helpful assistant."), + UserMessage(content="Please give me information about Michael Jordan."), + ], + stream=False, + response_format=JsonResponseFormat( + schema=AnswerFormat.model_json_schema(), + ), + **inference_settings["common_params"], + ) + + assert isinstance(response, ChatCompletionResponse) + assert response.completion_message.role == "assistant" + assert isinstance(response.completion_message.content, str) + + answer = AnswerFormat.parse_raw(response.completion_message.content) + assert answer.first_name == "Michael" + assert answer.last_name == "Jordan" + assert answer.year_of_birth == 1963 + assert answer.num_seasons_in_nba == 15 + + response = await inference_impl.chat_completion( + messages=[ + SystemMessage(content="You are a helpful assistant."), + UserMessage(content="Please give me information about Michael Jordan."), + ], + stream=False, + **inference_settings["common_params"], + ) + + assert isinstance(response, ChatCompletionResponse) + assert isinstance(response.completion_message.content, str) + + with pytest.raises(ValidationError): + AnswerFormat.parse_raw(response.completion_message.content) + + @pytest.mark.asyncio async def test_chat_completion_streaming(inference_settings, sample_messages): inference_impl = inference_settings["impl"] diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 9d695698f..48f1df02f 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -3,6 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import json from typing import Tuple from llama_models.llama3.api.chat_format import ChatFormat @@ -70,11 +71,25 @@ def chat_completion_request_to_messages( and is_multimodal(model.core_model_id) ): # llama3.1 and llama3.2 multimodal models follow the same tool prompt format - return augment_messages_for_tools_llama_3_1(request) + messages = augment_messages_for_tools_llama_3_1(request) elif model.model_family == ModelFamily.llama3_2: - return augment_messages_for_tools_llama_3_2(request) + messages = augment_messages_for_tools_llama_3_2(request) else: - return request.messages + messages = request.messages + + if fmt := request.response_format: + if fmt.type == ResponseFormatType.json_schema.value: + messages.append( + UserMessage( + content=f"Please respond in JSON format with the schema: {json.dumps(fmt.schema)}" + ) + ) + elif fmt.type == ResponseFormatType.grammar.value: + raise NotImplementedError("Grammar response format not supported yet") + else: + raise ValueError(f"Unknown response format {fmt.type}") + + return messages def augment_messages_for_tools_llama_3_1(