diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 98baf846e..1d4012c19 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -394,7 +394,7 @@ class EmbeddingsResponse(BaseModel): class ModelStore(Protocol): - def get_model(self, identifier: str) -> Model: ... + async def get_model(self, identifier: str) -> Model: ... class TextTruncation(Enum): diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index a2ac113e8..b28cb2016 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -103,7 +103,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): assert self.model_store is not None if sampling_params is None: sampling_params = SamplingParams() - model = self.model_store.get_model(model_id) + model = await self.model_store.get_model(model_id) request = CompletionRequest( model=model.provider_resource_id, content=content, @@ -167,7 +167,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): assert self.model_store is not None if sampling_params is None: sampling_params = SamplingParams() - model = self.model_store.get_model(model_id) + model = await self.model_store.get_model(model_id) request = ChatCompletionRequest( model=model.provider_resource_id, messages=messages, @@ -287,7 +287,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: assert self.model_store is not None - model = self.model_store.get_model(model_id) + model = await self.model_store.get_model(model_id) assert all(not content_has_media(content) for content in contents), ( "Ollama does not support media for embeddings" diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 26e429592..8aed67d04 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -249,7 +249,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): assert self.model_store is not None if sampling_params is None: sampling_params = SamplingParams() - model = self.model_store.get_model(model_id) + model = await self.model_store.get_model(model_id) request = CompletionRequest( model=model.provider_resource_id, content=content, @@ -279,7 +279,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): assert self.model_store is not None if sampling_params is None: sampling_params = SamplingParams() - model = self.model_store.get_model(model_id) + model = await self.model_store.get_model(model_id) # This is to be consistent with OpenAI API and support vLLM <= v0.6.3 # References: # * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice @@ -397,7 +397,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: assert self.client is not None - model = self.model_store.get_model(model_id) + model = await self.model_store.get_model(model_id) kwargs = {} assert model.model_type == ModelType.embedding