From deee355952594d230b8ed060a69eaf5d8d45a194 Mon Sep 17 00:00:00 2001 From: Ilya Kolchinsky <58424190+ilya-kolchinsky@users.noreply.github.com> Date: Wed, 23 Apr 2025 15:33:19 +0200 Subject: [PATCH] fix: Added lazy initialization of the remote vLLM client to avoid issues with expired asyncio event loop (#1969) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? Closes #1968. The asynchronous client in `VLLMInferenceAdapter` is now initialized directly before first use and not in `VLLMInferenceAdapter.initialize`. This prevents issues arising due to accessing an expired event loop from a completed `asyncio.run`. ## Test Plan Ran unit tests, including `test_remote_vllm.py`. Ran the code snippet mentioned in #1968. --------- Co-authored-by: Sébastien Han --- .../providers/remote/inference/vllm/vllm.py | 33 ++++++++++++++----- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index d141afa86..8cfef2ee0 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -231,12 +231,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): self.client = None async def initialize(self) -> None: - log.info(f"Initializing VLLM client with base_url={self.config.url}") - self.client = AsyncOpenAI( - base_url=self.config.url, - api_key=self.config.api_token, - http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False), - ) + pass async def shutdown(self) -> None: pass @@ -249,6 +244,20 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): raise ValueError("Model store not set") return await self.model_store.get_model(model_id) + def _lazy_initialize_client(self): + if self.client is not None: + return + + log.info(f"Initializing vLLM client with base_url={self.config.url}") + self.client = self._create_client() + + def _create_client(self): + return AsyncOpenAI( + base_url=self.config.url, + api_key=self.config.api_token, + http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False), + ) + async def completion( self, model_id: str, @@ -258,6 +267,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]: + self._lazy_initialize_client() if sampling_params is None: sampling_params = SamplingParams() model = await self._get_model(model_id) @@ -287,6 +297,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, ) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]: + self._lazy_initialize_client() if sampling_params is None: sampling_params = SamplingParams() model = await self._get_model(model_id) @@ -357,9 +368,12 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): yield chunk async def register_model(self, model: Model) -> Model: - assert self.client is not None + # register_model is called during Llama Stack initialization, hence we cannot init self.client if not initialized yet. + # self.client should only be created after the initialization is complete to avoid asyncio cross-context errors. + # Changing this may lead to unpredictable behavior. + client = self._create_client() if self.client is None else self.client model = await self.register_helper.register_model(model) - res = await self.client.models.list() + res = await client.models.list() available_models = [m.id async for m in res] if model.provider_resource_id not in available_models: raise ValueError( @@ -410,6 +424,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): output_dimension: Optional[int] = None, task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: + self._lazy_initialize_client() assert self.client is not None model = await self._get_model(model_id) @@ -449,6 +464,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): guided_choice: Optional[List[str]] = None, prompt_logprobs: Optional[int] = None, ) -> OpenAICompletion: + self._lazy_initialize_client() model_obj = await self._get_model(model) extra_body: Dict[str, Any] = {} @@ -505,6 +521,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): top_p: Optional[float] = None, user: Optional[str] = None, ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]: + self._lazy_initialize_client() model_obj = await self._get_model(model) params = await prepare_openai_completion_params( model=model_obj.provider_resource_id,