llama-stack/llama_stack/providers/adapters/inference/fireworks/fireworks.py
Ashwin Bharambe c06718fbd5
Add support for Structured Output / Guided decoding (#281)
Added support for structured output in the API and added a reference implementation for meta-reference.

A few notes:

* Two formats are specified in the API: Json schema and EBNF based grammar
* Implementation only supports Json for now
We use lm-format-enhancer to provide the implementation right now but may change this especially because BNF grammars aren't supported by that library.
Fireworks has support for structured output and Together has limited supported for it too. Subsequent PRs will add these changes. We would like all our inference providers to provide structured output for llama models since it is an extremely important and highly sought-after need by the developers.
2024-10-22 12:53:34 -07:00

147 lines
5.1 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator
from fireworks.client import Fireworks
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
)
from .config import FireworksImplConfig
FIREWORKS_SUPPORTED_MODELS = {
"Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct",
"Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct",
"Llama3.1-405B-Instruct": "fireworks/llama-v3p1-405b-instruct",
"Llama3.2-1B-Instruct": "fireworks/llama-v3p2-1b-instruct",
"Llama3.2-3B-Instruct": "fireworks/llama-v3p2-3b-instruct",
}
class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: FireworksImplConfig) -> None:
ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS
)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self) -> None:
return
async def shutdown(self) -> None:
pass
async def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError()
async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
client = Fireworks(api_key=self.config.api_key)
if stream:
return self._stream_chat_completion(request, client)
else:
return await self._nonstream_chat_completion(request, client)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: Fireworks
) -> ChatCompletionResponse:
params = self._get_params(request)
r = await client.completion.acreate(**params)
return process_chat_completion_response(r, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: Fireworks
) -> AsyncGenerator:
params = self._get_params(request)
stream = client.completion.acreate(**params)
async for chunk in process_chat_completion_stream_response(
stream, self.formatter
):
yield chunk
def _get_params(self, request: ChatCompletionRequest) -> dict:
prompt = chat_completion_request_to_prompt(request, self.formatter)
# Fireworks always prepends with BOS
if prompt.startswith("<|begin_of_text|>"):
prompt = prompt[len("<|begin_of_text|>") :]
options = get_sampling_options(request)
options.setdefault("max_tokens", 512)
if fmt := request.response_format:
if fmt.type == ResponseFormatType.json_schema.value:
options["response_format"] = {
"type": "json_object",
"schema": fmt.schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
options["response_format"] = {
"type": "grammar",
"grammar": fmt.bnf,
}
else:
raise ValueError(f"Unknown response format {fmt.type}")
return {
"model": self.map_to_provider_model(request.model),
"prompt": prompt,
"stream": request.stream,
**options,
}
async def embeddings(
self,
model: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()