pre-commit fixes

This commit is contained in:
Chantal D Gama Rose 2025-03-14 13:56:05 -07:00
parent 967dd0aa08
commit 7e211f8553
314 changed files with 5574 additions and 11369 deletions

View file

@ -8,7 +8,6 @@ from typing import AsyncGenerator, List, Optional, Union
from fireworks.client import Fireworks
from llama_stack import logcat
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
@ -33,6 +32,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
@ -55,6 +55,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import FireworksImplConfig
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
def __init__(self, config: FireworksImplConfig) -> None:
@ -68,8 +70,9 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
pass
def _get_api_key(self) -> str:
if self.config.api_key is not None:
return self.config.api_key.get_secret_value()
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
if config_api_key:
return config_api_key
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.fireworks_api_key:
@ -86,11 +89,13 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = CompletionRequest(
model=model.provider_resource_id,
@ -157,7 +162,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
@ -166,6 +171,8 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
@ -233,7 +240,8 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
"stream": request.stream,
**self._build_options(request.sampling_params, request.response_format, request.logprobs),
}
logcat.debug("inference", f"params to fireworks: {params}")
logger.debug(f"params to fireworks: {params}")
return params
async def embeddings(

View file

@ -24,10 +24,6 @@ MODEL_ENTRIES = [
"accounts/fireworks/models/llama-v3p1-405b-instruct",
CoreModelId.llama3_1_405b_instruct.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-v3p2-1b-instruct",
CoreModelId.llama3_2_1b_instruct.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-v3p2-3b-instruct",
CoreModelId.llama3_2_3b_instruct.value,