mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 02:32:40 +00:00
chore: don't use asserts to guarantee self.model_store is not None
asserts in production code are not advised; to avoid duplicating the logic handling missing model_store, introduce helper functions. Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
d443607d65
commit
f08035c2fc
2 changed files with 16 additions and 11 deletions
|
@ -91,6 +91,11 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
async def unregister_model(self, model_id: str) -> None:
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def _get_model(self, model_id: str) -> Model:
|
||||||
|
if not self.model_store:
|
||||||
|
raise ValueError("Model store not set")
|
||||||
|
return await self.model_store.get_model(model_id)
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -100,10 +105,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
|
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||||
assert self.model_store is not None
|
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
sampling_params = SamplingParams()
|
sampling_params = SamplingParams()
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self._get_model(model_id)
|
||||||
request = CompletionRequest(
|
request = CompletionRequest(
|
||||||
model=model.provider_resource_id,
|
model=model.provider_resource_id,
|
||||||
content=content,
|
content=content,
|
||||||
|
@ -164,10 +168,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||||
assert self.model_store is not None
|
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
sampling_params = SamplingParams()
|
sampling_params = SamplingParams()
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self._get_model(model_id)
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=model.provider_resource_id,
|
model=model.provider_resource_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -286,8 +289,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
output_dimension: Optional[int] = None,
|
output_dimension: Optional[int] = None,
|
||||||
task_type: Optional[EmbeddingTaskType] = None,
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
assert self.model_store is not None
|
model = await self._get_model(model_id)
|
||||||
model = await self.model_store.get_model(model_id)
|
|
||||||
|
|
||||||
assert all(not content_has_media(content) for content in contents), (
|
assert all(not content_has_media(content) for content in contents), (
|
||||||
"Ollama does not support media for embeddings"
|
"Ollama does not support media for embeddings"
|
||||||
|
|
|
@ -237,6 +237,11 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
async def unregister_model(self, model_id: str) -> None:
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def _get_model(self, model_id: str) -> Model:
|
||||||
|
if not self.model_store:
|
||||||
|
raise ValueError("Model store not set")
|
||||||
|
return await self.model_store.get_model(model_id)
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -246,10 +251,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
|
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||||
assert self.model_store is not None
|
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
sampling_params = SamplingParams()
|
sampling_params = SamplingParams()
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self._get_model(model_id)
|
||||||
request = CompletionRequest(
|
request = CompletionRequest(
|
||||||
model=model.provider_resource_id,
|
model=model.provider_resource_id,
|
||||||
content=content,
|
content=content,
|
||||||
|
@ -276,10 +280,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||||
assert self.model_store is not None
|
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
sampling_params = SamplingParams()
|
sampling_params = SamplingParams()
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self._get_model(model_id)
|
||||||
# This is to be consistent with OpenAI API and support vLLM <= v0.6.3
|
# This is to be consistent with OpenAI API and support vLLM <= v0.6.3
|
||||||
# References:
|
# References:
|
||||||
# * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
|
# * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
|
||||||
|
@ -400,7 +403,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
task_type: Optional[EmbeddingTaskType] = None,
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
assert self.client is not None
|
assert self.client is not None
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self._get_model(model_id)
|
||||||
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
assert model.model_type == ModelType.embedding
|
assert model.model_type == ModelType.embedding
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue