From fe20a69f24515dafd186030f41a22906d876d001 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 21 Oct 2024 22:42:29 -0700 Subject: [PATCH] Add support for fireworks --- .../adapters/inference/fireworks/fireworks.py | 17 ++++++++++++- .../tests/inference/test_inference.py | 24 ++++++++++++++++--- 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/llama_stack/providers/adapters/inference/fireworks/fireworks.py b/llama_stack/providers/adapters/inference/fireworks/fireworks.py index 1f598b277..441f32166 100644 --- a/llama_stack/providers/adapters/inference/fireworks/fireworks.py +++ b/llama_stack/providers/adapters/inference/fireworks/fireworks.py @@ -67,10 +67,10 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): model: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), - response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: @@ -81,6 +81,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): tools=tools or [], tool_choice=tool_choice, tool_prompt_format=tool_prompt_format, + response_format=response_format, stream=stream, logprobs=logprobs, ) @@ -117,6 +118,20 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): options = get_sampling_options(request) options.setdefault("max_tokens", 512) + + if fmt := request.response_format: + if fmt.type == ResponseFormatType.json_schema.value: + options["response_format"] = { + "type": "json_object", + "schema": fmt.schema, + } + elif fmt.type == ResponseFormatType.grammar.value: + options["response_format"] = { + "type": "grammar", + "grammar": fmt.bnf, + } + else: + raise ValueError(f"Unknown response format {fmt.type}") return { "model": self.map_to_provider_model(request.model), "prompt": prompt, diff --git a/llama_stack/providers/tests/inference/test_inference.py b/llama_stack/providers/tests/inference/test_inference.py index 3e61337eb..c6355b2dd 100644 --- a/llama_stack/providers/tests/inference/test_inference.py +++ b/llama_stack/providers/tests/inference/test_inference.py @@ -10,7 +10,7 @@ import os import pytest import pytest_asyncio -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403 @@ -191,7 +191,10 @@ async def test_structured_output(inference_settings): params = inference_settings["common_params"] provider = inference_impl.routing_table.get_provider_impl(params["model"]) - if provider.__provider_id__ != "meta-reference": + if provider.__provider_spec__.provider_type not in ( + "meta-reference", + "remote::fireworks", + ): pytest.skip("Other inference providers don't support structured output yet") class AnswerFormat(BaseModel): @@ -207,7 +210,7 @@ async def test_structured_output(inference_settings): ], stream=False, response_format=JsonResponseFormat( - schema=AnswerFormat.schema(), + schema=AnswerFormat.model_json_schema(), ), **inference_settings["common_params"], ) @@ -222,6 +225,21 @@ async def test_structured_output(inference_settings): assert answer.year_of_birth == 1963 assert answer.num_seasons_in_nba == 15 + response = await inference_impl.chat_completion( + messages=[ + SystemMessage(content="You are a helpful assistant."), + UserMessage(content="Please give me information about Michael Jordan."), + ], + stream=False, + **inference_settings["common_params"], + ) + + assert isinstance(response, ChatCompletionResponse) + assert isinstance(response.completion_message.content, str) + + with pytest.raises(ValidationError): + AnswerFormat.parse_raw(response.completion_message.content) + @pytest.mark.asyncio async def test_chat_completion_streaming(inference_settings, sample_messages):