mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
model registration in ollama and vllm check against the available models in the provider (#446)
tests: pytest -v -s -m "ollama" llama_stack/providers/tests/inference/test_text_inference.py pytest -v -s -m vllm_remote llama_stack/providers/tests/inference/test_text_inference.py --env VLLM_URL="http://localhost:9798/v1" ---------
This commit is contained in:
parent
7f6ac2fbd7
commit
787e2034b7
4 changed files with 73 additions and 14 deletions
|
@ -71,12 +71,9 @@ model_aliases = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate):
|
class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
def __init__(self, url: str) -> None:
|
def __init__(self, url: str) -> None:
|
||||||
ModelRegistryHelper.__init__(
|
self.register_helper = ModelRegistryHelper(model_aliases)
|
||||||
self,
|
|
||||||
model_aliases=model_aliases,
|
|
||||||
)
|
|
||||||
self.url = url
|
self.url = url
|
||||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||||
|
|
||||||
|
@ -203,7 +200,9 @@ class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPriva
|
||||||
else:
|
else:
|
||||||
input_dict["raw"] = True
|
input_dict["raw"] = True
|
||||||
input_dict["prompt"] = chat_completion_request_to_prompt(
|
input_dict["prompt"] = chat_completion_request_to_prompt(
|
||||||
request, self.get_llama_model(request.model), self.formatter
|
request,
|
||||||
|
self.register_helper.get_llama_model(request.model),
|
||||||
|
self.formatter,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert (
|
assert (
|
||||||
|
@ -282,6 +281,18 @@ class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPriva
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
async def register_model(self, model: Model) -> Model:
|
||||||
|
model = await self.register_helper.register_model(model)
|
||||||
|
models = await self.client.ps()
|
||||||
|
available_models = [m["model"] for m in models["models"]]
|
||||||
|
if model.provider_resource_id not in available_models:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model '{model.provider_resource_id}' is not available in Ollama. "
|
||||||
|
f"Available models: {', '.join(available_models)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
async def convert_message_to_dict_for_ollama(message: Message) -> List[dict]:
|
async def convert_message_to_dict_for_ollama(message: Message) -> List[dict]:
|
||||||
async def _convert_content(content) -> dict:
|
async def _convert_content(content) -> dict:
|
||||||
|
|
|
@ -45,12 +45,9 @@ def build_model_aliases():
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class VLLMInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate):
|
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
|
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
|
||||||
ModelRegistryHelper.__init__(
|
self.register_helper = ModelRegistryHelper(build_model_aliases())
|
||||||
self,
|
|
||||||
model_aliases=build_model_aliases(),
|
|
||||||
)
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||||
self.client = None
|
self.client = None
|
||||||
|
@ -131,6 +128,17 @@ class VLLMInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
async def register_model(self, model: Model) -> Model:
|
||||||
|
model = await self.register_helper.register_model(model)
|
||||||
|
res = self.client.models.list()
|
||||||
|
available_models = [m.id for m in res]
|
||||||
|
if model.provider_resource_id not in available_models:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model {model.provider_resource_id} is not being served by vLLM. "
|
||||||
|
f"Available models: {', '.join(available_models)}"
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
async def _get_params(
|
async def _get_params(
|
||||||
self, request: Union[ChatCompletionRequest, CompletionRequest]
|
self, request: Union[ChatCompletionRequest, CompletionRequest]
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
@ -149,7 +157,9 @@ class VLLMInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
input_dict["prompt"] = chat_completion_request_to_prompt(
|
input_dict["prompt"] = chat_completion_request_to_prompt(
|
||||||
request, self.get_llama_model(request.model), self.formatter
|
request,
|
||||||
|
self.register_helper.get_llama_model(request.model),
|
||||||
|
self.formatter,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert (
|
assert (
|
||||||
|
@ -157,7 +167,7 @@ class VLLMInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate
|
||||||
), "Together does not support media for Completion requests"
|
), "Together does not support media for Completion requests"
|
||||||
input_dict["prompt"] = completion_request_to_prompt(
|
input_dict["prompt"] = completion_request_to_prompt(
|
||||||
request,
|
request,
|
||||||
self.get_llama_model(request.model),
|
self.register_helper.get_llama_model(request.model),
|
||||||
self.formatter,
|
self.formatter,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,35 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# How to run this test:
|
||||||
|
#
|
||||||
|
# pytest -v -s llama_stack/providers/tests/inference/test_model_registration.py
|
||||||
|
# -m "meta_reference"
|
||||||
|
# --env TOGETHER_API_KEY=<your_api_key>
|
||||||
|
|
||||||
|
|
||||||
|
class TestModelRegistration:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_unsupported_model(self, inference_stack):
|
||||||
|
_, models_impl = inference_stack
|
||||||
|
|
||||||
|
# Try to register a model that's too large for local inference
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
await models_impl.register_model(
|
||||||
|
model_id="Llama3.1-70B-Instruct",
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_nonexistent_model(self, inference_stack):
|
||||||
|
_, models_impl = inference_stack
|
||||||
|
|
||||||
|
# Try to register a non-existent model
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
await models_impl.register_model(
|
||||||
|
model_id="Llama3-NonExistent-Model",
|
||||||
|
)
|
|
@ -54,7 +54,10 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||||
raise ValueError(f"Unknown model: `{identifier}`")
|
raise ValueError(f"Unknown model: `{identifier}`")
|
||||||
|
|
||||||
def get_llama_model(self, provider_model_id: str) -> str:
|
def get_llama_model(self, provider_model_id: str) -> str:
|
||||||
return self.provider_id_to_llama_model_map[provider_model_id]
|
if provider_model_id in self.provider_id_to_llama_model_map:
|
||||||
|
return self.provider_id_to_llama_model_map[provider_model_id]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> Model:
|
async def register_model(self, model: Model) -> Model:
|
||||||
model.provider_resource_id = self.get_provider_model_id(
|
model.provider_resource_id = self.get_provider_model_id(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue