Merge branch 'main' into inference_refactor

This commit is contained in:
Botao Chen 2024-12-17 20:10:23 -08:00
commit fadb7deae5
79 changed files with 1547 additions and 2026 deletions

View file

@ -113,6 +113,7 @@ def inference_vllm_remote() -> ProviderFixture:
provider_type="remote::vllm",
config=VLLMInferenceAdapterConfig(
url=get_env_or_fail("VLLM_URL"),
max_tokens=int(os.getenv("VLLM_MAX_TOKENS", 2048)),
).model_dump(),
)
],
@ -192,6 +193,19 @@ def inference_tgi() -> ProviderFixture:
)
@pytest.fixture(scope="session")
def inference_sentence_transformers() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="sentence_transformers",
provider_type="inline::sentence-transformers",
config={},
)
]
)
def get_model_short_name(model_name: str) -> str:
"""Convert model name to a short test identifier.

View file

@ -7,16 +7,19 @@
from pathlib import Path
import pytest
from PIL import Image as PIL_Image
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem, URL
from .utils import group_chunks
THIS_DIR = Path(__file__).parent
with open(THIS_DIR / "pasta.jpeg", "rb") as f:
PASTA_IMAGE = f.read()
class TestVisionModelInference:
@pytest.mark.asyncio
@ -24,12 +27,12 @@ class TestVisionModelInference:
"image, expected_strings",
[
(
ImageMedia(image=PIL_Image.open(THIS_DIR / "pasta.jpeg")),
ImageContentItem(data=PASTA_IMAGE),
["spaghetti"],
),
(
ImageMedia(
image=URL(
ImageContentItem(
url=URL(
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
)
),
@ -58,7 +61,12 @@ class TestVisionModelInference:
model_id=inference_model,
messages=[
UserMessage(content="You are a helpful assistant."),
UserMessage(content=[image, "Describe this image in two sentences."]),
UserMessage(
content=[
image,
TextContentItem(text="Describe this image in two sentences."),
]
),
],
stream=False,
sampling_params=SamplingParams(max_tokens=100),
@ -89,8 +97,8 @@ class TestVisionModelInference:
)
images = [
ImageMedia(
image=URL(
ImageContentItem(
url=URL(
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
)
),
@ -106,7 +114,12 @@ class TestVisionModelInference:
messages=[
UserMessage(content="You are a helpful assistant."),
UserMessage(
content=[image, "Describe this image in two sentences."]
content=[
image,
TextContentItem(
text="Describe this image in two sentences."
),
]
),
],
stream=True,