chore: don't use asserts to guarantee self.model_store is not None

asserts in production code are not advised; to avoid duplicating the
logic handling missing model_store, introduce helper functions.

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-04-01 10:19:55 -04:00
parent d443607d65
commit f08035c2fc
2 changed files with 16 additions and 11 deletions

View file

@ -91,6 +91,11 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
async def unregister_model(self, model_id: str) -> None:
pass
async def _get_model(self, model_id: str) -> Model:
if not self.model_store:
raise ValueError("Model store not set")
return await self.model_store.get_model(model_id)
async def completion(
self,
model_id: str,
@ -100,10 +105,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
assert self.model_store is not None
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
model = await self._get_model(model_id)
request = CompletionRequest(
model=model.provider_resource_id,
content=content,
@ -164,10 +168,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
assert self.model_store is not None
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
model = await self._get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
@ -286,8 +289,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
assert self.model_store is not None
model = await self.model_store.get_model(model_id)
model = await self._get_model(model_id)
assert all(not content_has_media(content) for content in contents), (
"Ollama does not support media for embeddings"

View file

@ -237,6 +237,11 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def unregister_model(self, model_id: str) -> None:
pass
async def _get_model(self, model_id: str) -> Model:
if not self.model_store:
raise ValueError("Model store not set")
return await self.model_store.get_model(model_id)
async def completion(
self,
model_id: str,
@ -246,10 +251,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
assert self.model_store is not None
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
model = await self._get_model(model_id)
request = CompletionRequest(
model=model.provider_resource_id,
content=content,
@ -276,10 +280,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
assert self.model_store is not None
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
model = await self._get_model(model_id)
# This is to be consistent with OpenAI API and support vLLM <= v0.6.3
# References:
# * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
@ -400,7 +403,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
assert self.client is not None
model = await self.model_store.get_model(model_id)
model = await self._get_model(model_id)
kwargs = {}
assert model.model_type == ModelType.embedding