Merge branch 'main' into sambanova-inferene

This commit is contained in:
snova-edwardm 2025-01-14 10:04:52 -08:00 committed by GitHub
commit 89ab2be302
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
385 changed files with 39001 additions and 9280 deletions

View file

@ -9,16 +9,18 @@ import os
import pytest
import pytest_asyncio
from llama_stack.apis.models import ModelInput
from llama_stack.apis.models import ModelInput, ModelType
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.inference.meta_reference import (
MetaReferenceInferenceConfig,
)
from llama_stack.providers.inline.inference.vllm import VLLMConfig
from llama_stack.providers.remote.inference.bedrock import BedrockConfig
from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
from llama_stack.providers.remote.inference.groq import GroqConfig
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig
@ -48,6 +50,9 @@ def inference_meta_reference(inference_model) -> ProviderFixture:
inference_model = (
[inference_model] if isinstance(inference_model, str) else inference_model
)
# If embedding dimension is set, use the 8B model for testing
if os.getenv("EMBEDDING_DIMENSION"):
inference_model = ["meta-llama/Llama-3.1-8B-Instruct"]
return ProviderFixture(
providers=[
@ -86,7 +91,7 @@ def inference_ollama(inference_model) -> ProviderFixture:
inference_model = (
[inference_model] if isinstance(inference_model, str) else inference_model
)
if "Llama3.1-8B-Instruct" in inference_model:
if inference_model and "Llama3.1-8B-Instruct" in inference_model:
pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing")
return ProviderFixture(
@ -102,6 +107,26 @@ def inference_ollama(inference_model) -> ProviderFixture:
)
@pytest_asyncio.fixture(scope="session")
def inference_vllm(inference_model) -> ProviderFixture:
inference_model = (
[inference_model] if isinstance(inference_model, str) else inference_model
)
return ProviderFixture(
providers=[
Provider(
provider_id=f"vllm-{i}",
provider_type="inline::vllm",
config=VLLMConfig(
model=m,
enforce_eager=True, # Make test run faster
).model_dump(),
)
for i, m in enumerate(inference_model)
]
)
@pytest.fixture(scope="session")
def inference_vllm_remote() -> ProviderFixture:
return ProviderFixture(
@ -111,6 +136,7 @@ def inference_vllm_remote() -> ProviderFixture:
provider_type="remote::vllm",
config=VLLMInferenceAdapterConfig(
url=get_env_or_fail("VLLM_URL"),
max_tokens=int(os.getenv("VLLM_MAX_TOKENS", 2048)),
).model_dump(),
)
],
@ -148,6 +174,22 @@ def inference_together() -> ProviderFixture:
)
@pytest.fixture(scope="session")
def inference_groq() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="groq",
provider_type="remote::groq",
config=GroqConfig().model_dump(),
)
],
provider_data=dict(
groq_api_key=get_env_or_fail("GROQ_API_KEY"),
),
)
@pytest.fixture(scope="session")
def inference_bedrock() -> ProviderFixture:
return ProviderFixture(
@ -208,6 +250,18 @@ def inference_sambanova() -> ProviderFixture:
)
def inference_sentence_transformers() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="sentence_transformers",
provider_type="inline::sentence-transformers",
config={},
)
]
)
def get_model_short_name(model_name: str) -> str:
"""Convert model name to a short test identifier.
@ -238,6 +292,8 @@ INFERENCE_FIXTURES = [
"ollama",
"fireworks",
"together",
"vllm",
"groq",
"vllm_remote",
"remote",
"bedrock",
@ -252,11 +308,27 @@ INFERENCE_FIXTURES = [
async def inference_stack(request, inference_model):
fixture_name = request.param
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
model_type = ModelType.llm
metadata = {}
if os.getenv("EMBEDDING_DIMENSION"):
model_type = ModelType.embedding
metadata["embedding_dimension"] = get_env_or_fail("EMBEDDING_DIMENSION")
test_stack = await construct_stack_for_test(
[Api.inference],
{"inference": inference_fixture.providers},
inference_fixture.provider_data,
models=[ModelInput(model_id=inference_model)],
models=[
ModelInput(
model_id=inference_model,
model_type=model_type,
metadata=metadata,
)
],
)
return test_stack.impls[Api.inference], test_stack.impls[Api.models]
# Pytest yield fixture; see https://docs.pytest.org/en/stable/how-to/fixtures.html#yield-fixtures-recommended
yield test_stack.impls[Api.inference], test_stack.impls[Api.models]
# Cleanup code that runs after test case completion
await test_stack.impls[Api.inference].shutdown()