Add support for fireworks

This commit is contained in:
Ashwin Bharambe 2024-10-21 22:42:29 -07:00
parent cd84dee3e9
commit fe20a69f24
2 changed files with 37 additions and 4 deletions

View file

@ -67,10 +67,10 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
@ -81,6 +81,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
@ -117,6 +118,20 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
options = get_sampling_options(request)
options.setdefault("max_tokens", 512)
if fmt := request.response_format:
if fmt.type == ResponseFormatType.json_schema.value:
options["response_format"] = {
"type": "json_object",
"schema": fmt.schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
options["response_format"] = {
"type": "grammar",
"grammar": fmt.bnf,
}
else:
raise ValueError(f"Unknown response format {fmt.type}")
return {
"model": self.map_to_provider_model(request.model),
"prompt": prompt,