llama-stack/llama_stack/providers/remote/inference/fireworks/fireworks.py
ehhuang c9ab72fa82
Support sys_prompt behavior in inference (#937)
# What does this PR do?

The current default system prompt for llama3.2 tends to overindex on
tool calling and doesn't work well when the prompt does not require tool
calling.

This PR adds an option to override the default system prompt, and
organizes tool-related configs into a new config object.

- [ ] Addresses issue (#issue)


## Test Plan

python -m unittest
llama_stack.providers.tests.inference.test_prompt_adapter


## Sources

Please link relevant resources if necessary.


## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Ran pre-commit to handle lint / formatting issues.
- [ ] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [ ] Updated relevant documentation.
- [ ] Wrote necessary unit or integration tests.
---
[//]: # (BEGIN SAPLING FOOTER)
Stack created with [Sapling](https://sapling-scm.com). Best reviewed
with
[ReviewStack](https://reviewstack.dev/meta-llama/llama-stack/pull/937).
* #938
* __->__ #937
2025-02-03 23:35:16 -08:00

299 lines
11 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator, List, Optional, Union
from fireworks.client import Fireworks
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
CompletionResponse,
EmbeddingsResponse,
Inference,
LogProbConfig,
Message,
ResponseFormat,
ResponseFormatType,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
content_has_media,
interleaved_content_as_str,
request_has_media,
)
from .config import FireworksImplConfig
MODEL_ALIASES = [
build_model_alias(
"accounts/fireworks/models/llama-v3p1-8b-instruct",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_alias(
"accounts/fireworks/models/llama-v3p1-70b-instruct",
CoreModelId.llama3_1_70b_instruct.value,
),
build_model_alias(
"accounts/fireworks/models/llama-v3p1-405b-instruct",
CoreModelId.llama3_1_405b_instruct.value,
),
build_model_alias(
"accounts/fireworks/models/llama-v3p2-1b-instruct",
CoreModelId.llama3_2_1b_instruct.value,
),
build_model_alias(
"accounts/fireworks/models/llama-v3p2-3b-instruct",
CoreModelId.llama3_2_3b_instruct.value,
),
build_model_alias(
"accounts/fireworks/models/llama-v3p2-11b-vision-instruct",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_model_alias(
"accounts/fireworks/models/llama-v3p2-90b-vision-instruct",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_model_alias(
"accounts/fireworks/models/llama-v3p3-70b-instruct",
CoreModelId.llama3_3_70b_instruct.value,
),
build_model_alias(
"accounts/fireworks/models/llama-guard-3-8b",
CoreModelId.llama_guard_3_8b.value,
),
build_model_alias(
"accounts/fireworks/models/llama-guard-3-11b-vision",
CoreModelId.llama_guard_3_11b_vision.value,
),
]
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
def __init__(self, config: FireworksImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ALIASES)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
def _get_api_key(self) -> str:
if self.config.api_key is not None:
return self.config.api_key.get_secret_value()
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.fireworks_api_key:
raise ValueError(
'Pass Fireworks API Key in the header X-LlamaStack-Provider-Data as { "fireworks_api_key": <your api key>}'
)
return provider_data.fireworks_api_key
def _get_client(self) -> Fireworks:
fireworks_api_key = self._get_api_key()
return Fireworks(api_key=fireworks_api_key)
async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = CompletionRequest(
model=model.provider_resource_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_completion(request)
else:
return await self._nonstream_completion(request)
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
params = await self._get_params(request)
r = await self._get_client().completion.acreate(**params)
return process_completion_response(r, self.formatter)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
# Wrapper for async generator similar
async def _to_async_generator():
stream = self._get_client().completion.create(**params)
for chunk in stream:
yield chunk
stream = _to_async_generator()
async for chunk in process_completion_stream_response(stream, self.formatter):
yield chunk
def _build_options(
self,
sampling_params: Optional[SamplingParams],
fmt: ResponseFormat,
logprobs: Optional[LogProbConfig],
) -> dict:
options = get_sampling_options(sampling_params)
options.setdefault("max_tokens", 512)
if fmt:
if fmt.type == ResponseFormatType.json_schema.value:
options["response_format"] = {
"type": "json_object",
"schema": fmt.json_schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
options["response_format"] = {
"type": "grammar",
"grammar": fmt.bnf,
}
else:
raise ValueError(f"Unknown response format {fmt.type}")
if logprobs and logprobs.top_k:
options["logprobs"] = logprobs.top_k
if options["logprobs"] <= 0 or options["logprobs"] >= 5:
raise ValueError("Required range: 0 < top_k < 5")
return options
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
if stream:
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
if "messages" in params:
r = await self._get_client().chat.completions.acreate(**params)
else:
r = await self._get_client().completion.acreate(**params)
return process_chat_completion_response(r, self.formatter)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
async def _to_async_generator():
if "messages" in params:
stream = self._get_client().chat.completions.acreate(**params)
else:
stream = self._get_client().completion.acreate(**params)
async for chunk in stream:
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, self.formatter):
yield chunk
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
input_dict = {}
media_present = request_has_media(request)
if isinstance(request, ChatCompletionRequest):
if media_present:
input_dict["messages"] = [
await convert_message_to_openai_dict(m, download=True) for m in request.messages
]
else:
input_dict["prompt"] = await chat_completion_request_to_prompt(
request, self.get_llama_model(request.model), self.formatter
)
else:
assert not media_present, "Fireworks does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request, self.formatter)
# Fireworks always prepends with BOS
if "prompt" in input_dict:
if input_dict["prompt"].startswith("<|begin_of_text|>"):
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
return {
"model": request.model,
**input_dict,
"stream": request.stream,
**self._build_options(request.sampling_params, request.response_format, request.logprobs),
}
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
kwargs = {}
if model.metadata.get("embedding_dimensions"):
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
assert all(not content_has_media(content) for content in contents), (
"Fireworks does not support media for embeddings"
)
response = self._get_client().embeddings.create(
model=model.provider_resource_id,
input=[interleaved_content_as_str(content) for content in contents],
**kwargs,
)
embeddings = [data.embedding for data in response.data]
return EmbeddingsResponse(embeddings=embeddings)