diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 185aeeb03..bd7f5073c 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -45,27 +45,25 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token) async def register_model(self, model: Model) -> None: - raise ValueError("Model registration is not supported for vLLM models") - - async def shutdown(self) -> None: - pass - - async def list_models(self) -> List[Model]: - models = [] - for model in self.client.models.list(): - repo = model.id + for running_model in self.client.models.list(): + repo = running_model.id if repo not in self.huggingface_repo_to_llama_model_id: print(f"Unknown model served by vllm: {repo}") continue identifier = self.huggingface_repo_to_llama_model_id[repo] - models.append( - Model( - identifier=identifier, - llama_model=identifier, + if identifier == model.provider_resource_id: + print( + f"Verified that model {model.provider_resource_id} is being served by vLLM" ) - ) - return models + return + + raise ValueError( + f"Model {model.provider_resource_id} is not being served by vLLM" + ) + + async def shutdown(self) -> None: + pass async def completion( self,