forked from phoenix-oss/llama-stack-mirror
# What does this PR do? This PR adds SambaNova as one of the Provider - Add SambaNova as a provider ## Test Plan Test the functional command ``` pytest -s -v --providers inference=sambanova llama_stack/providers/tests/inference/test_embeddings.py llama_stack/providers/tests/inference/test_prompt_adapter.py llama_stack/providers/tests/inference/test_text_inference.py llama_stack/providers/tests/inference/test_vision_inference.py --env SAMBANOVA_API_KEY=<sambanova-api-key> ``` Test the distribution template: ``` # Docker LLAMA_STACK_PORT=5001 docker run -it -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ llamastack/distribution-sambanova \ --port $LLAMA_STACK_PORT \ --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY # Conda llama stack build --template sambanova --image-type conda llama stack run ./run.yaml \ --port $LLAMA_STACK_PORT \ --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY ``` ## Source [SambaNova API Documentation](https://cloud.sambanova.ai/apis) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [Y] Ran pre-commit to handle lint / formatting issues. - [Y] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [Y] Updated relevant documentation. - [Y ] Wrote necessary unit or integration tests. --------- Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
95 lines
3.3 KiB
Python
95 lines
3.3 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import pytest
|
|
|
|
|
|
# How to run this test:
|
|
#
|
|
# torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="Llama3.1-8B-Instruct"
|
|
# ./llama_stack/providers/tests/inference/test_model_registration.py
|
|
|
|
|
|
class TestModelRegistration:
|
|
@pytest.mark.asyncio
|
|
async def test_register_unsupported_model(self, inference_stack, inference_model):
|
|
inference_impl, models_impl = inference_stack
|
|
|
|
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
|
if provider.__provider_spec__.provider_type not in (
|
|
"meta-reference",
|
|
"remote::ollama",
|
|
"remote::vllm",
|
|
"remote::tgi",
|
|
):
|
|
pytest.skip(
|
|
"Skipping test for remote inference providers since they can handle large models like 70B instruct"
|
|
)
|
|
|
|
# Try to register a model that's too large for local inference
|
|
with pytest.raises(ValueError) as exc_info:
|
|
await models_impl.register_model(
|
|
model_id="Llama3.1-70B-Instruct",
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_register_nonexistent_model(self, inference_stack):
|
|
_, models_impl = inference_stack
|
|
|
|
# Try to register a non-existent model
|
|
with pytest.raises(Exception) as exc_info:
|
|
await models_impl.register_model(
|
|
model_id="Llama3-NonExistent-Model",
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_register_with_llama_model(self, inference_stack):
|
|
_, models_impl = inference_stack
|
|
|
|
_ = await models_impl.register_model(
|
|
model_id="custom-model",
|
|
metadata={
|
|
"llama_model": "meta-llama/Llama-2-7b",
|
|
"skip_load": True,
|
|
},
|
|
)
|
|
|
|
with pytest.raises(ValueError) as exc_info:
|
|
await models_impl.register_model(
|
|
model_id="custom-model-2",
|
|
metadata={
|
|
"llama_model": "meta-llama/Llama-2-7b",
|
|
},
|
|
provider_model_id="custom-model",
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_initialize_model_during_registering(self, inference_stack):
|
|
_, models_impl = inference_stack
|
|
|
|
with patch(
|
|
"llama_stack.providers.inline.inference.meta_reference.inference.MetaReferenceInferenceImpl.load_model",
|
|
new_callable=AsyncMock,
|
|
) as mock_load_model:
|
|
_ = await models_impl.register_model(
|
|
model_id="Llama3.1-8B-Instruct",
|
|
metadata={
|
|
"llama_model": "meta-llama/Llama-3.1-8B-Instruct",
|
|
},
|
|
)
|
|
mock_load_model.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_register_with_invalid_llama_model(self, inference_stack):
|
|
_, models_impl = inference_stack
|
|
|
|
with pytest.raises(ValueError) as exc_info:
|
|
await models_impl.register_model(
|
|
model_id="custom-model-2",
|
|
metadata={"llama_model": "invalid-llama-model"},
|
|
)
|