From 803bf0e029098eea38ac59f5aab7c53d5bc79a3d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Fri, 7 Mar 2025 01:48:35 +0100 Subject: [PATCH] fix: solve ruff B008 warnings (#1444) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? The commit addresses the Ruff warning B008 by refactoring the code to avoid calling SamplingParams() directly in function argument defaults. Instead, it either uses Field(default_factory=SamplingParams) for Pydantic models or sets the default to None and instantiates SamplingParams inside the function body when the argument is None. Signed-off-by: Sébastien Han --- llama_stack/apis/agents/agents.py | 2 +- llama_stack/apis/batch_inference/batch_inference.py | 4 ++-- llama_stack/apis/inference/inference.py | 8 ++++---- llama_stack/distribution/routers/routers.py | 8 ++++++-- .../inline/inference/meta_reference/inference.py | 8 ++++++-- .../sentence_transformers/sentence_transformers.py | 4 ++-- llama_stack/providers/inline/inference/vllm/vllm.py | 6 ++++-- llama_stack/providers/remote/inference/bedrock/bedrock.py | 6 ++++-- .../providers/remote/inference/cerebras/cerebras.py | 8 ++++++-- .../providers/remote/inference/databricks/databricks.py | 6 ++++-- .../providers/remote/inference/fireworks/fireworks.py | 8 ++++++-- llama_stack/providers/remote/inference/nvidia/nvidia.py | 8 ++++++-- llama_stack/providers/remote/inference/ollama/ollama.py | 8 ++++++-- .../providers/remote/inference/passthrough/passthrough.py | 8 ++++++-- llama_stack/providers/remote/inference/runpod/runpod.py | 6 ++++-- .../providers/remote/inference/sambanova/sambanova.py | 6 ++++-- llama_stack/providers/remote/inference/tgi/tgi.py | 8 ++++++-- .../providers/remote/inference/together/together.py | 8 ++++++-- llama_stack/providers/remote/inference/vllm/vllm.py | 8 ++++++-- .../providers/utils/inference/litellm_openai_mixin.py | 6 ++++-- pyproject.toml | 2 -- 21 files changed, 93 insertions(+), 43 deletions(-) diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index dbe35ac09..af4b0ba77 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -199,7 +199,7 @@ AgentToolGroup = register_schema( class AgentConfigCommon(BaseModel): - sampling_params: Optional[SamplingParams] = SamplingParams() + sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams) input_shields: Optional[List[str]] = Field(default_factory=list) output_shields: Optional[List[str]] = Field(default_factory=list) diff --git a/llama_stack/apis/batch_inference/batch_inference.py b/llama_stack/apis/batch_inference/batch_inference.py index 0fa5c78ce..330a683ba 100644 --- a/llama_stack/apis/batch_inference/batch_inference.py +++ b/llama_stack/apis/batch_inference/batch_inference.py @@ -40,7 +40,7 @@ class BatchInference(Protocol): self, model: str, content_batch: List[InterleavedContent], - sampling_params: Optional[SamplingParams] = SamplingParams(), + sampling_params: Optional[SamplingParams] = None, response_format: Optional[ResponseFormat] = None, logprobs: Optional[LogProbConfig] = None, ) -> BatchCompletionResponse: ... @@ -50,7 +50,7 @@ class BatchInference(Protocol): self, model: str, messages_batch: List[List[Message]], - sampling_params: Optional[SamplingParams] = SamplingParams(), + sampling_params: Optional[SamplingParams] = None, # zero-shot tool definitions as input to the model tools: Optional[List[ToolDefinition]] = list, tool_choice: Optional[ToolChoice] = ToolChoice.auto, diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 08ceace4f..fa917ac22 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -278,7 +278,7 @@ ResponseFormat = register_schema( class CompletionRequest(BaseModel): model: str content: InterleavedContent - sampling_params: Optional[SamplingParams] = SamplingParams() + sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams) response_format: Optional[ResponseFormat] = None stream: Optional[bool] = False logprobs: Optional[LogProbConfig] = None @@ -357,7 +357,7 @@ class ToolConfig(BaseModel): class ChatCompletionRequest(BaseModel): model: str messages: List[Message] - sampling_params: Optional[SamplingParams] = SamplingParams() + sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams) tools: Optional[List[ToolDefinition]] = Field(default_factory=list) tool_config: Optional[ToolConfig] = Field(default_factory=ToolConfig) @@ -444,7 +444,7 @@ class Inference(Protocol): self, model_id: str, content: InterleavedContent, - sampling_params: Optional[SamplingParams] = SamplingParams(), + sampling_params: Optional[SamplingParams] = None, response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, @@ -467,7 +467,7 @@ class Inference(Protocol): self, model_id: str, messages: List[Message], - sampling_params: Optional[SamplingParams] = SamplingParams(), + sampling_params: Optional[SamplingParams] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = None, diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 2f62a513d..3cfc2b119 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -217,7 +217,7 @@ class InferenceRouter(Inference): self, model_id: str, messages: List[Message], - sampling_params: Optional[SamplingParams] = SamplingParams(), + sampling_params: Optional[SamplingParams] = None, response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = None, @@ -230,6 +230,8 @@ class InferenceRouter(Inference): "core", f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}", ) + if sampling_params is None: + sampling_params = SamplingParams() model = await self.routing_table.get_model(model_id) if model is None: raise ValueError(f"Model '{model_id}' not found") @@ -320,11 +322,13 @@ class InferenceRouter(Inference): self, model_id: str, content: InterleavedContent, - sampling_params: Optional[SamplingParams] = SamplingParams(), + sampling_params: Optional[SamplingParams] = None, response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + if sampling_params is None: + sampling_params = SamplingParams() logcat.debug( "core", f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}", diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 062bf215e..83e0b87e3 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -136,11 +136,13 @@ class MetaReferenceInferenceImpl( self, model_id: str, content: InterleavedContent, - sampling_params: Optional[SamplingParams] = SamplingParams(), + sampling_params: Optional[SamplingParams] = None, response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: + if sampling_params is None: + sampling_params = SamplingParams() if logprobs: assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}" @@ -244,7 +246,7 @@ class MetaReferenceInferenceImpl( self, model_id: str, messages: List[Message], - sampling_params: Optional[SamplingParams] = SamplingParams(), + sampling_params: Optional[SamplingParams] = None, response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, @@ -253,6 +255,8 @@ class MetaReferenceInferenceImpl( logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: + if sampling_params is None: + sampling_params = SamplingParams() if logprobs: assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}" diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index bfb09af53..b583896ad 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -53,7 +53,7 @@ class SentenceTransformersInferenceImpl( self, model_id: str, content: str, - sampling_params: Optional[SamplingParams] = SamplingParams(), + sampling_params: Optional[SamplingParams] = None, response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, @@ -64,7 +64,7 @@ class SentenceTransformersInferenceImpl( self, model_id: str, messages: List[Message], - sampling_params: Optional[SamplingParams] = SamplingParams(), + sampling_params: Optional[SamplingParams] = None, response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index e28b567b2..b461bf44a 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -143,7 +143,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): self, model_id: str, content: InterleavedContent, - sampling_params: Optional[SamplingParams] = SamplingParams(), + sampling_params: Optional[SamplingParams] = None, response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, @@ -154,7 +154,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): self, model_id: str, messages: List[Message], - sampling_params: Optional[SamplingParams] = SamplingParams(), + sampling_params: Optional[SamplingParams] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = None, @@ -163,6 +163,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, ) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk: + if sampling_params is None: + sampling_params = SamplingParams() assert self.engine is not None request = ChatCompletionRequest( diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index b82a4c752..120da5bd4 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -72,7 +72,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): self, model_id: str, content: InterleavedContent, - sampling_params: Optional[SamplingParams] = SamplingParams(), + sampling_params: Optional[SamplingParams] = None, response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, @@ -83,7 +83,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): self, model_id: str, messages: List[Message], - sampling_params: Optional[SamplingParams] = SamplingParams(), + sampling_params: Optional[SamplingParams] = None, response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, @@ -92,6 +92,8 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: + if sampling_params is None: + sampling_params = SamplingParams() model = await self.model_store.get_model(model_id) request = ChatCompletionRequest( model=model.provider_resource_id, diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py index 748c5237a..a53e6e5a5 100644 --- a/llama_stack/providers/remote/inference/cerebras/cerebras.py +++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -72,11 +72,13 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference): self, model_id: str, content: InterleavedContent, - sampling_params: Optional[SamplingParams] = SamplingParams(), + sampling_params: Optional[SamplingParams] = None, response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + if sampling_params is None: + sampling_params = SamplingParams() model = await self.model_store.get_model(model_id) request = CompletionRequest( model=model.provider_resource_id, @@ -112,7 +114,7 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference): self, model_id: str, messages: List[Message], - sampling_params: Optional[SamplingParams] = SamplingParams(), + sampling_params: Optional[SamplingParams] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = None, @@ -121,6 +123,8 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference): logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: + if sampling_params is None: + sampling_params = SamplingParams() model = await self.model_store.get_model(model_id) request = ChatCompletionRequest( model=model.provider_resource_id, diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index 9db430e4d..53a9c04f4 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -71,7 +71,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): self, model: str, content: InterleavedContent, - sampling_params: Optional[SamplingParams] = SamplingParams(), + sampling_params: Optional[SamplingParams] = None, response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, @@ -82,7 +82,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): self, model: str, messages: List[Message], - sampling_params: Optional[SamplingParams] = SamplingParams(), + sampling_params: Optional[SamplingParams] = None, response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, @@ -91,6 +91,8 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: + if sampling_params is None: + sampling_params = SamplingParams() request = ChatCompletionRequest( model=model, messages=messages, diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index e264fa434..a4cecf9f1 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -86,11 +86,13 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv self, model_id: str, content: InterleavedContent, - sampling_params: Optional[SamplingParams] = SamplingParams(), + sampling_params: Optional[SamplingParams] = None, response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + if sampling_params is None: + sampling_params = SamplingParams() model = await self.model_store.get_model(model_id) request = CompletionRequest( model=model.provider_resource_id, @@ -157,7 +159,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv self, model_id: str, messages: List[Message], - sampling_params: Optional[SamplingParams] = SamplingParams(), + sampling_params: Optional[SamplingParams] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = None, @@ -166,6 +168,8 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: + if sampling_params is None: + sampling_params = SamplingParams() model = await self.model_store.get_model(model_id) request = ChatCompletionRequest( model=model.provider_resource_id, diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index db9e176ee..b59da79eb 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -93,11 +93,13 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): self, model_id: str, content: InterleavedContent, - sampling_params: Optional[SamplingParams] = SamplingParams(), + sampling_params: Optional[SamplingParams] = None, response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: + if sampling_params is None: + sampling_params = SamplingParams() if content_has_media(content): raise NotImplementedError("Media is not supported") @@ -188,7 +190,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): self, model_id: str, messages: List[Message], - sampling_params: Optional[SamplingParams] = SamplingParams(), + sampling_params: Optional[SamplingParams] = None, response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, @@ -197,6 +199,8 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: + if sampling_params is None: + sampling_params = SamplingParams() if tool_prompt_format: warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring", stacklevel=2) diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 5a520f3b9..4d7fef8ed 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -90,11 +90,13 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): self, model_id: str, content: InterleavedContent, - sampling_params: Optional[SamplingParams] = SamplingParams(), + sampling_params: Optional[SamplingParams] = None, response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + if sampling_params is None: + sampling_params = SamplingParams() model = await self.model_store.get_model(model_id) request = CompletionRequest( model=model.provider_resource_id, @@ -145,7 +147,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): self, model_id: str, messages: List[Message], - sampling_params: Optional[SamplingParams] = SamplingParams(), + sampling_params: Optional[SamplingParams] = None, response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, @@ -154,6 +156,8 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: + if sampling_params is None: + sampling_params = SamplingParams() model = await self.model_store.get_model(model_id) request = ChatCompletionRequest( model=model.provider_resource_id, diff --git a/llama_stack/providers/remote/inference/passthrough/passthrough.py b/llama_stack/providers/remote/inference/passthrough/passthrough.py index 11da6bb9e..aa8a87bf7 100644 --- a/llama_stack/providers/remote/inference/passthrough/passthrough.py +++ b/llama_stack/providers/remote/inference/passthrough/passthrough.py @@ -81,11 +81,13 @@ class PassthroughInferenceAdapter(Inference): self, model_id: str, content: InterleavedContent, - sampling_params: Optional[SamplingParams] = SamplingParams(), + sampling_params: Optional[SamplingParams] = None, response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + if sampling_params is None: + sampling_params = SamplingParams() client = self._get_client() model = await self.model_store.get_model(model_id) @@ -107,7 +109,7 @@ class PassthroughInferenceAdapter(Inference): self, model_id: str, messages: List[Message], - sampling_params: Optional[SamplingParams] = SamplingParams(), + sampling_params: Optional[SamplingParams] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = None, @@ -116,6 +118,8 @@ class PassthroughInferenceAdapter(Inference): logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: + if sampling_params is None: + sampling_params = SamplingParams() client = self._get_client() model = await self.model_store.get_model(model_id) diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py index bd620aa64..783842f71 100644 --- a/llama_stack/providers/remote/inference/runpod/runpod.py +++ b/llama_stack/providers/remote/inference/runpod/runpod.py @@ -54,7 +54,7 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference): self, model: str, content: InterleavedContent, - sampling_params: Optional[SamplingParams] = SamplingParams(), + sampling_params: Optional[SamplingParams] = None, response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, @@ -65,7 +65,7 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference): self, model: str, messages: List[Message], - sampling_params: Optional[SamplingParams] = SamplingParams(), + sampling_params: Optional[SamplingParams] = None, response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, @@ -74,6 +74,8 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference): logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: + if sampling_params is None: + sampling_params = SamplingParams() request = ChatCompletionRequest( model=model, messages=messages, diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index 57a296258..a5e17c2a3 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -74,7 +74,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference): self, model_id: str, content: InterleavedContent, - sampling_params: Optional[SamplingParams] = SamplingParams(), + sampling_params: Optional[SamplingParams] = None, response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, @@ -85,7 +85,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference): self, model_id: str, messages: List[Message], - sampling_params: Optional[SamplingParams] = SamplingParams(), + sampling_params: Optional[SamplingParams] = None, response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, @@ -94,6 +94,8 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference): tool_config: Optional[ToolConfig] = None, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + if sampling_params is None: + sampling_params = SamplingParams() model = await self.model_store.get_model(model_id) request = ChatCompletionRequest( diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index d09ca241f..757085fb1 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -98,11 +98,13 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): self, model_id: str, content: InterleavedContent, - sampling_params: Optional[SamplingParams] = SamplingParams(), + sampling_params: Optional[SamplingParams] = None, response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + if sampling_params is None: + sampling_params = SamplingParams() model = await self.model_store.get_model(model_id) request = CompletionRequest( model=model.provider_resource_id, @@ -201,7 +203,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): self, model_id: str, messages: List[Message], - sampling_params: Optional[SamplingParams] = SamplingParams(), + sampling_params: Optional[SamplingParams] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = None, @@ -210,6 +212,8 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: + if sampling_params is None: + sampling_params = SamplingParams() model = await self.model_store.get_model(model_id) request = ChatCompletionRequest( model=model.provider_resource_id, diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 6fe1bd03d..0c468cdbf 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -70,11 +70,13 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi self, model_id: str, content: InterleavedContent, - sampling_params: Optional[SamplingParams] = SamplingParams(), + sampling_params: Optional[SamplingParams] = None, response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + if sampling_params is None: + sampling_params = SamplingParams() model = await self.model_store.get_model(model_id) request = CompletionRequest( model=model.provider_resource_id, @@ -151,7 +153,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi self, model_id: str, messages: List[Message], - sampling_params: Optional[SamplingParams] = SamplingParams(), + sampling_params: Optional[SamplingParams] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = None, @@ -160,6 +162,8 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: + if sampling_params is None: + sampling_params = SamplingParams() model = await self.model_store.get_model(model_id) request = ChatCompletionRequest( model=model.provider_resource_id, diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 714d6e9e8..ac9a46e85 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -241,11 +241,13 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): self, model_id: str, content: InterleavedContent, - sampling_params: Optional[SamplingParams] = SamplingParams(), + sampling_params: Optional[SamplingParams] = None, response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: + if sampling_params is None: + sampling_params = SamplingParams() model = await self.model_store.get_model(model_id) request = CompletionRequest( model=model.provider_resource_id, @@ -264,7 +266,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): self, model_id: str, messages: List[Message], - sampling_params: Optional[SamplingParams] = SamplingParams(), + sampling_params: Optional[SamplingParams] = None, response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, @@ -273,6 +275,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: + if sampling_params is None: + sampling_params = SamplingParams() model = await self.model_store.get_model(model_id) # This is to be consistent with OpenAI API and support vLLM <= v0.6.3 # References: diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index 92199baa9..9467996a6 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -74,7 +74,7 @@ class LiteLLMOpenAIMixin( self, model_id: str, content: InterleavedContent, - sampling_params: Optional[SamplingParams] = SamplingParams(), + sampling_params: Optional[SamplingParams] = None, response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, @@ -85,7 +85,7 @@ class LiteLLMOpenAIMixin( self, model_id: str, messages: List[Message], - sampling_params: Optional[SamplingParams] = SamplingParams(), + sampling_params: Optional[SamplingParams] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = None, @@ -94,6 +94,8 @@ class LiteLLMOpenAIMixin( logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: + if sampling_params is None: + sampling_params = SamplingParams() model = await self.model_store.get_model(model_id) request = ChatCompletionRequest( model=model.provider_resource_id, diff --git a/pyproject.toml b/pyproject.toml index 08d8011b0..a58d01076 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -136,8 +136,6 @@ ignore = [ # These are the additional ones we started ignoring after moving to ruff. We should look into each one of them later. "C901", # Complexity of the function is too high - # these ignores are from flake8-bugbear; please fix! - "B008", ] [tool.mypy]