model registration in ollama and vllm check against the available models in the provider (#446)

tests:
pytest -v -s -m "ollama"
llama_stack/providers/tests/inference/test_text_inference.py

pytest -v -s -m vllm_remote
llama_stack/providers/tests/inference/test_text_inference.py --env
VLLM_URL="http://localhost:9798/v1"

---------
This commit is contained in:
Dinesh Yeduguru 2024-11-13 13:04:06 -08:00 committed by GitHub
parent 7f6ac2fbd7
commit 787e2034b7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 73 additions and 14 deletions

View file

@ -71,12 +71,9 @@ model_aliases = [
]
class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate):
class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
def __init__(self, url: str) -> None:
ModelRegistryHelper.__init__(
self,
model_aliases=model_aliases,
)
self.register_helper = ModelRegistryHelper(model_aliases)
self.url = url
self.formatter = ChatFormat(Tokenizer.get_instance())
@ -203,7 +200,9 @@ class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPriva
else:
input_dict["raw"] = True
input_dict["prompt"] = chat_completion_request_to_prompt(
request, self.get_llama_model(request.model), self.formatter
request,
self.register_helper.get_llama_model(request.model),
self.formatter,
)
else:
assert (
@ -282,6 +281,18 @@ class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPriva
) -> EmbeddingsResponse:
raise NotImplementedError()
async def register_model(self, model: Model) -> Model:
model = await self.register_helper.register_model(model)
models = await self.client.ps()
available_models = [m["model"] for m in models["models"]]
if model.provider_resource_id not in available_models:
raise ValueError(
f"Model '{model.provider_resource_id}' is not available in Ollama. "
f"Available models: {', '.join(available_models)}"
)
return model
async def convert_message_to_dict_for_ollama(message: Message) -> List[dict]:
async def _convert_content(content) -> dict:

View file

@ -45,12 +45,9 @@ def build_model_aliases():
]
class VLLMInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate):
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
ModelRegistryHelper.__init__(
self,
model_aliases=build_model_aliases(),
)
self.register_helper = ModelRegistryHelper(build_model_aliases())
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
self.client = None
@ -131,6 +128,17 @@ class VLLMInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate
):
yield chunk
async def register_model(self, model: Model) -> Model:
model = await self.register_helper.register_model(model)
res = self.client.models.list()
available_models = [m.id for m in res]
if model.provider_resource_id not in available_models:
raise ValueError(
f"Model {model.provider_resource_id} is not being served by vLLM. "
f"Available models: {', '.join(available_models)}"
)
return model
async def _get_params(
self, request: Union[ChatCompletionRequest, CompletionRequest]
) -> dict:
@ -149,7 +157,9 @@ class VLLMInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate
]
else:
input_dict["prompt"] = chat_completion_request_to_prompt(
request, self.get_llama_model(request.model), self.formatter
request,
self.register_helper.get_llama_model(request.model),
self.formatter,
)
else:
assert (
@ -157,7 +167,7 @@ class VLLMInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate
), "Together does not support media for Completion requests"
input_dict["prompt"] = completion_request_to_prompt(
request,
self.get_llama_model(request.model),
self.register_helper.get_llama_model(request.model),
self.formatter,
)

View file

@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
# How to run this test:
#
# pytest -v -s llama_stack/providers/tests/inference/test_model_registration.py
# -m "meta_reference"
# --env TOGETHER_API_KEY=<your_api_key>
class TestModelRegistration:
@pytest.mark.asyncio
async def test_register_unsupported_model(self, inference_stack):
_, models_impl = inference_stack
# Try to register a model that's too large for local inference
with pytest.raises(Exception) as exc_info:
await models_impl.register_model(
model_id="Llama3.1-70B-Instruct",
)
@pytest.mark.asyncio
async def test_register_nonexistent_model(self, inference_stack):
_, models_impl = inference_stack
# Try to register a non-existent model
with pytest.raises(Exception) as exc_info:
await models_impl.register_model(
model_id="Llama3-NonExistent-Model",
)

View file

@ -54,7 +54,10 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
raise ValueError(f"Unknown model: `{identifier}`")
def get_llama_model(self, provider_model_id: str) -> str:
return self.provider_id_to_llama_model_map[provider_model_id]
if provider_model_id in self.provider_id_to_llama_model_map:
return self.provider_id_to_llama_model_map[provider_model_id]
else:
return None
async def register_model(self, model: Model) -> Model:
model.provider_resource_id = self.get_provider_model_id(