mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-30 19:53:44 +00:00
feat: introduce llama4 support (#1877)
As title says. Details in README, elsewhere.
This commit is contained in:
parent
23a99a4b22
commit
b8f1561956
61 changed files with 205222 additions and 6439 deletions
|
@ -8,6 +8,7 @@ from typing import Any, Dict
|
|||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from llama_stack_client import Agent, AgentEventLogger, Document
|
||||
from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig
|
||||
|
||||
|
@ -21,7 +22,7 @@ from llama_stack.apis.agents.agents import (
|
|||
|
||||
def get_boiling_point(liquid_name: str, celcius: bool = True) -> int:
|
||||
"""
|
||||
Returns the boiling point of a liquid in Celcius or Fahrenheit
|
||||
Returns the boiling point of a liquid in Celcius or Fahrenheit.
|
||||
|
||||
:param liquid_name: The name of the liquid
|
||||
:param celcius: Whether to return the boiling point in Celcius
|
||||
|
@ -185,7 +186,7 @@ def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent
|
|||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Search the web and tell me what is the local time in Tokyo currently.",
|
||||
"content": "Who are the latest board members to join Meta's board of directors?",
|
||||
}
|
||||
],
|
||||
session_id=session_id,
|
||||
|
@ -429,19 +430,28 @@ def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_t
|
|||
|
||||
|
||||
def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, agent_config):
|
||||
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
|
||||
urls = ["llama3.rst", "lora_finetune.rst"]
|
||||
documents = [
|
||||
# passign as url
|
||||
Document(
|
||||
document_id=f"num-{i}",
|
||||
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
||||
document_id="num-0",
|
||||
content={
|
||||
"type": "url",
|
||||
"uri": f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{urls[0]}",
|
||||
},
|
||||
mime_type="text/plain",
|
||||
metadata={},
|
||||
)
|
||||
for i, url in enumerate(urls)
|
||||
),
|
||||
# passing as str
|
||||
Document(
|
||||
document_id="num-1",
|
||||
content=requests.get(
|
||||
f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{urls[1]}"
|
||||
).text[:500],
|
||||
mime_type="text/plain",
|
||||
metadata={},
|
||||
),
|
||||
]
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
}
|
||||
rag_agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
|
||||
session_id = rag_agent.create_session(f"test-session-{uuid4()}")
|
||||
user_prompts = [
|
||||
|
@ -456,7 +466,7 @@ def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, ag
|
|||
documents,
|
||||
),
|
||||
(
|
||||
"Tell me how to use LoRA",
|
||||
"Tell me how to use LoRA in 100 words or less",
|
||||
None,
|
||||
),
|
||||
]
|
||||
|
@ -478,6 +488,9 @@ def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, ag
|
|||
|
||||
|
||||
def test_rag_and_code_agent(llama_stack_client_with_mocked_inference, agent_config):
|
||||
if "llama-4" in agent_config["model"].lower():
|
||||
pytest.xfail("Not working for llama4")
|
||||
|
||||
documents = []
|
||||
documents.append(
|
||||
Document(
|
||||
|
@ -544,7 +557,7 @@ def test_rag_and_code_agent(llama_stack_client_with_mocked_inference, agent_conf
|
|||
stream=False,
|
||||
)
|
||||
tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution")
|
||||
assert tool_execution_step.tool_calls[0].tool_name == tool_name
|
||||
assert tool_execution_step.tool_calls[0].tool_name == tool_name, f"Failed on {prompt}"
|
||||
if expected_kw:
|
||||
assert expected_kw in response.output_message.content.lower()
|
||||
|
||||
|
@ -565,18 +578,22 @@ def test_create_turn_response(llama_stack_client_with_mocked_inference, agent_co
|
|||
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
||||
input_prompt = f"Call {client_tools[0].__name__} tool and answer What is the boiling point of polyjuice?"
|
||||
response = agent.create_turn(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Call get_boiling_point and answer What is the boiling point of polyjuice?",
|
||||
"content": input_prompt,
|
||||
},
|
||||
],
|
||||
session_id=session_id,
|
||||
stream=False,
|
||||
)
|
||||
assert len(response.input_messages) == 1
|
||||
assert input_prompt == response.input_messages[0].content
|
||||
|
||||
steps = response.steps
|
||||
assert len(steps) == 3
|
||||
assert len(steps) >= 3 # some models call the tool twice
|
||||
assert steps[0].step_type == "inference"
|
||||
assert steps[1].step_type == "tool_execution"
|
||||
assert steps[1].tool_calls[0].tool_name.startswith("get_boiling_point")
|
||||
|
|
|
@ -23,7 +23,12 @@ def skip_if_model_doesnt_support_completion(client_with_models, model_id):
|
|||
provider_id = models[model_id].provider_id
|
||||
providers = {p.provider_id: p for p in client_with_models.providers.list()}
|
||||
provider = providers[provider_id]
|
||||
if provider.provider_type in ("remote::openai", "remote::anthropic", "remote::gemini", "remote::groq"):
|
||||
if provider.provider_type in (
|
||||
"remote::openai",
|
||||
"remote::anthropic",
|
||||
"remote::gemini",
|
||||
"remote::groq",
|
||||
):
|
||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support completion")
|
||||
|
||||
|
||||
|
|
|
@ -4,11 +4,15 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import base64
|
||||
import pathlib
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
THIS_DIR = Path(__file__).parent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def image_path():
|
||||
|
@ -27,7 +31,6 @@ def base64_image_url(base64_image_data, image_path):
|
|||
return f"data:image/{image_path.suffix[1:]};base64,{base64_image_data}"
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="This test is failing because the image is not being downloaded correctly.")
|
||||
def test_image_chat_completion_non_streaming(client_with_models, vision_model_id):
|
||||
message = {
|
||||
"role": "user",
|
||||
|
@ -56,7 +59,99 @@ def test_image_chat_completion_non_streaming(client_with_models, vision_model_id
|
|||
assert any(expected in message_content for expected in {"dog", "puppy", "pup"})
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="This test is failing because the image is not being downloaded correctly.")
|
||||
@pytest.fixture
|
||||
def multi_image_data():
|
||||
files = [
|
||||
THIS_DIR / "vision_test_1.jpg",
|
||||
THIS_DIR / "vision_test_2.jpg",
|
||||
THIS_DIR / "vision_test_3.jpg",
|
||||
]
|
||||
encoded_files = []
|
||||
for file in files:
|
||||
with open(file, "rb") as image_file:
|
||||
base64_data = base64.b64encode(image_file.read()).decode("utf-8")
|
||||
encoded_files.append(base64_data)
|
||||
return encoded_files
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stream", [True, False])
|
||||
def test_image_chat_completion_multiple_images(client_with_models, vision_model_id, multi_image_data, stream):
|
||||
if "llama-4" not in vision_model_id.lower() and "gpt-4o" not in vision_model_id.lower():
|
||||
pytest.skip("Skip for non-llama4, gpt4o models")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": {
|
||||
"data": multi_image_data[0],
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"image": {
|
||||
"data": multi_image_data[1],
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What are the differences between these images? Where would you assume they would be located?",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
response = client_with_models.inference.chat_completion(
|
||||
model_id=vision_model_id,
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
)
|
||||
if stream:
|
||||
message_content = ""
|
||||
for chunk in response:
|
||||
message_content += chunk.event.delta.text
|
||||
else:
|
||||
message_content = response.completion_message.content
|
||||
assert len(message_content) > 0
|
||||
assert any(expected in message_content.lower().strip() for expected in {"bedroom"}), message_content
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": message_content}],
|
||||
"stop_reason": "end_of_turn",
|
||||
}
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": {
|
||||
"data": multi_image_data[2],
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "How about this one?"},
|
||||
],
|
||||
},
|
||||
)
|
||||
response = client_with_models.inference.chat_completion(
|
||||
model_id=vision_model_id,
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
)
|
||||
if stream:
|
||||
message_content = ""
|
||||
for chunk in response:
|
||||
message_content += chunk.event.delta.text
|
||||
else:
|
||||
message_content = response.completion_message.content
|
||||
assert len(message_content) > 0
|
||||
assert any(expected in message_content.lower().strip() for expected in {"sword", "shield"}), message_content
|
||||
|
||||
|
||||
def test_image_chat_completion_streaming(client_with_models, vision_model_id):
|
||||
message = {
|
||||
"role": "user",
|
||||
|
|
BIN
tests/integration/inference/vision_test_1.jpg
Normal file
BIN
tests/integration/inference/vision_test_1.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 108 KiB |
BIN
tests/integration/inference/vision_test_2.jpg
Normal file
BIN
tests/integration/inference/vision_test_2.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 148 KiB |
BIN
tests/integration/inference/vision_test_3.jpg
Normal file
BIN
tests/integration/inference/vision_test_3.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 139 KiB |
Loading…
Add table
Add a link
Reference in a new issue