llama-stack/llama_stack/providers/tests/inference/fixtures.py
snova-edwardm 22dc684da6
Sambanova inference provider (#555)
# What does this PR do?

This PR adds SambaNova as one of the Provider

- Add SambaNova as a provider

## Test Plan
Test the functional command
```
pytest -s -v --providers inference=sambanova llama_stack/providers/tests/inference/test_embeddings.py llama_stack/providers/tests/inference/test_prompt_adapter.py llama_stack/providers/tests/inference/test_text_inference.py llama_stack/providers/tests/inference/test_vision_inference.py --env SAMBANOVA_API_KEY=<sambanova-api-key>
```

Test the distribution template:
```
# Docker
LLAMA_STACK_PORT=5001
docker run -it -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
  llamastack/distribution-sambanova \
  --port $LLAMA_STACK_PORT \
  --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY

# Conda
llama stack build --template sambanova --image-type conda
llama stack run ./run.yaml \
  --port $LLAMA_STACK_PORT \
  --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY
```

## Source
[SambaNova API Documentation](https://cloud.sambanova.ai/apis)

## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [Y] Ran pre-commit to handle lint / formatting issues.
- [Y] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [Y] Updated relevant documentation.
- [Y ] Wrote necessary unit or integration tests.

---------

Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
2025-01-23 12:20:28 -08:00

335 lines
9.9 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import pytest
import pytest_asyncio
from llama_stack.apis.models import ModelInput, ModelType
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.inference.meta_reference import (
MetaReferenceInferenceConfig,
)
from llama_stack.providers.inline.inference.vllm import VLLMConfig
from llama_stack.providers.remote.inference.bedrock import BedrockConfig
from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
from llama_stack.providers.remote.inference.groq import GroqConfig
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig
from llama_stack.providers.remote.inference.tgi import TGIImplConfig
from llama_stack.providers.remote.inference.together import TogetherImplConfig
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture, remote_stack_fixture
from ..env import get_env_or_fail
@pytest.fixture(scope="session")
def inference_model(request):
if hasattr(request, "param"):
return request.param
return request.config.getoption("--inference-model", None)
@pytest.fixture(scope="session")
def inference_remote() -> ProviderFixture:
return remote_stack_fixture()
@pytest.fixture(scope="session")
def inference_meta_reference(inference_model) -> ProviderFixture:
inference_model = (
[inference_model] if isinstance(inference_model, str) else inference_model
)
# If embedding dimension is set, use the 8B model for testing
if os.getenv("EMBEDDING_DIMENSION"):
inference_model = ["meta-llama/Llama-3.1-8B-Instruct"]
return ProviderFixture(
providers=[
Provider(
provider_id=f"meta-reference-{i}",
provider_type="inline::meta-reference",
config=MetaReferenceInferenceConfig(
model=m,
max_seq_len=4096,
create_distributed_process_group=False,
checkpoint_dir=os.getenv("MODEL_CHECKPOINT_DIR", None),
).model_dump(),
)
for i, m in enumerate(inference_model)
]
)
@pytest.fixture(scope="session")
def inference_cerebras() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="cerebras",
provider_type="remote::cerebras",
config=CerebrasImplConfig(
api_key=get_env_or_fail("CEREBRAS_API_KEY"),
).model_dump(),
)
],
)
@pytest.fixture(scope="session")
def inference_ollama(inference_model) -> ProviderFixture:
inference_model = (
[inference_model] if isinstance(inference_model, str) else inference_model
)
if inference_model and "Llama3.1-8B-Instruct" in inference_model:
pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing")
return ProviderFixture(
providers=[
Provider(
provider_id="ollama",
provider_type="remote::ollama",
config=OllamaImplConfig(
host="localhost", port=os.getenv("OLLAMA_PORT", 11434)
).model_dump(),
)
],
)
@pytest_asyncio.fixture(scope="session")
def inference_vllm(inference_model) -> ProviderFixture:
inference_model = (
[inference_model] if isinstance(inference_model, str) else inference_model
)
return ProviderFixture(
providers=[
Provider(
provider_id=f"vllm-{i}",
provider_type="inline::vllm",
config=VLLMConfig(
model=m,
enforce_eager=True, # Make test run faster
).model_dump(),
)
for i, m in enumerate(inference_model)
]
)
@pytest.fixture(scope="session")
def inference_vllm_remote() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="remote::vllm",
provider_type="remote::vllm",
config=VLLMInferenceAdapterConfig(
url=get_env_or_fail("VLLM_URL"),
max_tokens=int(os.getenv("VLLM_MAX_TOKENS", 2048)),
).model_dump(),
)
],
)
@pytest.fixture(scope="session")
def inference_fireworks() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="fireworks",
provider_type="remote::fireworks",
config=FireworksImplConfig(
api_key=get_env_or_fail("FIREWORKS_API_KEY"),
).model_dump(),
)
],
)
@pytest.fixture(scope="session")
def inference_together() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="together",
provider_type="remote::together",
config=TogetherImplConfig().model_dump(),
)
],
provider_data=dict(
together_api_key=get_env_or_fail("TOGETHER_API_KEY"),
),
)
@pytest.fixture(scope="session")
def inference_groq() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="groq",
provider_type="remote::groq",
config=GroqConfig().model_dump(),
)
],
provider_data=dict(
groq_api_key=get_env_or_fail("GROQ_API_KEY"),
),
)
@pytest.fixture(scope="session")
def inference_bedrock() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="bedrock",
provider_type="remote::bedrock",
config=BedrockConfig().model_dump(),
)
],
)
@pytest.fixture(scope="session")
def inference_nvidia() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="nvidia",
provider_type="remote::nvidia",
config=NVIDIAConfig().model_dump(),
)
],
)
@pytest.fixture(scope="session")
def inference_tgi() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="tgi",
provider_type="remote::tgi",
config=TGIImplConfig(
url=get_env_or_fail("TGI_URL"),
api_token=os.getenv("TGI_API_TOKEN", None),
).model_dump(),
)
],
)
@pytest.fixture(scope="session")
def inference_sambanova() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="sambanova",
provider_type="remote::sambanova",
config=SambaNovaImplConfig(
api_key=get_env_or_fail("SAMBANOVA_API_KEY"),
).model_dump(),
)
],
provider_data=dict(
sambanova_api_key=get_env_or_fail("SAMBANOVA_API_KEY"),
),
)
def inference_sentence_transformers() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="sentence_transformers",
provider_type="inline::sentence-transformers",
config={},
)
]
)
def get_model_short_name(model_name: str) -> str:
"""Convert model name to a short test identifier.
Args:
model_name: Full model name like "Llama3.1-8B-Instruct"
Returns:
Short name like "llama_8b" suitable for test markers
"""
model_name = model_name.lower()
if "vision" in model_name:
return "llama_vision"
elif "3b" in model_name:
return "llama_3b"
elif "8b" in model_name:
return "llama_8b"
else:
return model_name.replace(".", "_").replace("-", "_")
@pytest.fixture(scope="session")
def model_id(inference_model) -> str:
return get_model_short_name(inference_model)
INFERENCE_FIXTURES = [
"meta_reference",
"ollama",
"fireworks",
"together",
"vllm",
"groq",
"vllm_remote",
"remote",
"bedrock",
"cerebras",
"nvidia",
"tgi",
"sambanova",
]
@pytest_asyncio.fixture(scope="session")
async def inference_stack(request, inference_model):
fixture_name = request.param
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
model_type = ModelType.llm
metadata = {}
if os.getenv("EMBEDDING_DIMENSION"):
model_type = ModelType.embedding
metadata["embedding_dimension"] = get_env_or_fail("EMBEDDING_DIMENSION")
test_stack = await construct_stack_for_test(
[Api.inference],
{"inference": inference_fixture.providers},
inference_fixture.provider_data,
models=[
ModelInput(
provider_id=inference_fixture.providers[0].provider_id,
model_id=inference_model,
model_type=model_type,
metadata=metadata,
)
],
)
# Pytest yield fixture; see https://docs.pytest.org/en/stable/how-to/fixtures.html#yield-fixtures-recommended
yield test_stack.impls[Api.inference], test_stack.impls[Api.models]
# Cleanup code that runs after test case completion
await test_stack.impls[Api.inference].shutdown()