chore: don't use asserts to guarantee self.model_store is not None

asserts in production code are not advised; to avoid duplicating the
logic handling missing model_store, introduce helper functions.

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-04-01 10:19:55 -04:00
parent d443607d65
commit f08035c2fc
2 changed files with 16 additions and 11 deletions

View file

@ -91,6 +91,11 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
async def unregister_model(self, model_id: str) -> None: async def unregister_model(self, model_id: str) -> None:
pass pass
async def _get_model(self, model_id: str) -> Model:
if not self.model_store:
raise ValueError("Model store not set")
return await self.model_store.get_model(model_id)
async def completion( async def completion(
self, self,
model_id: str, model_id: str,
@ -100,10 +105,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]: ) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
assert self.model_store is not None
if sampling_params is None: if sampling_params is None:
sampling_params = SamplingParams() sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id) model = await self._get_model(model_id)
request = CompletionRequest( request = CompletionRequest(
model=model.provider_resource_id, model=model.provider_resource_id,
content=content, content=content,
@ -164,10 +168,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]: ) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
assert self.model_store is not None
if sampling_params is None: if sampling_params is None:
sampling_params = SamplingParams() sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id) model = await self._get_model(model_id)
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=model.provider_resource_id, model=model.provider_resource_id,
messages=messages, messages=messages,
@ -286,8 +289,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
output_dimension: Optional[int] = None, output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None, task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
assert self.model_store is not None model = await self._get_model(model_id)
model = await self.model_store.get_model(model_id)
assert all(not content_has_media(content) for content in contents), ( assert all(not content_has_media(content) for content in contents), (
"Ollama does not support media for embeddings" "Ollama does not support media for embeddings"

View file

@ -237,6 +237,11 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def unregister_model(self, model_id: str) -> None: async def unregister_model(self, model_id: str) -> None:
pass pass
async def _get_model(self, model_id: str) -> Model:
if not self.model_store:
raise ValueError("Model store not set")
return await self.model_store.get_model(model_id)
async def completion( async def completion(
self, self,
model_id: str, model_id: str,
@ -246,10 +251,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]: ) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
assert self.model_store is not None
if sampling_params is None: if sampling_params is None:
sampling_params = SamplingParams() sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id) model = await self._get_model(model_id)
request = CompletionRequest( request = CompletionRequest(
model=model.provider_resource_id, model=model.provider_resource_id,
content=content, content=content,
@ -276,10 +280,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]: ) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
assert self.model_store is not None
if sampling_params is None: if sampling_params is None:
sampling_params = SamplingParams() sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id) model = await self._get_model(model_id)
# This is to be consistent with OpenAI API and support vLLM <= v0.6.3 # This is to be consistent with OpenAI API and support vLLM <= v0.6.3
# References: # References:
# * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice # * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
@ -400,7 +403,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
task_type: Optional[EmbeddingTaskType] = None, task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
assert self.client is not None assert self.client is not None
model = await self.model_store.get_model(model_id) model = await self._get_model(model_id)
kwargs = {} kwargs = {}
assert model.model_type == ModelType.embedding assert model.model_type == ModelType.embedding