mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-31 16:01:46 +00:00
address feedback
This commit is contained in:
parent
96b1bafcde
commit
7e4765c45b
3 changed files with 44 additions and 8 deletions
|
@ -73,7 +73,7 @@ model_aliases = [
|
|||
|
||||
class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||
def __init__(self, url: str) -> None:
|
||||
self.model_register_helper = ModelRegistryHelper(model_aliases)
|
||||
self.register_helper = ModelRegistryHelper(model_aliases)
|
||||
self.url = url
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
|
@ -201,7 +201,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
input_dict["raw"] = True
|
||||
input_dict["prompt"] = chat_completion_request_to_prompt(
|
||||
request,
|
||||
self.model_register_helper.get_llama_model(request.model),
|
||||
self.register_helper.get_llama_model(request.model),
|
||||
self.formatter,
|
||||
)
|
||||
else:
|
||||
|
@ -282,7 +282,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
raise NotImplementedError()
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
model = await self.model_register_helper.register_model(model)
|
||||
model = await self.register_helper.register_model(model)
|
||||
models = await self.client.ps()
|
||||
available_models = [m["model"] for m in models["models"]]
|
||||
if model.provider_resource_id not in available_models:
|
||||
|
|
|
@ -47,7 +47,7 @@ def build_model_aliases():
|
|||
|
||||
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
|
||||
self.model_register_helper = ModelRegistryHelper(build_model_aliases())
|
||||
self.register_helper = ModelRegistryHelper(build_model_aliases())
|
||||
self.config = config
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
self.client = None
|
||||
|
@ -129,12 +129,13 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
yield chunk
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
model = await self.model_register_helper.register_model(model)
|
||||
model = await self.register_helper.register_model(model)
|
||||
res = self.client.models.list()
|
||||
available_models = [m.id for m in res]
|
||||
if model.provider_resource_id not in available_models:
|
||||
raise ValueError(
|
||||
f"Model {model.provider_resource_id} is not being served by vLLM"
|
||||
f"Model {model.provider_resource_id} is not being served by vLLM. "
|
||||
f"Available models: {', '.join(available_models)}"
|
||||
)
|
||||
return model
|
||||
|
||||
|
@ -157,7 +158,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
else:
|
||||
input_dict["prompt"] = chat_completion_request_to_prompt(
|
||||
request,
|
||||
self.model_register_helper.get_llama_model(request.model),
|
||||
self.register_helper.get_llama_model(request.model),
|
||||
self.formatter,
|
||||
)
|
||||
else:
|
||||
|
@ -166,7 +167,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
), "Together does not support media for Completion requests"
|
||||
input_dict["prompt"] = completion_request_to_prompt(
|
||||
request,
|
||||
self.model_register_helper.get_llama_model(request.model),
|
||||
self.register_helper.get_llama_model(request.model),
|
||||
self.formatter,
|
||||
)
|
||||
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
# pytest -v -s llama_stack/providers/tests/inference/test_model_registration.py
|
||||
# -m "meta_reference"
|
||||
# --env TOGETHER_API_KEY=<your_api_key>
|
||||
|
||||
|
||||
class TestModelRegistration:
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_unsupported_model(self, inference_stack):
|
||||
_, models_impl = inference_stack
|
||||
|
||||
# Try to register a model that's too large for local inference
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await models_impl.register_model(
|
||||
model_id="Llama3.1-70B-Instruct",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_nonexistent_model(self, inference_stack):
|
||||
_, models_impl = inference_stack
|
||||
|
||||
# Try to register a non-existent model
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await models_impl.register_model(
|
||||
model_id="Llama3-NonExistent-Model",
|
||||
)
|
Loading…
Add table
Add a link
Reference in a new issue