fix: solve ruff B008 warnings (#1444)

# What does this PR do?

The commit addresses the Ruff warning B008 by refactoring the code to
avoid calling SamplingParams() directly in function argument defaults.
Instead, it either uses Field(default_factory=SamplingParams) for
Pydantic models or sets the default to None and instantiates
SamplingParams inside the function body when the argument is None.

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-03-07 01:48:35 +01:00 committed by GitHub
parent 3a454be9b2
commit 803bf0e029
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 93 additions and 43 deletions

View file

@ -199,7 +199,7 @@ AgentToolGroup = register_schema(
class AgentConfigCommon(BaseModel): class AgentConfigCommon(BaseModel):
sampling_params: Optional[SamplingParams] = SamplingParams() sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
input_shields: Optional[List[str]] = Field(default_factory=list) input_shields: Optional[List[str]] = Field(default_factory=list)
output_shields: Optional[List[str]] = Field(default_factory=list) output_shields: Optional[List[str]] = Field(default_factory=list)

View file

@ -40,7 +40,7 @@ class BatchInference(Protocol):
self, self,
model: str, model: str,
content_batch: List[InterleavedContent], content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> BatchCompletionResponse: ... ) -> BatchCompletionResponse: ...
@ -50,7 +50,7 @@ class BatchInference(Protocol):
self, self,
model: str, model: str,
messages_batch: List[List[Message]], messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
# zero-shot tool definitions as input to the model # zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = list, tools: Optional[List[ToolDefinition]] = list,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,

View file

@ -278,7 +278,7 @@ ResponseFormat = register_schema(
class CompletionRequest(BaseModel): class CompletionRequest(BaseModel):
model: str model: str
content: InterleavedContent content: InterleavedContent
sampling_params: Optional[SamplingParams] = SamplingParams() sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
response_format: Optional[ResponseFormat] = None response_format: Optional[ResponseFormat] = None
stream: Optional[bool] = False stream: Optional[bool] = False
logprobs: Optional[LogProbConfig] = None logprobs: Optional[LogProbConfig] = None
@ -357,7 +357,7 @@ class ToolConfig(BaseModel):
class ChatCompletionRequest(BaseModel): class ChatCompletionRequest(BaseModel):
model: str model: str
messages: List[Message] messages: List[Message]
sampling_params: Optional[SamplingParams] = SamplingParams() sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
tools: Optional[List[ToolDefinition]] = Field(default_factory=list) tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
tool_config: Optional[ToolConfig] = Field(default_factory=ToolConfig) tool_config: Optional[ToolConfig] = Field(default_factory=ToolConfig)
@ -444,7 +444,7 @@ class Inference(Protocol):
self, self,
model_id: str, model_id: str,
content: InterleavedContent, content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
@ -467,7 +467,7 @@ class Inference(Protocol):
self, self,
model_id: str, model_id: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None, tool_prompt_format: Optional[ToolPromptFormat] = None,

View file

@ -217,7 +217,7 @@ class InferenceRouter(Inference):
self, self,
model_id: str, model_id: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = None, tool_choice: Optional[ToolChoice] = None,
@ -230,6 +230,8 @@ class InferenceRouter(Inference):
"core", "core",
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}", f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
) )
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.routing_table.get_model(model_id) model = await self.routing_table.get_model(model_id)
if model is None: if model is None:
raise ValueError(f"Model '{model_id}' not found") raise ValueError(f"Model '{model_id}' not found")
@ -320,11 +322,13 @@ class InferenceRouter(Inference):
self, self,
model_id: str, model_id: str,
content: InterleavedContent, content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
logcat.debug( logcat.debug(
"core", "core",
f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}", f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}",

View file

@ -136,11 +136,13 @@ class MetaReferenceInferenceImpl(
self, self,
model_id: str, model_id: str,
content: InterleavedContent, content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
if sampling_params is None:
sampling_params = SamplingParams()
if logprobs: if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}" assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
@ -244,7 +246,7 @@ class MetaReferenceInferenceImpl(
self, self,
model_id: str, model_id: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
@ -253,6 +255,8 @@ class MetaReferenceInferenceImpl(
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
if logprobs: if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}" assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"

View file

@ -53,7 +53,7 @@ class SentenceTransformersInferenceImpl(
self, self,
model_id: str, model_id: str,
content: str, content: str,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
@ -64,7 +64,7 @@ class SentenceTransformersInferenceImpl(
self, self,
model_id: str, model_id: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,

View file

@ -143,7 +143,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
self, self,
model_id: str, model_id: str,
content: InterleavedContent, content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
@ -154,7 +154,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
self, self,
model_id: str, model_id: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None, tool_prompt_format: Optional[ToolPromptFormat] = None,
@ -163,6 +163,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk: ) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
if sampling_params is None:
sampling_params = SamplingParams()
assert self.engine is not None assert self.engine is not None
request = ChatCompletionRequest( request = ChatCompletionRequest(

View file

@ -72,7 +72,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
self, self,
model_id: str, model_id: str,
content: InterleavedContent, content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
@ -83,7 +83,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
self, self,
model_id: str, model_id: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
@ -92,6 +92,8 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=model.provider_resource_id, model=model.provider_resource_id,

View file

@ -72,11 +72,13 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
self, self,
model_id: str, model_id: str,
content: InterleavedContent, content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
request = CompletionRequest( request = CompletionRequest(
model=model.provider_resource_id, model=model.provider_resource_id,
@ -112,7 +114,7 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
self, self,
model_id: str, model_id: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None, tool_prompt_format: Optional[ToolPromptFormat] = None,
@ -121,6 +123,8 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=model.provider_resource_id, model=model.provider_resource_id,

View file

@ -71,7 +71,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
self, self,
model: str, model: str,
content: InterleavedContent, content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
@ -82,7 +82,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
self, self,
model: str, model: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
@ -91,6 +91,8 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=model, model=model,
messages=messages, messages=messages,

View file

@ -86,11 +86,13 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
self, self,
model_id: str, model_id: str,
content: InterleavedContent, content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
request = CompletionRequest( request = CompletionRequest(
model=model.provider_resource_id, model=model.provider_resource_id,
@ -157,7 +159,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
self, self,
model_id: str, model_id: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None, tool_prompt_format: Optional[ToolPromptFormat] = None,
@ -166,6 +168,8 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=model.provider_resource_id, model=model.provider_resource_id,

View file

@ -93,11 +93,13 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
self, self,
model_id: str, model_id: str,
content: InterleavedContent, content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: ) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
if sampling_params is None:
sampling_params = SamplingParams()
if content_has_media(content): if content_has_media(content):
raise NotImplementedError("Media is not supported") raise NotImplementedError("Media is not supported")
@ -188,7 +190,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
self, self,
model_id: str, model_id: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
@ -197,6 +199,8 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
if sampling_params is None:
sampling_params = SamplingParams()
if tool_prompt_format: if tool_prompt_format:
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring", stacklevel=2) warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring", stacklevel=2)

View file

@ -90,11 +90,13 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
self, self,
model_id: str, model_id: str,
content: InterleavedContent, content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
request = CompletionRequest( request = CompletionRequest(
model=model.provider_resource_id, model=model.provider_resource_id,
@ -145,7 +147,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
self, self,
model_id: str, model_id: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
@ -154,6 +156,8 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=model.provider_resource_id, model=model.provider_resource_id,

View file

@ -81,11 +81,13 @@ class PassthroughInferenceAdapter(Inference):
self, self,
model_id: str, model_id: str,
content: InterleavedContent, content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
client = self._get_client() client = self._get_client()
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
@ -107,7 +109,7 @@ class PassthroughInferenceAdapter(Inference):
self, self,
model_id: str, model_id: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None, tool_prompt_format: Optional[ToolPromptFormat] = None,
@ -116,6 +118,8 @@ class PassthroughInferenceAdapter(Inference):
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
client = self._get_client() client = self._get_client()
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)

View file

@ -54,7 +54,7 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
self, self,
model: str, model: str,
content: InterleavedContent, content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
@ -65,7 +65,7 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
self, self,
model: str, model: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
@ -74,6 +74,8 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=model, model=model,
messages=messages, messages=messages,

View file

@ -74,7 +74,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
self, self,
model_id: str, model_id: str,
content: InterleavedContent, content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
@ -85,7 +85,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
self, self,
model_id: str, model_id: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
@ -94,6 +94,8 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest( request = ChatCompletionRequest(

View file

@ -98,11 +98,13 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
self, self,
model_id: str, model_id: str,
content: InterleavedContent, content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
request = CompletionRequest( request = CompletionRequest(
model=model.provider_resource_id, model=model.provider_resource_id,
@ -201,7 +203,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
self, self,
model_id: str, model_id: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None, tool_prompt_format: Optional[ToolPromptFormat] = None,
@ -210,6 +212,8 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=model.provider_resource_id, model=model.provider_resource_id,

View file

@ -70,11 +70,13 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
self, self,
model_id: str, model_id: str,
content: InterleavedContent, content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
request = CompletionRequest( request = CompletionRequest(
model=model.provider_resource_id, model=model.provider_resource_id,
@ -151,7 +153,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
self, self,
model_id: str, model_id: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None, tool_prompt_format: Optional[ToolPromptFormat] = None,
@ -160,6 +162,8 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=model.provider_resource_id, model=model.provider_resource_id,

View file

@ -241,11 +241,13 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
self, self,
model_id: str, model_id: str,
content: InterleavedContent, content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
request = CompletionRequest( request = CompletionRequest(
model=model.provider_resource_id, model=model.provider_resource_id,
@ -264,7 +266,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
self, self,
model_id: str, model_id: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
@ -273,6 +275,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
# This is to be consistent with OpenAI API and support vLLM <= v0.6.3 # This is to be consistent with OpenAI API and support vLLM <= v0.6.3
# References: # References:

View file

@ -74,7 +74,7 @@ class LiteLLMOpenAIMixin(
self, self,
model_id: str, model_id: str,
content: InterleavedContent, content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
@ -85,7 +85,7 @@ class LiteLLMOpenAIMixin(
self, self,
model_id: str, model_id: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None, tool_prompt_format: Optional[ToolPromptFormat] = None,
@ -94,6 +94,8 @@ class LiteLLMOpenAIMixin(
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=model.provider_resource_id, model=model.provider_resource_id,

View file

@ -136,8 +136,6 @@ ignore = [
# These are the additional ones we started ignoring after moving to ruff. We should look into each one of them later. # These are the additional ones we started ignoring after moving to ruff. We should look into each one of them later.
"C901", # Complexity of the function is too high "C901", # Complexity of the function is too high
# these ignores are from flake8-bugbear; please fix!
"B008",
] ]
[tool.mypy] [tool.mypy]