From 40ba22f4c8f86c7d562b796a99e5023356b9e7ad Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 21 Oct 2024 19:14:52 -0700 Subject: [PATCH] add response format to signature --- llama_stack/apis/inference/client.py | 2 + llama_stack/apis/inference/inference.py | 9 ++- llama_stack/distribution/routers/routers.py | 2 + .../adapters/inference/bedrock/bedrock.py | 2 + .../inference/databricks/databricks.py | 2 + .../adapters/inference/fireworks/fireworks.py | 2 + .../adapters/inference/ollama/ollama.py | 2 + .../providers/adapters/inference/tgi/tgi.py | 2 + .../adapters/inference/together/together.py | 2 + .../providers/adapters/inference/vllm/vllm.py | 2 + .../agents/tests/test_chat_agent.py | 1 + .../meta_reference/inference/generation.py | 62 ++++++++++++------- .../meta_reference/inference/inference.py | 4 ++ .../tests/inference/test_inference.py | 11 +++- .../utils/inference/prompt_adapter.py | 20 +++++- 15 files changed, 93 insertions(+), 32 deletions(-) diff --git a/llama_stack/apis/inference/client.py b/llama_stack/apis/inference/client.py index 90636fa36..7359c6057 100644 --- a/llama_stack/apis/inference/client.py +++ b/llama_stack/apis/inference/client.py @@ -53,6 +53,7 @@ class InferenceClient(Inference): tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: @@ -63,6 +64,7 @@ class InferenceClient(Inference): tools=tools or [], tool_choice=tool_choice, tool_prompt_format=tool_prompt_format, + response_format=response_format, stream=stream, logprobs=logprobs, ) diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 8dc547b2d..f256901ee 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -74,13 +74,18 @@ class ChatCompletionResponseEvent(BaseModel): stop_reason: Optional[StopReason] = None +class ResponseFormatType(Enum): + json = "json" + grammar = "grammar" + + class JsonResponseFormat(BaseModel): - type: Literal["json"] = "json" + type: Literal[ResponseFormat.json.value] = ResponseFormat.json.value schema: Dict[str, Any] class GrammarResponseFormat(BaseModel): - type: Literal["grammar"] = "grammar" + type: Literal[ResponseFormat.grammar.value] = ResponseFormat.grammar.value bnf: Dict[str, Any] diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index a78e808d0..b33c5ec36 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -75,6 +75,7 @@ class InferenceRouter(Inference): model: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, @@ -102,6 +103,7 @@ class InferenceRouter(Inference): model: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: diff --git a/llama_stack/providers/adapters/inference/bedrock/bedrock.py b/llama_stack/providers/adapters/inference/bedrock/bedrock.py index 8440ecc20..3800c0496 100644 --- a/llama_stack/providers/adapters/inference/bedrock/bedrock.py +++ b/llama_stack/providers/adapters/inference/bedrock/bedrock.py @@ -52,6 +52,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): model: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: @@ -288,6 +289,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): model: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, # zero-shot tool definitions as input to the model tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, diff --git a/llama_stack/providers/adapters/inference/databricks/databricks.py b/llama_stack/providers/adapters/inference/databricks/databricks.py index 9f50ad227..4752e3fe4 100644 --- a/llama_stack/providers/adapters/inference/databricks/databricks.py +++ b/llama_stack/providers/adapters/inference/databricks/databricks.py @@ -53,6 +53,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): model: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: @@ -63,6 +64,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): model: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, diff --git a/llama_stack/providers/adapters/inference/fireworks/fireworks.py b/llama_stack/providers/adapters/inference/fireworks/fireworks.py index 537f3a6b4..1f598b277 100644 --- a/llama_stack/providers/adapters/inference/fireworks/fireworks.py +++ b/llama_stack/providers/adapters/inference/fireworks/fireworks.py @@ -56,6 +56,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): model: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: @@ -66,6 +67,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): model: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, diff --git a/llama_stack/providers/adapters/inference/ollama/ollama.py b/llama_stack/providers/adapters/inference/ollama/ollama.py index b19d54182..d4fe75cfa 100644 --- a/llama_stack/providers/adapters/inference/ollama/ollama.py +++ b/llama_stack/providers/adapters/inference/ollama/ollama.py @@ -93,6 +93,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): model: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: @@ -160,6 +161,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): model: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, diff --git a/llama_stack/providers/adapters/inference/tgi/tgi.py b/llama_stack/providers/adapters/inference/tgi/tgi.py index 3c610099c..85bbb34b2 100644 --- a/llama_stack/providers/adapters/inference/tgi/tgi.py +++ b/llama_stack/providers/adapters/inference/tgi/tgi.py @@ -71,6 +71,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): model: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: @@ -81,6 +82,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): model: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, diff --git a/llama_stack/providers/adapters/inference/together/together.py b/llama_stack/providers/adapters/inference/together/together.py index 8c73d75ec..f88e4c4c2 100644 --- a/llama_stack/providers/adapters/inference/together/together.py +++ b/llama_stack/providers/adapters/inference/together/together.py @@ -59,6 +59,7 @@ class TogetherInferenceAdapter( model: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: @@ -69,6 +70,7 @@ class TogetherInferenceAdapter( model: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, diff --git a/llama_stack/providers/adapters/inference/vllm/vllm.py b/llama_stack/providers/adapters/inference/vllm/vllm.py index a5934928a..dacf646b0 100644 --- a/llama_stack/providers/adapters/inference/vllm/vllm.py +++ b/llama_stack/providers/adapters/inference/vllm/vllm.py @@ -80,6 +80,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): model: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: @@ -90,6 +91,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): model: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, diff --git a/llama_stack/providers/impls/meta_reference/agents/tests/test_chat_agent.py b/llama_stack/providers/impls/meta_reference/agents/tests/test_chat_agent.py index 46423814b..782e0ca7d 100644 --- a/llama_stack/providers/impls/meta_reference/agents/tests/test_chat_agent.py +++ b/llama_stack/providers/impls/meta_reference/agents/tests/test_chat_agent.py @@ -26,6 +26,7 @@ class MockInferenceAPI: model: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = None, tool_prompt_format: Optional[ToolPromptFormat] = None, diff --git a/llama_stack/providers/impls/meta_reference/inference/generation.py b/llama_stack/providers/impls/meta_reference/inference/generation.py index cdf4ec79d..fc1c809ad 100644 --- a/llama_stack/providers/impls/meta_reference/inference/generation.py +++ b/llama_stack/providers/impls/meta_reference/inference/generation.py @@ -80,7 +80,7 @@ class Llama: def build( config: Union[ MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig - ] + ], ): """ Build a Llama instance by initializing and loading a model checkpoint. @@ -184,17 +184,11 @@ class Llama: echo: bool = False, include_stop_token: bool = False, print_input_tokens: bool = False, + logits_processor: Optional["LogitsProcessor"] = None, ) -> Generator: - parser = JsonSchemaParser(AnswerFormat.schema()) - tokenizer_data = build_token_enforcer_tokenizer_data( - self.tokenizer, self.args.vocab_size - ) - token_enforcer = TokenEnforcer(tokenizer_data, parser) - logits_processor = LogitsProcessor(token_enforcer) - params = self.model.params - if print_input_tokens or True: + if print_input_tokens: input_tokens = [ self.formatter.vision_token if t == 128256 else t for t in model_input.tokens @@ -266,10 +260,10 @@ class Llama: else: logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) - # print(f"{logits=}") - input_ids = tokens[0, :cur_pos].tolist() - # logits = logits_processor.process_logits(input_ids, logits) - # print(f"{logits=}") + if logits_processor is not None: + logits = logits_processor.process_logits( + tokens[0, :cur_pos].tolist(), logits + ) if temperature > 0: probs = torch.softmax(logits[:, -1] / temperature, dim=-1) @@ -342,7 +336,11 @@ class Llama: top_p=sampling_params.top_p, logprobs=bool(request.logprobs), include_stop_token=True, - echo=False, + logits_processor=get_logits_processor( + self.tokenizer, + self.args.vocab_size, + request.response_format, + ), ) def chat_completion( @@ -370,6 +368,11 @@ class Llama: top_p=sampling_params.top_p, logprobs=bool(request.logprobs), include_stop_token=True, + logits_processor=get_logits_processor( + self.tokenizer, + self.args.vocab_size, + request.response_format, + ), ) @@ -398,6 +401,27 @@ def sample_top_p(probs, p): return next_token +def get_logits_processor( + tokenizer: Tokenizer, + vocab_size: int, + response_format: Optional[ResponseFormat], +) -> Optional[LogitsProcessor]: + if response_format is None: + return None + + if response_format.type != ResponseFormatType.json: + raise ValueError(f"Unsupported response format type {response_format.type}") + + parser = JsonSchemaParser(response_format.schema) + data = TokenEnforcerTokenizerData( + _build_regular_tokens_list(tokenizer, vocab_size), + tokenizer.decode, + tokenizer.stop_tokens, + ) + token_enforcer = TokenEnforcer(data, parser) + return LogitsProcessor(token_enforcer) + + class LogitsProcessor: def __init__(self, token_enforcer: TokenEnforcer): self.token_enforcer = token_enforcer @@ -437,13 +461,3 @@ def _build_regular_tokens_list( is_word_start_token = len(decoded_after_0) > len(decoded_regular) regular_tokens.append((token_idx, decoded_after_0, is_word_start_token)) return regular_tokens - - -def build_token_enforcer_tokenizer_data( - tokenizer: Tokenizer, - vocab_size: int, -) -> TokenEnforcerTokenizerData: - regular_tokens = _build_regular_tokens_list(tokenizer, vocab_size) - return TokenEnforcerTokenizerData( - regular_tokens, tokenizer.decode, tokenizer.stop_tokens - ) diff --git a/llama_stack/providers/impls/meta_reference/inference/inference.py b/llama_stack/providers/impls/meta_reference/inference/inference.py index 34053343e..5588be6c0 100644 --- a/llama_stack/providers/impls/meta_reference/inference/inference.py +++ b/llama_stack/providers/impls/meta_reference/inference/inference.py @@ -71,6 +71,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): model: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: @@ -81,6 +82,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): model=model, content=content, sampling_params=sampling_params, + response_format=response_format, stream=stream, logprobs=logprobs, ) @@ -186,6 +188,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): model: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, @@ -203,6 +206,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): tools=tools or [], tool_choice=tool_choice, tool_prompt_format=tool_prompt_format, + response_format=response_format, stream=stream, logprobs=logprobs, ) diff --git a/llama_stack/providers/tests/inference/test_inference.py b/llama_stack/providers/tests/inference/test_inference.py index 86e37e39c..18c6327db 100644 --- a/llama_stack/providers/tests/inference/test_inference.py +++ b/llama_stack/providers/tests/inference/test_inference.py @@ -97,12 +97,12 @@ class AnswerFormat(BaseModel): @pytest.fixture def sample_messages(): - question = "Please give me information about Michael Jordan. You MUST answer using the following json schema: " - question_with_schema = f"{question}{AnswerFormat.schema_json()}" + question = "Please give me information about Michael Jordan." + # question_with_schema = f"{question}{AnswerFormat.schema_json()}" return [ SystemMessage(content="You are a helpful assistant."), # UserMessage(content="What's the weather like today?"), - UserMessage(content=question_with_schema), + UserMessage(content=question), ] @@ -183,10 +183,15 @@ async def test_completion(inference_settings): @pytest.mark.asyncio async def test_chat_completion_non_streaming(inference_settings, sample_messages): + print(AnswerFormat.schema_json()) + print(AnswerFormat.schema()) inference_impl = inference_settings["impl"] response = await inference_impl.chat_completion( messages=sample_messages, stream=False, + response_format=JsonResponseFormat( + schema=AnswerFormat.schema(), + ), **inference_settings["common_params"], ) diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 9d695698f..cab2e5169 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -70,11 +70,25 @@ def chat_completion_request_to_messages( and is_multimodal(model.core_model_id) ): # llama3.1 and llama3.2 multimodal models follow the same tool prompt format - return augment_messages_for_tools_llama_3_1(request) + messages = augment_messages_for_tools_llama_3_1(request) elif model.model_family == ModelFamily.llama3_2: - return augment_messages_for_tools_llama_3_2(request) + messages = augment_messages_for_tools_llama_3_2(request) else: - return request.messages + messages = request.messages + + if fmt := request.response_format: + if fmt.type == ResponseFormatType.json: + messages.append( + UserMessage( + content=f"Please response in JSON format with the schema: {json.dumps(fmt.schema)}" + ) + ) + elif fmt.type == ResponseFormatType.grammar: + raise NotImplementedError("Grammar response format not supported yet") + else: + raise ValueError(f"Unknown response format {fmt.type}") + + return messages def augment_messages_for_tools_llama_3_1(