mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
fix vllm registry
This commit is contained in:
parent
a019011326
commit
e272f8aa62
3 changed files with 10 additions and 4 deletions
|
@ -283,10 +283,7 @@ class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPriva
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> Model:
|
async def register_model(self, model: Model) -> Model:
|
||||||
# First perform the parent class's registration check
|
|
||||||
model = await super().register_model(model)
|
model = await super().register_model(model)
|
||||||
|
|
||||||
# Additional Ollama-specific check
|
|
||||||
models = await self.client.ps()
|
models = await self.client.ps()
|
||||||
available_models = [m["model"] for m in models["models"]]
|
available_models = [m["model"] for m in models["models"]]
|
||||||
if model.provider_resource_id not in available_models:
|
if model.provider_resource_id not in available_models:
|
||||||
|
|
|
@ -131,6 +131,15 @@ class VLLMInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
async def register_model(self, model: Model) -> None:
|
||||||
|
model = await super().register_model(model)
|
||||||
|
res = self.client.models.list()
|
||||||
|
available_models = [m.id for m in res]
|
||||||
|
if model.provider_resource_id not in available_models:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model {model.provider_resource_id} is not being served by vLLM"
|
||||||
|
)
|
||||||
|
|
||||||
async def _get_params(
|
async def _get_params(
|
||||||
self, request: Union[ChatCompletionRequest, CompletionRequest]
|
self, request: Union[ChatCompletionRequest, CompletionRequest]
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -57,7 +57,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||||
if provider_model_id in self.provider_id_to_llama_model_map:
|
if provider_model_id in self.provider_id_to_llama_model_map:
|
||||||
return self.provider_id_to_llama_model_map[provider_model_id]
|
return self.provider_id_to_llama_model_map[provider_model_id]
|
||||||
else:
|
else:
|
||||||
None
|
return None
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> Model:
|
async def register_model(self, model: Model) -> Model:
|
||||||
model.provider_resource_id = self.get_provider_model_id(
|
model.provider_resource_id = self.get_provider_model_id(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue