From f08035c2fc779f33574816e0ed2acbed05d5b78b Mon Sep 17 00:00:00 2001 From: Ihar Hrachyshka Date: Tue, 1 Apr 2025 10:19:55 -0400 Subject: [PATCH] chore: don't use asserts to guarantee self.model_store is not None asserts in production code are not advised; to avoid duplicating the logic handling missing model_store, introduce helper functions. Signed-off-by: Ihar Hrachyshka --- .../providers/remote/inference/ollama/ollama.py | 14 ++++++++------ .../providers/remote/inference/vllm/vllm.py | 13 ++++++++----- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 4cfd81ead..5a78c07cc 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -91,6 +91,11 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): async def unregister_model(self, model_id: str) -> None: pass + async def _get_model(self, model_id: str) -> Model: + if not self.model_store: + raise ValueError("Model store not set") + return await self.model_store.get_model(model_id) + async def completion( self, model_id: str, @@ -100,10 +105,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]: - assert self.model_store is not None if sampling_params is None: sampling_params = SamplingParams() - model = await self.model_store.get_model(model_id) + model = await self._get_model(model_id) request = CompletionRequest( model=model.provider_resource_id, content=content, @@ -164,10 +168,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, ) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]: - assert self.model_store is not None if sampling_params is None: sampling_params = SamplingParams() - model = await self.model_store.get_model(model_id) + model = await self._get_model(model_id) request = ChatCompletionRequest( model=model.provider_resource_id, messages=messages, @@ -286,8 +289,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): output_dimension: Optional[int] = None, task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: - assert self.model_store is not None - model = await self.model_store.get_model(model_id) + model = await self._get_model(model_id) assert all(not content_has_media(content) for content in contents), ( "Ollama does not support media for embeddings" diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 6f99bf007..6a828322f 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -237,6 +237,11 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): async def unregister_model(self, model_id: str) -> None: pass + async def _get_model(self, model_id: str) -> Model: + if not self.model_store: + raise ValueError("Model store not set") + return await self.model_store.get_model(model_id) + async def completion( self, model_id: str, @@ -246,10 +251,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]: - assert self.model_store is not None if sampling_params is None: sampling_params = SamplingParams() - model = await self.model_store.get_model(model_id) + model = await self._get_model(model_id) request = CompletionRequest( model=model.provider_resource_id, content=content, @@ -276,10 +280,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, ) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]: - assert self.model_store is not None if sampling_params is None: sampling_params = SamplingParams() - model = await self.model_store.get_model(model_id) + model = await self._get_model(model_id) # This is to be consistent with OpenAI API and support vLLM <= v0.6.3 # References: # * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice @@ -400,7 +403,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: assert self.client is not None - model = await self.model_store.get_model(model_id) + model = await self._get_model(model_id) kwargs = {} assert model.model_type == ModelType.embedding