fix vllm registry

This commit is contained in:
Dinesh Yeduguru 2024-11-13 12:12:22 -08:00
parent a019011326
commit e272f8aa62
3 changed files with 10 additions and 4 deletions

View file

@ -283,10 +283,7 @@ class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPriva
raise NotImplementedError()
async def register_model(self, model: Model) -> Model:
# First perform the parent class's registration check
model = await super().register_model(model)
# Additional Ollama-specific check
models = await self.client.ps()
available_models = [m["model"] for m in models["models"]]
if model.provider_resource_id not in available_models:

View file

@ -131,6 +131,15 @@ class VLLMInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate
):
yield chunk
async def register_model(self, model: Model) -> None:
model = await super().register_model(model)
res = self.client.models.list()
available_models = [m.id for m in res]
if model.provider_resource_id not in available_models:
raise ValueError(
f"Model {model.provider_resource_id} is not being served by vLLM"
)
async def _get_params(
self, request: Union[ChatCompletionRequest, CompletionRequest]
) -> dict:

View file

@ -57,7 +57,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
if provider_model_id in self.provider_id_to_llama_model_map:
return self.provider_id_to_llama_model_map[provider_model_id]
else:
None
return None
async def register_model(self, model: Model) -> Model:
model.provider_resource_id = self.get_provider_model_id(