diff --git a/llama_stack/providers/remote/inference/fireworks/config.py b/llama_stack/providers/remote/inference/fireworks/config.py index 827bc620f..275ce99e7 100644 --- a/llama_stack/providers/remote/inference/fireworks/config.py +++ b/llama_stack/providers/remote/inference/fireworks/config.py @@ -4,6 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Optional + from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field @@ -14,7 +16,7 @@ class FireworksImplConfig(BaseModel): default="https://api.fireworks.ai/inference", description="The URL for the Fireworks server", ) - api_key: str = Field( - default="", + api_key: Optional[str] = Field( + default=None, description="The Fireworks.ai API Key", ) diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index 0070756d8..57e851c5b 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -9,12 +9,11 @@ from typing import AsyncGenerator from fireworks.client import Fireworks from llama_models.llama3.api.chat_format import ChatFormat - from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.tokenizer import Tokenizer from llama_stack.apis.inference import * # noqa: F403 - +from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_options, @@ -32,7 +31,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import FireworksImplConfig - FIREWORKS_SUPPORTED_MODELS = { "Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct", "Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct", @@ -41,10 +39,13 @@ FIREWORKS_SUPPORTED_MODELS = { "Llama3.2-3B-Instruct": "fireworks/llama-v3p2-3b-instruct", "Llama3.2-11B-Vision-Instruct": "fireworks/llama-v3p2-11b-vision-instruct", "Llama3.2-90B-Vision-Instruct": "fireworks/llama-v3p2-90b-vision-instruct", + "Llama-Guard-3-8B": "fireworks/llama-guard-3-8b", } -class FireworksInferenceAdapter(ModelRegistryHelper, Inference): +class FireworksInferenceAdapter( + ModelRegistryHelper, Inference, NeedsRequestProviderData +): def __init__(self, config: FireworksImplConfig) -> None: ModelRegistryHelper.__init__( self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS @@ -53,11 +54,24 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): self.formatter = ChatFormat(Tokenizer.get_instance()) async def initialize(self) -> None: - return + pass async def shutdown(self) -> None: pass + def _get_client(self) -> Fireworks: + fireworks_api_key = None + if self.config.api_key is not None: + fireworks_api_key = self.config.api_key + else: + provider_data = self.get_request_provider_data() + if provider_data is None or not provider_data.fireworks_api_key: + raise ValueError( + 'Pass Fireworks API Key in the header X-LlamaStack-ProviderData as { "fireworks_api_key": }' + ) + fireworks_api_key = provider_data.fireworks_api_key + return Fireworks(api_key=fireworks_api_key) + async def completion( self, model: str, @@ -75,28 +89,53 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): stream=stream, logprobs=logprobs, ) - client = Fireworks(api_key=self.config.api_key) if stream: - return self._stream_completion(request, client) + return self._stream_completion(request) else: - return await self._nonstream_completion(request, client) + return await self._nonstream_completion(request) async def _nonstream_completion( - self, request: CompletionRequest, client: Fireworks + self, request: CompletionRequest ) -> CompletionResponse: params = await self._get_params(request) - r = await client.completion.acreate(**params) + r = await self._get_client().completion.acreate(**params) return process_completion_response(r, self.formatter) - async def _stream_completion( - self, request: CompletionRequest, client: Fireworks - ) -> AsyncGenerator: + async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: params = await self._get_params(request) - stream = client.completion.acreate(**params) + # Wrapper for async generator similar + async def _to_async_generator(): + stream = self._get_client().completion.create(**params) + for chunk in stream: + yield chunk + + stream = _to_async_generator() async for chunk in process_completion_stream_response(stream, self.formatter): yield chunk + def _build_options( + self, sampling_params: Optional[SamplingParams], fmt: ResponseFormat + ) -> dict: + options = get_sampling_options(sampling_params) + options.setdefault("max_tokens", 512) + + if fmt: + if fmt.type == ResponseFormatType.json_schema.value: + options["response_format"] = { + "type": "json_object", + "schema": fmt.json_schema, + } + elif fmt.type == ResponseFormatType.grammar.value: + options["response_format"] = { + "type": "grammar", + "grammar": fmt.bnf, + } + else: + raise ValueError(f"Unknown response format {fmt.type}") + + return options + async def chat_completion( self, model: str, @@ -121,32 +160,35 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): logprobs=logprobs, ) - client = Fireworks(api_key=self.config.api_key) if stream: - return self._stream_chat_completion(request, client) + return self._stream_chat_completion(request) else: - return await self._nonstream_chat_completion(request, client) + return await self._nonstream_chat_completion(request) async def _nonstream_chat_completion( - self, request: ChatCompletionRequest, client: Fireworks + self, request: ChatCompletionRequest ) -> ChatCompletionResponse: params = await self._get_params(request) if "messages" in params: - r = await client.chat.completions.acreate(**params) + r = await self._get_client().chat.completions.acreate(**params) else: - r = await client.completion.acreate(**params) + r = await self._get_client().completion.acreate(**params) return process_chat_completion_response(r, self.formatter) async def _stream_chat_completion( - self, request: ChatCompletionRequest, client: Fireworks + self, request: ChatCompletionRequest ) -> AsyncGenerator: params = await self._get_params(request) - if "messages" in params: - stream = client.chat.completions.acreate(**params) - else: - stream = client.completion.acreate(**params) + async def _to_async_generator(): + if "messages" in params: + stream = await self._get_client().chat.completions.acreate(**params) + else: + stream = self._get_client().completion.create(**params) + for chunk in stream: + yield chunk + stream = _to_async_generator() async for chunk in process_chat_completion_stream_response( stream, self.formatter ): @@ -167,41 +209,22 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): input_dict["prompt"] = chat_completion_request_to_prompt( request, self.formatter ) - elif isinstance(request, CompletionRequest): + else: assert ( not media_present ), "Fireworks does not support media for Completion requests" input_dict["prompt"] = completion_request_to_prompt(request, self.formatter) - else: - raise ValueError(f"Unknown request type {type(request)}") # Fireworks always prepends with BOS if "prompt" in input_dict: if input_dict["prompt"].startswith("<|begin_of_text|>"): input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :] - options = get_sampling_options(request.sampling_params) - options.setdefault("max_tokens", 512) - - if fmt := request.response_format: - if fmt.type == ResponseFormatType.json_schema.value: - options["response_format"] = { - "type": "json_object", - "schema": fmt.json_schema, - } - elif fmt.type == ResponseFormatType.grammar.value: - options["response_format"] = { - "type": "grammar", - "grammar": fmt.bnf, - } - else: - raise ValueError(f"Unknown response format {fmt.type}") - return { "model": self.map_to_provider_model(request.model), **input_dict, "stream": request.stream, - **options, + **self._build_options(request.sampling_params, request.response_format), } async def embeddings(