mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-31 10:23:52 +00:00
Merge branch 'meta-llama:main' into main
This commit is contained in:
commit
1c23d72dac
56 changed files with 2918 additions and 1732 deletions
|
|
@ -22,7 +22,7 @@ The `llamastack/distribution-fireworks` distribution consists of the following p
|
|||
| safety | `inline::llama-guard` |
|
||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||
| telemetry | `inline::meta-reference` |
|
||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `remote::wolfram-alpha`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ The `llamastack/distribution-ollama` distribution consists of the following prov
|
|||
| safety | `inline::llama-guard` |
|
||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||
| telemetry | `inline::meta-reference` |
|
||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
|
||||
| vector_io | `inline::sqlite-vec`, `remote::chromadb`, `remote::pgvector` |
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ The `llamastack/distribution-remote-vllm` distribution consists of the following
|
|||
| safety | `inline::llama-guard` |
|
||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||
| telemetry | `inline::meta-reference` |
|
||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
|
||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ The `llamastack/distribution-together` distribution consists of the following pr
|
|||
| safety | `inline::llama-guard` |
|
||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||
| telemetry | `inline::meta-reference` |
|
||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
|
||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -366,7 +366,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
|||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||
f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
|
||||
)
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import secrets
|
|||
import string
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
|
@ -31,7 +31,6 @@ from llama_stack.apis.agents import (
|
|||
AgentTurnResponseStreamChunk,
|
||||
AgentTurnResponseTurnAwaitingInputPayload,
|
||||
AgentTurnResponseTurnCompletePayload,
|
||||
AgentTurnResponseTurnStartPayload,
|
||||
AgentTurnResumeRequest,
|
||||
Attachment,
|
||||
Document,
|
||||
|
|
@ -184,115 +183,49 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("request", request.model_dump_json())
|
||||
assert request.stream is True, "Non-streaming not supported"
|
||||
|
||||
session_info = await self.storage.get_session_info(request.session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {request.session_id} not found")
|
||||
|
||||
turns = await self.storage.get_session_turns(request.session_id)
|
||||
messages = await self.get_messages_from_turns(turns)
|
||||
messages.extend(request.messages)
|
||||
|
||||
turn_id = str(uuid.uuid4())
|
||||
span.set_attribute("turn_id", turn_id)
|
||||
start_time = datetime.now().astimezone().isoformat()
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnStartPayload(
|
||||
turn_id=turn_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
steps = []
|
||||
output_message = None
|
||||
async for chunk in self.run(
|
||||
session_id=request.session_id,
|
||||
turn_id=turn_id,
|
||||
input_messages=messages,
|
||||
sampling_params=self.agent_config.sampling_params,
|
||||
stream=request.stream,
|
||||
documents=request.documents,
|
||||
toolgroups_for_turn=request.toolgroups,
|
||||
):
|
||||
if isinstance(chunk, CompletionMessage):
|
||||
logcat.info(
|
||||
"agents",
|
||||
f"returning result from the agent turn: {chunk}",
|
||||
)
|
||||
output_message = chunk
|
||||
continue
|
||||
|
||||
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
|
||||
event = chunk.event
|
||||
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
|
||||
steps.append(event.payload.step_details)
|
||||
|
||||
async for chunk in self._run_turn(request, turn_id):
|
||||
yield chunk
|
||||
|
||||
assert output_message is not None
|
||||
|
||||
turn = Turn(
|
||||
turn_id=turn_id,
|
||||
session_id=request.session_id,
|
||||
input_messages=request.messages,
|
||||
output_message=output_message,
|
||||
started_at=start_time,
|
||||
completed_at=datetime.now().astimezone().isoformat(),
|
||||
steps=steps,
|
||||
)
|
||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||
if output_message.tool_calls:
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnAwaitingInputPayload(
|
||||
turn=turn,
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnCompletePayload(
|
||||
turn=turn,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
yield chunk
|
||||
|
||||
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
|
||||
with tracing.span("resume_turn") as span:
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("turn_id", request.turn_id)
|
||||
span.set_attribute("request", request.model_dump_json())
|
||||
assert request.stream is True, "Non-streaming not supported"
|
||||
async for chunk in self._run_turn(request):
|
||||
yield chunk
|
||||
|
||||
session_info = await self.storage.get_session_info(request.session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {request.session_id} not found")
|
||||
async def _run_turn(
|
||||
self,
|
||||
request: Union[AgentTurnCreateRequest, AgentTurnResumeRequest],
|
||||
turn_id: Optional[str] = None,
|
||||
) -> AsyncGenerator:
|
||||
assert request.stream is True, "Non-streaming not supported"
|
||||
|
||||
turns = await self.storage.get_session_turns(request.session_id)
|
||||
if len(turns) == 0:
|
||||
raise ValueError("No turns found for session")
|
||||
is_resume = isinstance(request, AgentTurnResumeRequest)
|
||||
session_info = await self.storage.get_session_info(request.session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {request.session_id} not found")
|
||||
|
||||
messages = await self.get_messages_from_turns(turns)
|
||||
turns = await self.storage.get_session_turns(request.session_id)
|
||||
if is_resume and len(turns) == 0:
|
||||
raise ValueError("No turns found for session")
|
||||
|
||||
steps = []
|
||||
messages = await self.get_messages_from_turns(turns)
|
||||
if is_resume:
|
||||
messages.extend(request.tool_responses)
|
||||
|
||||
last_turn = turns[-1]
|
||||
last_turn_messages = self.turn_to_messages(last_turn)
|
||||
last_turn_messages = [
|
||||
x for x in last_turn_messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
|
||||
]
|
||||
|
||||
# TODO: figure out whether we should add the tool responses to the last turn messages
|
||||
last_turn_messages.extend(request.tool_responses)
|
||||
|
||||
# get the steps from the turn id
|
||||
steps = []
|
||||
steps = turns[-1].steps
|
||||
# get steps from the turn
|
||||
steps = last_turn.steps
|
||||
|
||||
# mark tool execution step as complete
|
||||
# if there's no tool execution in progress step (due to storage, or tool call parsing on client),
|
||||
|
|
@ -326,62 +259,67 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
)
|
||||
input_messages = last_turn_messages
|
||||
|
||||
output_message = None
|
||||
async for chunk in self.run(
|
||||
session_id=request.session_id,
|
||||
turn_id=request.turn_id,
|
||||
input_messages=messages,
|
||||
sampling_params=self.agent_config.sampling_params,
|
||||
stream=request.stream,
|
||||
):
|
||||
if isinstance(chunk, CompletionMessage):
|
||||
output_message = chunk
|
||||
continue
|
||||
turn_id = request.turn_id
|
||||
start_time = last_turn.started_at
|
||||
else:
|
||||
messages.extend(request.messages)
|
||||
start_time = datetime.now().astimezone().isoformat()
|
||||
input_messages = request.messages
|
||||
|
||||
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
|
||||
event = chunk.event
|
||||
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
|
||||
steps.append(event.payload.step_details)
|
||||
output_message = None
|
||||
async for chunk in self.run(
|
||||
session_id=request.session_id,
|
||||
turn_id=turn_id,
|
||||
input_messages=messages,
|
||||
sampling_params=self.agent_config.sampling_params,
|
||||
stream=request.stream,
|
||||
documents=request.documents if not is_resume else None,
|
||||
toolgroups_for_turn=request.toolgroups if not is_resume else None,
|
||||
):
|
||||
if isinstance(chunk, CompletionMessage):
|
||||
output_message = chunk
|
||||
continue
|
||||
|
||||
yield chunk
|
||||
|
||||
assert output_message is not None
|
||||
|
||||
last_turn_start_time = datetime.now().astimezone().isoformat()
|
||||
if len(turns) > 0:
|
||||
last_turn_start_time = turns[-1].started_at
|
||||
|
||||
turn = Turn(
|
||||
turn_id=request.turn_id,
|
||||
session_id=request.session_id,
|
||||
input_messages=last_turn_messages,
|
||||
output_message=output_message,
|
||||
started_at=last_turn_start_time,
|
||||
completed_at=datetime.now().astimezone().isoformat(),
|
||||
steps=steps,
|
||||
)
|
||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||
|
||||
if output_message.tool_calls:
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnAwaitingInputPayload(
|
||||
turn=turn,
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnCompletePayload(
|
||||
turn=turn,
|
||||
)
|
||||
)
|
||||
)
|
||||
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
|
||||
event = chunk.event
|
||||
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
|
||||
steps.append(event.payload.step_details)
|
||||
|
||||
yield chunk
|
||||
|
||||
assert output_message is not None
|
||||
|
||||
turn = Turn(
|
||||
turn_id=turn_id,
|
||||
session_id=request.session_id,
|
||||
input_messages=input_messages,
|
||||
output_message=output_message,
|
||||
started_at=start_time,
|
||||
completed_at=datetime.now().astimezone().isoformat(),
|
||||
steps=steps,
|
||||
)
|
||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||
if output_message.tool_calls:
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnAwaitingInputPayload(
|
||||
turn=turn,
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnCompletePayload(
|
||||
turn=turn,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
yield chunk
|
||||
|
||||
async def run(
|
||||
self,
|
||||
session_id: str,
|
||||
|
|
|
|||
|
|
@ -1,29 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
|
||||
from .fixtures import DATASETIO_FIXTURES
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
for fixture_name in DATASETIO_FIXTURES:
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
f"{fixture_name}: marks tests as {fixture_name} specific",
|
||||
)
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc):
|
||||
if "datasetio_stack" in metafunc.fixturenames:
|
||||
metafunc.parametrize(
|
||||
"datasetio_stack",
|
||||
[
|
||||
pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name))
|
||||
for fixture_name in DATASETIO_FIXTURES
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
|
|
@ -1,61 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||
|
||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def datasetio_remote() -> ProviderFixture:
|
||||
return remote_stack_fixture()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def datasetio_localfs() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="localfs",
|
||||
provider_type="inline::localfs",
|
||||
config={},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def datasetio_huggingface() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="huggingface",
|
||||
provider_type="remote::huggingface",
|
||||
config={},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
DATASETIO_FIXTURES = ["localfs", "remote", "huggingface"]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def datasetio_stack(request):
|
||||
fixture_name = request.param
|
||||
fixture = request.getfixturevalue(f"datasetio_{fixture_name}")
|
||||
|
||||
test_stack = await construct_stack_for_test(
|
||||
[Api.datasetio],
|
||||
{"datasetio": fixture.providers},
|
||||
fixture.provider_data,
|
||||
)
|
||||
|
||||
return test_stack.impls[Api.datasetio], test_stack.impls[Api.datasets]
|
||||
|
|
@ -1,134 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
import mimetypes
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.common.type_system import ChatCompletionInputType, StringType
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
# pytest llama_stack/providers/tests/datasetio/test_datasetio.py
|
||||
# -m "meta_reference"
|
||||
# -v -s --tb=short --disable-warnings
|
||||
|
||||
|
||||
def data_url_from_file(file_path: str) -> str:
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
with open(file_path, "rb") as file:
|
||||
file_content = file.read()
|
||||
|
||||
base64_content = base64.b64encode(file_content).decode("utf-8")
|
||||
mime_type, _ = mimetypes.guess_type(file_path)
|
||||
|
||||
data_url = f"data:{mime_type};base64,{base64_content}"
|
||||
|
||||
return data_url
|
||||
|
||||
|
||||
async def register_dataset(
|
||||
datasets_impl: Datasets,
|
||||
for_generation=False,
|
||||
for_rag=False,
|
||||
dataset_id="test_dataset",
|
||||
):
|
||||
if for_rag:
|
||||
test_file = Path(os.path.abspath(__file__)).parent / "test_rag_dataset.csv"
|
||||
else:
|
||||
test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv"
|
||||
test_url = data_url_from_file(str(test_file))
|
||||
|
||||
if for_generation:
|
||||
dataset_schema = {
|
||||
"expected_answer": StringType(),
|
||||
"input_query": StringType(),
|
||||
"chat_completion_input": ChatCompletionInputType(),
|
||||
}
|
||||
elif for_rag:
|
||||
dataset_schema = {
|
||||
"expected_answer": StringType(),
|
||||
"input_query": StringType(),
|
||||
"generated_answer": StringType(),
|
||||
"context": StringType(),
|
||||
}
|
||||
else:
|
||||
dataset_schema = {
|
||||
"expected_answer": StringType(),
|
||||
"input_query": StringType(),
|
||||
"generated_answer": StringType(),
|
||||
}
|
||||
|
||||
await datasets_impl.register_dataset(
|
||||
dataset_id=dataset_id,
|
||||
dataset_schema=dataset_schema,
|
||||
url=URL(uri=test_url),
|
||||
)
|
||||
|
||||
|
||||
class TestDatasetIO:
|
||||
@pytest.mark.asyncio
|
||||
async def test_datasets_list(self, datasetio_stack):
|
||||
# NOTE: this needs you to ensure that you are starting from a clean state
|
||||
# but so far we don't have an unregister API unfortunately, so be careful
|
||||
_, datasets_impl = datasetio_stack
|
||||
response = await datasets_impl.list_datasets()
|
||||
assert isinstance(response, list)
|
||||
assert len(response) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_dataset(self, datasetio_stack):
|
||||
_, datasets_impl = datasetio_stack
|
||||
await register_dataset(datasets_impl)
|
||||
response = await datasets_impl.list_datasets()
|
||||
assert isinstance(response, list)
|
||||
assert len(response) == 1
|
||||
assert response[0].identifier == "test_dataset"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# unregister a dataset that does not exist
|
||||
await datasets_impl.unregister_dataset("test_dataset2")
|
||||
|
||||
await datasets_impl.unregister_dataset("test_dataset")
|
||||
response = await datasets_impl.list_datasets()
|
||||
assert isinstance(response, list)
|
||||
assert len(response) == 0
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await datasets_impl.unregister_dataset("test_dataset")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_rows_paginated(self, datasetio_stack):
|
||||
datasetio_impl, datasets_impl = datasetio_stack
|
||||
await register_dataset(datasets_impl)
|
||||
response = await datasetio_impl.get_rows_paginated(
|
||||
dataset_id="test_dataset",
|
||||
rows_in_page=3,
|
||||
)
|
||||
assert isinstance(response.rows, list)
|
||||
assert len(response.rows) == 3
|
||||
assert response.next_page_token == "3"
|
||||
|
||||
provider = datasetio_impl.routing_table.get_provider_impl("test_dataset")
|
||||
if provider.__provider_spec__.provider_type == "remote":
|
||||
pytest.skip("remote provider doesn't support get_rows_paginated")
|
||||
|
||||
# iterate over all rows
|
||||
response = await datasetio_impl.get_rows_paginated(
|
||||
dataset_id="test_dataset",
|
||||
rows_in_page=2,
|
||||
page_token=response.next_page_token,
|
||||
)
|
||||
assert isinstance(response.rows, list)
|
||||
assert len(response.rows) == 2
|
||||
assert response.next_page_token == "5"
|
||||
|
|
@ -1,92 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
|
||||
from ..agents.fixtures import AGENTS_FIXTURES
|
||||
from ..conftest import get_provider_fixture_overrides
|
||||
from ..datasetio.fixtures import DATASETIO_FIXTURES
|
||||
from ..inference.fixtures import INFERENCE_FIXTURES
|
||||
from ..safety.fixtures import SAFETY_FIXTURES
|
||||
from ..scoring.fixtures import SCORING_FIXTURES
|
||||
from ..tools.fixtures import TOOL_RUNTIME_FIXTURES
|
||||
from ..vector_io.fixtures import VECTOR_IO_FIXTURES
|
||||
from .fixtures import EVAL_FIXTURES
|
||||
|
||||
DEFAULT_PROVIDER_COMBINATIONS = [
|
||||
pytest.param(
|
||||
{
|
||||
"eval": "meta_reference",
|
||||
"scoring": "basic",
|
||||
"datasetio": "localfs",
|
||||
"inference": "fireworks",
|
||||
"agents": "meta_reference",
|
||||
"safety": "llama_guard",
|
||||
"vector_io": "faiss",
|
||||
"tool_runtime": "memory_and_search",
|
||||
},
|
||||
id="meta_reference_eval_fireworks_inference",
|
||||
marks=pytest.mark.meta_reference_eval_fireworks_inference,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"eval": "meta_reference",
|
||||
"scoring": "basic",
|
||||
"datasetio": "localfs",
|
||||
"inference": "together",
|
||||
"agents": "meta_reference",
|
||||
"safety": "llama_guard",
|
||||
"vector_io": "faiss",
|
||||
"tool_runtime": "memory_and_search",
|
||||
},
|
||||
id="meta_reference_eval_together_inference",
|
||||
marks=pytest.mark.meta_reference_eval_together_inference,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"eval": "meta_reference",
|
||||
"scoring": "basic",
|
||||
"datasetio": "huggingface",
|
||||
"inference": "together",
|
||||
"agents": "meta_reference",
|
||||
"safety": "llama_guard",
|
||||
"vector_io": "faiss",
|
||||
"tool_runtime": "memory_and_search",
|
||||
},
|
||||
id="meta_reference_eval_together_inference_huggingface_datasetio",
|
||||
marks=pytest.mark.meta_reference_eval_together_inference_huggingface_datasetio,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
for fixture_name in [
|
||||
"meta_reference_eval_fireworks_inference",
|
||||
"meta_reference_eval_together_inference",
|
||||
"meta_reference_eval_together_inference_huggingface_datasetio",
|
||||
]:
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
f"{fixture_name}: marks tests as {fixture_name} specific",
|
||||
)
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc):
|
||||
if "eval_stack" in metafunc.fixturenames:
|
||||
available_fixtures = {
|
||||
"eval": EVAL_FIXTURES,
|
||||
"scoring": SCORING_FIXTURES,
|
||||
"datasetio": DATASETIO_FIXTURES,
|
||||
"inference": INFERENCE_FIXTURES,
|
||||
"agents": AGENTS_FIXTURES,
|
||||
"safety": SAFETY_FIXTURES,
|
||||
"vector_io": VECTOR_IO_FIXTURES,
|
||||
"tool_runtime": TOOL_RUNTIME_FIXTURES,
|
||||
}
|
||||
combinations = (
|
||||
get_provider_fixture_overrides(metafunc.config, available_fixtures) or DEFAULT_PROVIDER_COMBINATIONS
|
||||
)
|
||||
metafunc.parametrize("eval_stack", combinations, indirect=True)
|
||||
|
|
@ -1,87 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ModelInput, Provider
|
||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||
|
||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def eval_remote() -> ProviderFixture:
|
||||
return remote_stack_fixture()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def eval_meta_reference() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="meta-reference",
|
||||
provider_type="inline::meta-reference",
|
||||
config={},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
EVAL_FIXTURES = ["meta_reference", "remote"]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def eval_stack(
|
||||
request,
|
||||
inference_model,
|
||||
judge_model,
|
||||
tool_group_input_memory,
|
||||
tool_group_input_tavily_search,
|
||||
):
|
||||
fixture_dict = request.param
|
||||
|
||||
providers = {}
|
||||
provider_data = {}
|
||||
for key in [
|
||||
"datasetio",
|
||||
"eval",
|
||||
"scoring",
|
||||
"inference",
|
||||
"agents",
|
||||
"safety",
|
||||
"vector_io",
|
||||
"tool_runtime",
|
||||
]:
|
||||
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||
providers[key] = fixture.providers
|
||||
if fixture.provider_data:
|
||||
provider_data.update(fixture.provider_data)
|
||||
|
||||
test_stack = await construct_stack_for_test(
|
||||
[
|
||||
Api.eval,
|
||||
Api.datasetio,
|
||||
Api.inference,
|
||||
Api.scoring,
|
||||
Api.agents,
|
||||
Api.safety,
|
||||
Api.vector_io,
|
||||
Api.tool_runtime,
|
||||
],
|
||||
providers,
|
||||
provider_data,
|
||||
models=[
|
||||
ModelInput(model_id=model)
|
||||
for model in [
|
||||
inference_model,
|
||||
judge_model,
|
||||
]
|
||||
],
|
||||
tool_groups=[tool_group_input_memory, tool_group_input_tavily_search],
|
||||
)
|
||||
|
||||
return test_stack.impls
|
||||
|
|
@ -1,42 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
|
||||
from ..conftest import get_provider_fixture_overrides
|
||||
from ..datasetio.fixtures import DATASETIO_FIXTURES
|
||||
from .fixtures import POST_TRAINING_FIXTURES
|
||||
|
||||
DEFAULT_PROVIDER_COMBINATIONS = [
|
||||
pytest.param(
|
||||
{
|
||||
"post_training": "torchtune",
|
||||
"datasetio": "huggingface",
|
||||
},
|
||||
id="torchtune_post_training_huggingface_datasetio",
|
||||
marks=pytest.mark.torchtune_post_training_huggingface_datasetio,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
combined_fixtures = "torchtune_post_training_huggingface_datasetio"
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
f"{combined_fixtures}: marks tests as {combined_fixtures} specific",
|
||||
)
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc):
|
||||
if "post_training_stack" in metafunc.fixturenames:
|
||||
available_fixtures = {
|
||||
"eval": POST_TRAINING_FIXTURES,
|
||||
"datasetio": DATASETIO_FIXTURES,
|
||||
}
|
||||
combinations = (
|
||||
get_provider_fixture_overrides(metafunc.config, available_fixtures) or DEFAULT_PROVIDER_COMBINATIONS
|
||||
)
|
||||
metafunc.parametrize("post_training_stack", combinations, indirect=True)
|
||||
|
|
@ -1,72 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.common.type_system import StringType
|
||||
from llama_stack.apis.datasets import DatasetInput
|
||||
from llama_stack.apis.models import ModelInput
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||
|
||||
from ..conftest import ProviderFixture
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def post_training_torchtune() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="torchtune",
|
||||
provider_type="inline::torchtune",
|
||||
config={},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
POST_TRAINING_FIXTURES = ["torchtune"]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def post_training_stack(request):
|
||||
fixture_dict = request.param
|
||||
|
||||
providers = {}
|
||||
provider_data = {}
|
||||
for key in ["post_training", "datasetio"]:
|
||||
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||
providers[key] = fixture.providers
|
||||
if fixture.provider_data:
|
||||
provider_data.update(fixture.provider_data)
|
||||
|
||||
test_stack = await construct_stack_for_test(
|
||||
[Api.post_training, Api.datasetio],
|
||||
providers,
|
||||
provider_data,
|
||||
models=[ModelInput(model_id="meta-llama/Llama-3.2-3B-Instruct")],
|
||||
datasets=[
|
||||
DatasetInput(
|
||||
dataset_id="alpaca",
|
||||
provider_id="huggingface",
|
||||
url=URL(uri="https://huggingface.co/datasets/tatsu-lab/alpaca"),
|
||||
metadata={
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"split": "train",
|
||||
},
|
||||
dataset_schema={
|
||||
"instruction": StringType(),
|
||||
"input": StringType(),
|
||||
"output": StringType(),
|
||||
"text": StringType(),
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
return test_stack.impls[Api.post_training]
|
||||
|
|
@ -1,75 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
|
||||
from ..conftest import get_provider_fixture_overrides
|
||||
from ..datasetio.fixtures import DATASETIO_FIXTURES
|
||||
from ..inference.fixtures import INFERENCE_FIXTURES
|
||||
from .fixtures import SCORING_FIXTURES
|
||||
|
||||
DEFAULT_PROVIDER_COMBINATIONS = [
|
||||
pytest.param(
|
||||
{
|
||||
"scoring": "basic",
|
||||
"datasetio": "localfs",
|
||||
"inference": "together",
|
||||
},
|
||||
id="basic_scoring_together_inference",
|
||||
marks=pytest.mark.basic_scoring_together_inference,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"scoring": "braintrust",
|
||||
"datasetio": "localfs",
|
||||
"inference": "together",
|
||||
},
|
||||
id="braintrust_scoring_together_inference",
|
||||
marks=pytest.mark.braintrust_scoring_together_inference,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"scoring": "llm_as_judge",
|
||||
"datasetio": "localfs",
|
||||
"inference": "together",
|
||||
},
|
||||
id="llm_as_judge_scoring_together_inference",
|
||||
marks=pytest.mark.llm_as_judge_scoring_together_inference,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
for fixture_name in [
|
||||
"basic_scoring_together_inference",
|
||||
"braintrust_scoring_together_inference",
|
||||
"llm_as_judge_scoring_together_inference",
|
||||
]:
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
f"{fixture_name}: marks tests as {fixture_name} specific",
|
||||
)
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc):
|
||||
judge_model = metafunc.config.getoption("--judge-model")
|
||||
if "judge_model" in metafunc.fixturenames:
|
||||
metafunc.parametrize(
|
||||
"judge_model",
|
||||
[pytest.param(judge_model, id="")],
|
||||
indirect=True,
|
||||
)
|
||||
|
||||
if "scoring_stack" in metafunc.fixturenames:
|
||||
available_fixtures = {
|
||||
"scoring": SCORING_FIXTURES,
|
||||
"datasetio": DATASETIO_FIXTURES,
|
||||
"inference": INFERENCE_FIXTURES,
|
||||
}
|
||||
combinations = (
|
||||
get_provider_fixture_overrides(metafunc.config, available_fixtures) or DEFAULT_PROVIDER_COMBINATIONS
|
||||
)
|
||||
metafunc.parametrize("scoring_stack", combinations, indirect=True)
|
||||
|
|
@ -1,100 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.models import ModelInput
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
from llama_stack.providers.inline.scoring.braintrust import BraintrustScoringConfig
|
||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||
|
||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||
from ..env import get_env_or_fail
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def scoring_remote() -> ProviderFixture:
|
||||
return remote_stack_fixture()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def judge_model(request):
|
||||
if hasattr(request, "param"):
|
||||
return request.param
|
||||
return request.config.getoption("--judge-model", None)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def scoring_basic() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="basic",
|
||||
provider_type="inline::basic",
|
||||
config={},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def scoring_braintrust() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="braintrust",
|
||||
provider_type="inline::braintrust",
|
||||
config=BraintrustScoringConfig(
|
||||
openai_api_key=get_env_or_fail("OPENAI_API_KEY"),
|
||||
).model_dump(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def scoring_llm_as_judge() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="llm-as-judge",
|
||||
provider_type="inline::llm-as-judge",
|
||||
config={},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
SCORING_FIXTURES = ["basic", "remote", "braintrust", "llm_as_judge"]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def scoring_stack(request, inference_model, judge_model):
|
||||
fixture_dict = request.param
|
||||
|
||||
providers = {}
|
||||
provider_data = {}
|
||||
for key in ["datasetio", "scoring", "inference"]:
|
||||
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||
providers[key] = fixture.providers
|
||||
if fixture.provider_data:
|
||||
provider_data.update(fixture.provider_data)
|
||||
|
||||
test_stack = await construct_stack_for_test(
|
||||
[Api.scoring, Api.datasetio, Api.inference],
|
||||
providers,
|
||||
provider_data,
|
||||
models=[
|
||||
ModelInput(model_id=model)
|
||||
for model in [
|
||||
inference_model,
|
||||
judge_model,
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
return test_stack.impls
|
||||
|
|
@ -1,213 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
LLMAsJudgeScoringFnParams,
|
||||
RegexParserScoringFnParams,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
# pytest llama_stack/providers/tests/scoring/test_scoring.py
|
||||
# -m "meta_reference"
|
||||
# -v -s --tb=short --disable-warnings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_judge_prompt_template():
|
||||
return "Output a number response in the following format: Score: <number>, where <number> is the number between 0 and 9."
|
||||
|
||||
|
||||
class TestScoring:
|
||||
@pytest.mark.asyncio
|
||||
async def test_scoring_functions_list(self, scoring_stack):
|
||||
# NOTE: this needs you to ensure that you are starting from a clean state
|
||||
# but so far we don't have an unregister API unfortunately, so be careful
|
||||
scoring_functions_impl = scoring_stack[Api.scoring_functions]
|
||||
response = await scoring_functions_impl.list_scoring_functions()
|
||||
assert isinstance(response, list)
|
||||
assert len(response) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scoring_score(self, scoring_stack):
|
||||
(
|
||||
scoring_impl,
|
||||
scoring_functions_impl,
|
||||
datasetio_impl,
|
||||
datasets_impl,
|
||||
) = (
|
||||
scoring_stack[Api.scoring],
|
||||
scoring_stack[Api.scoring_functions],
|
||||
scoring_stack[Api.datasetio],
|
||||
scoring_stack[Api.datasets],
|
||||
)
|
||||
scoring_fns_list = await scoring_functions_impl.list_scoring_functions()
|
||||
provider_id = scoring_fns_list[0].provider_id
|
||||
if provider_id == "llm-as-judge":
|
||||
pytest.skip(f"{provider_id} provider does not support scoring without params")
|
||||
|
||||
await register_dataset(datasets_impl, for_rag=True)
|
||||
response = await datasets_impl.list_datasets()
|
||||
assert len(response) == 1
|
||||
|
||||
# scoring individual rows
|
||||
rows = await datasetio_impl.get_rows_paginated(
|
||||
dataset_id="test_dataset",
|
||||
rows_in_page=3,
|
||||
)
|
||||
assert len(rows.rows) == 3
|
||||
|
||||
scoring_fns_list = await scoring_functions_impl.list_scoring_functions()
|
||||
scoring_functions = {
|
||||
scoring_fns_list[0].identifier: None,
|
||||
}
|
||||
|
||||
response = await scoring_impl.score(
|
||||
input_rows=rows.rows,
|
||||
scoring_functions=scoring_functions,
|
||||
)
|
||||
assert len(response.results) == len(scoring_functions)
|
||||
for x in scoring_functions:
|
||||
assert x in response.results
|
||||
assert len(response.results[x].score_rows) == len(rows.rows)
|
||||
|
||||
# score batch
|
||||
response = await scoring_impl.score_batch(
|
||||
dataset_id="test_dataset",
|
||||
scoring_functions=scoring_functions,
|
||||
)
|
||||
assert len(response.results) == len(scoring_functions)
|
||||
for x in scoring_functions:
|
||||
assert x in response.results
|
||||
assert len(response.results[x].score_rows) == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scoring_score_with_params_llm_as_judge(
|
||||
self, scoring_stack, sample_judge_prompt_template, judge_model
|
||||
):
|
||||
(
|
||||
scoring_impl,
|
||||
scoring_functions_impl,
|
||||
datasetio_impl,
|
||||
datasets_impl,
|
||||
) = (
|
||||
scoring_stack[Api.scoring],
|
||||
scoring_stack[Api.scoring_functions],
|
||||
scoring_stack[Api.datasetio],
|
||||
scoring_stack[Api.datasets],
|
||||
)
|
||||
await register_dataset(datasets_impl, for_rag=True)
|
||||
response = await datasets_impl.list_datasets()
|
||||
assert len(response) == 1
|
||||
|
||||
scoring_fns_list = await scoring_functions_impl.list_scoring_functions()
|
||||
provider_id = scoring_fns_list[0].provider_id
|
||||
if provider_id == "braintrust" or provider_id == "basic":
|
||||
pytest.skip(f"{provider_id} provider does not support scoring with params")
|
||||
|
||||
# scoring individual rows
|
||||
rows = await datasetio_impl.get_rows_paginated(
|
||||
dataset_id="test_dataset",
|
||||
rows_in_page=3,
|
||||
)
|
||||
assert len(rows.rows) == 3
|
||||
|
||||
scoring_functions = {
|
||||
"llm-as-judge::base": LLMAsJudgeScoringFnParams(
|
||||
judge_model=judge_model,
|
||||
prompt_template=sample_judge_prompt_template,
|
||||
judge_score_regexes=[r"Score: (\d+)"],
|
||||
aggregation_functions=[AggregationFunctionType.categorical_count],
|
||||
)
|
||||
}
|
||||
|
||||
response = await scoring_impl.score(
|
||||
input_rows=rows.rows,
|
||||
scoring_functions=scoring_functions,
|
||||
)
|
||||
assert len(response.results) == len(scoring_functions)
|
||||
for x in scoring_functions:
|
||||
assert x in response.results
|
||||
assert len(response.results[x].score_rows) == len(rows.rows)
|
||||
|
||||
# score batch
|
||||
response = await scoring_impl.score_batch(
|
||||
dataset_id="test_dataset",
|
||||
scoring_functions=scoring_functions,
|
||||
)
|
||||
assert len(response.results) == len(scoring_functions)
|
||||
for x in scoring_functions:
|
||||
assert x in response.results
|
||||
assert len(response.results[x].score_rows) == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scoring_score_with_aggregation_functions(
|
||||
self, scoring_stack, sample_judge_prompt_template, judge_model
|
||||
):
|
||||
(
|
||||
scoring_impl,
|
||||
scoring_functions_impl,
|
||||
datasetio_impl,
|
||||
datasets_impl,
|
||||
) = (
|
||||
scoring_stack[Api.scoring],
|
||||
scoring_stack[Api.scoring_functions],
|
||||
scoring_stack[Api.datasetio],
|
||||
scoring_stack[Api.datasets],
|
||||
)
|
||||
await register_dataset(datasets_impl, for_rag=True)
|
||||
rows = await datasetio_impl.get_rows_paginated(
|
||||
dataset_id="test_dataset",
|
||||
rows_in_page=3,
|
||||
)
|
||||
assert len(rows.rows) == 3
|
||||
|
||||
scoring_fns_list = await scoring_functions_impl.list_scoring_functions()
|
||||
scoring_functions = {}
|
||||
aggr_fns = [
|
||||
AggregationFunctionType.accuracy,
|
||||
AggregationFunctionType.median,
|
||||
AggregationFunctionType.categorical_count,
|
||||
AggregationFunctionType.average,
|
||||
]
|
||||
for x in scoring_fns_list:
|
||||
if x.provider_id == "llm-as-judge":
|
||||
aggr_fns = [AggregationFunctionType.categorical_count]
|
||||
scoring_functions[x.identifier] = LLMAsJudgeScoringFnParams(
|
||||
judge_model=judge_model,
|
||||
prompt_template=sample_judge_prompt_template,
|
||||
judge_score_regexes=[r"Score: (\d+)"],
|
||||
aggregation_functions=aggr_fns,
|
||||
)
|
||||
elif x.provider_id == "basic" or x.provider_id == "braintrust":
|
||||
if "regex_parser" in x.identifier:
|
||||
scoring_functions[x.identifier] = RegexParserScoringFnParams(
|
||||
aggregation_functions=aggr_fns,
|
||||
)
|
||||
else:
|
||||
scoring_functions[x.identifier] = BasicScoringFnParams(
|
||||
aggregation_functions=aggr_fns,
|
||||
)
|
||||
else:
|
||||
scoring_functions[x.identifier] = None
|
||||
|
||||
response = await scoring_impl.score(
|
||||
input_rows=rows.rows,
|
||||
scoring_functions=scoring_functions,
|
||||
)
|
||||
|
||||
assert len(response.results) == len(scoring_functions)
|
||||
for x in scoring_functions:
|
||||
assert x in response.results
|
||||
assert len(response.results[x].score_rows) == len(rows.rows)
|
||||
assert len(response.results[x].aggregated_results) == len(aggr_fns)
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -1,48 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
|
||||
from ..conftest import get_provider_fixture_overrides
|
||||
from ..inference.fixtures import INFERENCE_FIXTURES
|
||||
from ..safety.fixtures import SAFETY_FIXTURES
|
||||
from ..vector_io.fixtures import VECTOR_IO_FIXTURES
|
||||
from .fixtures import TOOL_RUNTIME_FIXTURES
|
||||
|
||||
DEFAULT_PROVIDER_COMBINATIONS = [
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "together",
|
||||
"safety": "llama_guard",
|
||||
"vector_io": "faiss",
|
||||
"tool_runtime": "memory_and_search",
|
||||
},
|
||||
id="together",
|
||||
marks=pytest.mark.together,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
for mark in ["together"]:
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
f"{mark}: marks tests as {mark} specific",
|
||||
)
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc):
|
||||
if "tools_stack" in metafunc.fixturenames:
|
||||
available_fixtures = {
|
||||
"inference": INFERENCE_FIXTURES,
|
||||
"safety": SAFETY_FIXTURES,
|
||||
"vector_io": VECTOR_IO_FIXTURES,
|
||||
"tool_runtime": TOOL_RUNTIME_FIXTURES,
|
||||
}
|
||||
combinations = (
|
||||
get_provider_fixture_overrides(metafunc.config, available_fixtures) or DEFAULT_PROVIDER_COMBINATIONS
|
||||
)
|
||||
metafunc.parametrize("tools_stack", combinations, indirect=True)
|
||||
|
|
@ -1,133 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.models import ModelInput, ModelType
|
||||
from llama_stack.apis.tools import ToolGroupInput
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||
|
||||
from ..conftest import ProviderFixture
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tool_runtime_memory_and_search() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="rag-runtime",
|
||||
provider_type="inline::rag-runtime",
|
||||
config={},
|
||||
),
|
||||
Provider(
|
||||
provider_id="tavily-search",
|
||||
provider_type="remote::tavily-search",
|
||||
config={
|
||||
"api_key": os.environ["TAVILY_SEARCH_API_KEY"],
|
||||
},
|
||||
),
|
||||
Provider(
|
||||
provider_id="wolfram-alpha",
|
||||
provider_type="remote::wolfram-alpha",
|
||||
config={
|
||||
"api_key": os.environ["WOLFRAM_ALPHA_API_KEY"],
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tool_group_input_memory() -> ToolGroupInput:
|
||||
return ToolGroupInput(
|
||||
toolgroup_id="builtin::rag",
|
||||
provider_id="rag-runtime",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tool_group_input_tavily_search() -> ToolGroupInput:
|
||||
return ToolGroupInput(
|
||||
toolgroup_id="builtin::web_search",
|
||||
provider_id="tavily-search",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tool_group_input_wolfram_alpha() -> ToolGroupInput:
|
||||
return ToolGroupInput(
|
||||
toolgroup_id="builtin::wolfram_alpha",
|
||||
provider_id="wolfram-alpha",
|
||||
)
|
||||
|
||||
|
||||
TOOL_RUNTIME_FIXTURES = ["memory_and_search"]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def tools_stack(
|
||||
request,
|
||||
inference_model,
|
||||
tool_group_input_memory,
|
||||
tool_group_input_tavily_search,
|
||||
tool_group_input_wolfram_alpha,
|
||||
):
|
||||
fixture_dict = request.param
|
||||
|
||||
providers = {}
|
||||
provider_data = {}
|
||||
for key in ["inference", "vector_io", "tool_runtime"]:
|
||||
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||
providers[key] = fixture.providers
|
||||
if key == "inference":
|
||||
providers[key].append(
|
||||
Provider(
|
||||
provider_id="tools_memory_provider",
|
||||
provider_type="inline::sentence-transformers",
|
||||
config={},
|
||||
)
|
||||
)
|
||||
if fixture.provider_data:
|
||||
provider_data.update(fixture.provider_data)
|
||||
inference_models = inference_model if isinstance(inference_model, list) else [inference_model]
|
||||
models = [
|
||||
ModelInput(
|
||||
model_id=model,
|
||||
model_type=ModelType.llm,
|
||||
provider_id=providers["inference"][0].provider_id,
|
||||
)
|
||||
for model in inference_models
|
||||
]
|
||||
models.append(
|
||||
ModelInput(
|
||||
model_id="all-MiniLM-L6-v2",
|
||||
model_type=ModelType.embedding,
|
||||
provider_id="tools_memory_provider",
|
||||
metadata={"embedding_dimension": 384},
|
||||
)
|
||||
)
|
||||
|
||||
test_stack = await construct_stack_for_test(
|
||||
[
|
||||
Api.tool_groups,
|
||||
Api.inference,
|
||||
Api.vector_io,
|
||||
Api.tool_runtime,
|
||||
],
|
||||
providers,
|
||||
provider_data,
|
||||
models=models,
|
||||
tool_groups=[
|
||||
tool_group_input_tavily_search,
|
||||
tool_group_input_wolfram_alpha,
|
||||
tool_group_input_memory,
|
||||
],
|
||||
)
|
||||
return test_stack
|
||||
|
|
@ -1,109 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.tools import RAGDocument, RAGQueryResult, ToolInvocationResult
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_search_query():
|
||||
return "What are the latest developments in quantum computing?"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_wolfram_alpha_query():
|
||||
return "What is the square root of 16?"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_documents():
|
||||
urls = [
|
||||
"memory_optimizations.rst",
|
||||
"chat.rst",
|
||||
"llama3.rst",
|
||||
"qat_finetune.rst",
|
||||
"lora_finetune.rst",
|
||||
]
|
||||
return [
|
||||
RAGDocument(
|
||||
document_id=f"num-{i}",
|
||||
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
||||
mime_type="text/plain",
|
||||
metadata={},
|
||||
)
|
||||
for i, url in enumerate(urls)
|
||||
]
|
||||
|
||||
|
||||
class TestTools:
|
||||
@pytest.mark.asyncio
|
||||
async def test_web_search_tool(self, tools_stack, sample_search_query):
|
||||
"""Test the web search tool functionality."""
|
||||
if "TAVILY_SEARCH_API_KEY" not in os.environ:
|
||||
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
|
||||
|
||||
tools_impl = tools_stack.impls[Api.tool_runtime]
|
||||
|
||||
# Execute the tool
|
||||
response = await tools_impl.invoke_tool(tool_name="web_search", kwargs={"query": sample_search_query})
|
||||
|
||||
# Verify the response
|
||||
assert isinstance(response, ToolInvocationResult)
|
||||
assert response.content is not None
|
||||
assert len(response.content) > 0
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wolfram_alpha_tool(self, tools_stack, sample_wolfram_alpha_query):
|
||||
"""Test the wolfram alpha tool functionality."""
|
||||
if "WOLFRAM_ALPHA_API_KEY" not in os.environ:
|
||||
pytest.skip("WOLFRAM_ALPHA_API_KEY not set, skipping test")
|
||||
|
||||
tools_impl = tools_stack.impls[Api.tool_runtime]
|
||||
|
||||
response = await tools_impl.invoke_tool(tool_name="wolfram_alpha", kwargs={"query": sample_wolfram_alpha_query})
|
||||
|
||||
# Verify the response
|
||||
assert isinstance(response, ToolInvocationResult)
|
||||
assert response.content is not None
|
||||
assert len(response.content) > 0
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rag_tool(self, tools_stack, sample_documents):
|
||||
"""Test the memory tool functionality."""
|
||||
vector_dbs_impl = tools_stack.impls[Api.vector_dbs]
|
||||
tools_impl = tools_stack.impls[Api.tool_runtime]
|
||||
|
||||
# Register memory bank
|
||||
await vector_dbs_impl.register_vector_db(
|
||||
vector_db_id="test_bank",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
provider_id="faiss",
|
||||
)
|
||||
|
||||
# Insert documents into memory
|
||||
await tools_impl.rag_tool.insert(
|
||||
documents=sample_documents,
|
||||
vector_db_id="test_bank",
|
||||
chunk_size_in_tokens=512,
|
||||
)
|
||||
|
||||
# Execute the memory tool
|
||||
response = await tools_impl.rag_tool.query(
|
||||
content="What are the main topics covered in the documentation?",
|
||||
vector_db_ids=["test_bank"],
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
assert isinstance(response, RAGQueryResult)
|
||||
assert response.content is not None
|
||||
assert len(response.content) > 0
|
||||
|
|
@ -27,6 +27,7 @@ distribution_spec:
|
|||
tool_runtime:
|
||||
- remote::brave-search
|
||||
- remote::tavily-search
|
||||
- remote::wolfram-alpha
|
||||
- inline::code-interpreter
|
||||
- inline::rag-runtime
|
||||
- remote::model-context-protocol
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"tool_runtime": [
|
||||
"remote::brave-search",
|
||||
"remote::tavily-search",
|
||||
"remote::wolfram-alpha",
|
||||
"inline::code-interpreter",
|
||||
"inline::rag-runtime",
|
||||
"remote::model-context-protocol",
|
||||
|
|
@ -77,6 +78,10 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
toolgroup_id="builtin::websearch",
|
||||
provider_id="tavily-search",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::wolfram_alpha",
|
||||
provider_id="wolfram-alpha",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::rag",
|
||||
provider_id="rag-runtime",
|
||||
|
|
|
|||
|
|
@ -86,6 +86,9 @@ providers:
|
|||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: wolfram-alpha
|
||||
provider_type: remote::wolfram-alpha
|
||||
config: {}
|
||||
- provider_id: code-interpreter
|
||||
provider_type: inline::code-interpreter
|
||||
config: {}
|
||||
|
|
@ -225,6 +228,8 @@ benchmarks: []
|
|||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::wolfram_alpha
|
||||
provider_id: wolfram-alpha
|
||||
- toolgroup_id: builtin::rag
|
||||
provider_id: rag-runtime
|
||||
- toolgroup_id: builtin::code_interpreter
|
||||
|
|
|
|||
|
|
@ -80,6 +80,9 @@ providers:
|
|||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: wolfram-alpha
|
||||
provider_type: remote::wolfram-alpha
|
||||
config: {}
|
||||
- provider_id: code-interpreter
|
||||
provider_type: inline::code-interpreter
|
||||
config: {}
|
||||
|
|
@ -214,6 +217,8 @@ benchmarks: []
|
|||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::wolfram_alpha
|
||||
provider_id: wolfram-alpha
|
||||
- toolgroup_id: builtin::rag
|
||||
provider_id: rag-runtime
|
||||
- toolgroup_id: builtin::code_interpreter
|
||||
|
|
|
|||
|
|
@ -29,4 +29,5 @@ distribution_spec:
|
|||
- inline::code-interpreter
|
||||
- inline::rag-runtime
|
||||
- remote::model-context-protocol
|
||||
- remote::wolfram-alpha
|
||||
image_type: conda
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"inline::code-interpreter",
|
||||
"inline::rag-runtime",
|
||||
"remote::model-context-protocol",
|
||||
"remote::wolfram-alpha",
|
||||
],
|
||||
}
|
||||
name = "ollama"
|
||||
|
|
@ -78,6 +79,10 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
toolgroup_id="builtin::code_interpreter",
|
||||
provider_id="code-interpreter",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::wolfram_alpha",
|
||||
provider_id="wolfram-alpha",
|
||||
),
|
||||
]
|
||||
|
||||
return DistributionTemplate(
|
||||
|
|
|
|||
|
|
@ -85,6 +85,9 @@ providers:
|
|||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
config: {}
|
||||
- provider_id: wolfram-alpha
|
||||
provider_type: remote::wolfram-alpha
|
||||
config: {}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db
|
||||
|
|
@ -119,5 +122,7 @@ tool_groups:
|
|||
provider_id: rag-runtime
|
||||
- toolgroup_id: builtin::code_interpreter
|
||||
provider_id: code-interpreter
|
||||
- toolgroup_id: builtin::wolfram_alpha
|
||||
provider_id: wolfram-alpha
|
||||
server:
|
||||
port: 8321
|
||||
|
|
|
|||
|
|
@ -82,6 +82,9 @@ providers:
|
|||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
config: {}
|
||||
- provider_id: wolfram-alpha
|
||||
provider_type: remote::wolfram-alpha
|
||||
config: {}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db
|
||||
|
|
@ -108,5 +111,7 @@ tool_groups:
|
|||
provider_id: rag-runtime
|
||||
- toolgroup_id: builtin::code_interpreter
|
||||
provider_id: code-interpreter
|
||||
- toolgroup_id: builtin::wolfram_alpha
|
||||
provider_id: wolfram-alpha
|
||||
server:
|
||||
port: 8321
|
||||
|
|
|
|||
|
|
@ -30,4 +30,5 @@ distribution_spec:
|
|||
- inline::code-interpreter
|
||||
- inline::rag-runtime
|
||||
- remote::model-context-protocol
|
||||
- remote::wolfram-alpha
|
||||
image_type: conda
|
||||
|
|
|
|||
|
|
@ -96,6 +96,9 @@ providers:
|
|||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
config: {}
|
||||
- provider_id: wolfram-alpha
|
||||
provider_type: remote::wolfram-alpha
|
||||
config: {}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db
|
||||
|
|
@ -126,5 +129,7 @@ tool_groups:
|
|||
provider_id: rag-runtime
|
||||
- toolgroup_id: builtin::code_interpreter
|
||||
provider_id: code-interpreter
|
||||
- toolgroup_id: builtin::wolfram_alpha
|
||||
provider_id: wolfram-alpha
|
||||
server:
|
||||
port: 8321
|
||||
|
|
|
|||
|
|
@ -90,6 +90,9 @@ providers:
|
|||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
config: {}
|
||||
- provider_id: wolfram-alpha
|
||||
provider_type: remote::wolfram-alpha
|
||||
config: {}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db
|
||||
|
|
@ -115,5 +118,7 @@ tool_groups:
|
|||
provider_id: rag-runtime
|
||||
- toolgroup_id: builtin::code_interpreter
|
||||
provider_id: code-interpreter
|
||||
- toolgroup_id: builtin::wolfram_alpha
|
||||
provider_id: wolfram-alpha
|
||||
server:
|
||||
port: 8321
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"inline::code-interpreter",
|
||||
"inline::rag-runtime",
|
||||
"remote::model-context-protocol",
|
||||
"remote::wolfram-alpha",
|
||||
],
|
||||
}
|
||||
name = "remote-vllm"
|
||||
|
|
@ -87,6 +88,10 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
toolgroup_id="builtin::code_interpreter",
|
||||
provider_id="code-interpreter",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::wolfram_alpha",
|
||||
provider_id="wolfram-alpha",
|
||||
),
|
||||
]
|
||||
|
||||
return DistributionTemplate(
|
||||
|
|
|
|||
|
|
@ -30,4 +30,5 @@ distribution_spec:
|
|||
- inline::code-interpreter
|
||||
- inline::rag-runtime
|
||||
- remote::model-context-protocol
|
||||
- remote::wolfram-alpha
|
||||
image_type: conda
|
||||
|
|
|
|||
|
|
@ -95,6 +95,9 @@ providers:
|
|||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
config: {}
|
||||
- provider_id: wolfram-alpha
|
||||
provider_type: remote::wolfram-alpha
|
||||
config: {}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/registry.db
|
||||
|
|
@ -226,5 +229,7 @@ tool_groups:
|
|||
provider_id: rag-runtime
|
||||
- toolgroup_id: builtin::code_interpreter
|
||||
provider_id: code-interpreter
|
||||
- toolgroup_id: builtin::wolfram_alpha
|
||||
provider_id: wolfram-alpha
|
||||
server:
|
||||
port: 8321
|
||||
|
|
|
|||
|
|
@ -89,6 +89,9 @@ providers:
|
|||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
config: {}
|
||||
- provider_id: wolfram-alpha
|
||||
provider_type: remote::wolfram-alpha
|
||||
config: {}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/registry.db
|
||||
|
|
@ -215,5 +218,7 @@ tool_groups:
|
|||
provider_id: rag-runtime
|
||||
- toolgroup_id: builtin::code_interpreter
|
||||
provider_id: code-interpreter
|
||||
- toolgroup_id: builtin::wolfram_alpha
|
||||
provider_id: wolfram-alpha
|
||||
server:
|
||||
port: 8321
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"inline::code-interpreter",
|
||||
"inline::rag-runtime",
|
||||
"remote::model-context-protocol",
|
||||
"remote::wolfram-alpha",
|
||||
],
|
||||
}
|
||||
name = "together"
|
||||
|
|
@ -73,6 +74,10 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
toolgroup_id="builtin::code_interpreter",
|
||||
provider_id="code-interpreter",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::wolfram_alpha",
|
||||
provider_id="wolfram-alpha",
|
||||
),
|
||||
]
|
||||
embedding_model = ModelInput(
|
||||
model_id="all-MiniLM-L6-v2",
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ from llama_stack.distribution.datatypes import Provider, StackRunConfig
|
|||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.stack import replace_env_vars
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.providers.tests.env import get_env_or_fail
|
||||
from llama_stack.env import get_env_or_fail
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
from .fixtures.recordable_mock import RecordableMock
|
||||
|
|
@ -84,6 +84,11 @@ def pytest_addoption(parser):
|
|||
default=None,
|
||||
help="Specify the embedding model to use for testing",
|
||||
)
|
||||
parser.addoption(
|
||||
"--judge-model",
|
||||
default=None,
|
||||
help="Specify the judge model to use for testing",
|
||||
)
|
||||
parser.addoption(
|
||||
"--embedding-dimension",
|
||||
type=int,
|
||||
|
|
@ -109,6 +114,7 @@ def provider_data():
|
|||
"TOGETHER_API_KEY": "together_api_key",
|
||||
"ANTHROPIC_API_KEY": "anthropic_api_key",
|
||||
"GROQ_API_KEY": "groq_api_key",
|
||||
"WOLFRAM_ALPHA_API_KEY": "wolfram_alpha_api_key",
|
||||
}
|
||||
provider_data = {}
|
||||
for key, value in keymap.items():
|
||||
|
|
@ -260,7 +266,9 @@ def inference_provider_type(llama_stack_client):
|
|||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def client_with_models(llama_stack_client, text_model_id, vision_model_id, embedding_model_id, embedding_dimension):
|
||||
def client_with_models(
|
||||
llama_stack_client, text_model_id, vision_model_id, embedding_model_id, embedding_dimension, judge_model_id
|
||||
):
|
||||
client = llama_stack_client
|
||||
|
||||
providers = [p for p in client.providers.list() if p.api == "inference"]
|
||||
|
|
@ -274,6 +282,8 @@ def client_with_models(llama_stack_client, text_model_id, vision_model_id, embed
|
|||
client.models.register(model_id=text_model_id, provider_id=inference_providers[0])
|
||||
if vision_model_id and vision_model_id not in model_ids:
|
||||
client.models.register(model_id=vision_model_id, provider_id=inference_providers[0])
|
||||
if judge_model_id and judge_model_id not in model_ids:
|
||||
client.models.register(model_id=judge_model_id, provider_id=inference_providers[0])
|
||||
|
||||
if embedding_model_id and embedding_dimension and embedding_model_id not in model_ids:
|
||||
# try to find a provider that supports embeddings, if sentence-transformers is not available
|
||||
|
|
@ -328,6 +338,14 @@ def pytest_generate_tests(metafunc):
|
|||
if val is not None:
|
||||
id_parts.append(f"emb={get_short_id(val)}")
|
||||
|
||||
if "judge_model_id" in metafunc.fixturenames:
|
||||
params.append("judge_model_id")
|
||||
val = metafunc.config.getoption("--judge-model")
|
||||
print(f"judge_model_id: {val}")
|
||||
values.append(val)
|
||||
if val is not None:
|
||||
id_parts.append(f"judge={get_short_id(val)}")
|
||||
|
||||
if "embedding_dimension" in metafunc.fixturenames:
|
||||
params.append("embedding_dimension")
|
||||
val = metafunc.config.getoption("--embedding-dimension")
|
||||
|
|
|
|||
118
tests/integration/datasetio/test_datasetio.py
Normal file
118
tests/integration/datasetio/test_datasetio.py
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
import mimetypes
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
# pytest llama_stack/providers/tests/datasetio/test_datasetio.py
|
||||
# -m "meta_reference"
|
||||
# -v -s --tb=short --disable-warnings
|
||||
|
||||
|
||||
def data_url_from_file(file_path: str) -> str:
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
with open(file_path, "rb") as file:
|
||||
file_content = file.read()
|
||||
|
||||
base64_content = base64.b64encode(file_content).decode("utf-8")
|
||||
mime_type, _ = mimetypes.guess_type(file_path)
|
||||
|
||||
data_url = f"data:{mime_type};base64,{base64_content}"
|
||||
|
||||
return data_url
|
||||
|
||||
|
||||
def register_dataset(llama_stack_client, for_generation=False, for_rag=False, dataset_id="test_dataset"):
|
||||
if for_rag:
|
||||
test_file = Path(os.path.abspath(__file__)).parent / "test_rag_dataset.csv"
|
||||
else:
|
||||
test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv"
|
||||
test_url = data_url_from_file(str(test_file))
|
||||
|
||||
if for_generation:
|
||||
dataset_schema = {
|
||||
"expected_answer": {"type": "string"},
|
||||
"input_query": {"type": "string"},
|
||||
"chat_completion_input": {"type": "chat_completion_input"},
|
||||
}
|
||||
elif for_rag:
|
||||
dataset_schema = {
|
||||
"expected_answer": {"type": "string"},
|
||||
"input_query": {"type": "string"},
|
||||
"generated_answer": {"type": "string"},
|
||||
"context": {"type": "string"},
|
||||
}
|
||||
else:
|
||||
dataset_schema = {
|
||||
"expected_answer": {"type": "string"},
|
||||
"input_query": {"type": "string"},
|
||||
"generated_answer": {"type": "string"},
|
||||
}
|
||||
|
||||
llama_stack_client.datasets.register(
|
||||
dataset_id=dataset_id,
|
||||
dataset_schema=dataset_schema,
|
||||
url=dict(uri=test_url),
|
||||
provider_id="localfs",
|
||||
)
|
||||
|
||||
|
||||
def test_datasets_list(llama_stack_client):
|
||||
# NOTE: this needs you to ensure that you are starting from a clean state
|
||||
# but so far we don't have an unregister API unfortunately, so be careful
|
||||
|
||||
response = llama_stack_client.datasets.list()
|
||||
assert isinstance(response, list)
|
||||
assert len(response) == 0
|
||||
|
||||
|
||||
def test_register_dataset(llama_stack_client):
|
||||
register_dataset(llama_stack_client)
|
||||
response = llama_stack_client.datasets.list()
|
||||
assert isinstance(response, list)
|
||||
assert len(response) == 1
|
||||
assert response[0].identifier == "test_dataset"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# unregister a dataset that does not exist
|
||||
llama_stack_client.datasets.unregister("test_dataset2")
|
||||
|
||||
llama_stack_client.datasets.unregister("test_dataset")
|
||||
response = llama_stack_client.datasets.list()
|
||||
assert isinstance(response, list)
|
||||
assert len(response) == 0
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
llama_stack_client.datasets.unregister("test_dataset")
|
||||
|
||||
|
||||
def test_get_rows_paginated(llama_stack_client):
|
||||
register_dataset(llama_stack_client)
|
||||
response = llama_stack_client.datasetio.get_rows_paginated(
|
||||
dataset_id="test_dataset",
|
||||
rows_in_page=3,
|
||||
)
|
||||
assert isinstance(response.rows, list)
|
||||
assert len(response.rows) == 3
|
||||
assert response.next_page_token == "3"
|
||||
|
||||
# iterate over all rows
|
||||
response = llama_stack_client.datasetio.get_rows_paginated(
|
||||
dataset_id="test_dataset",
|
||||
rows_in_page=2,
|
||||
page_token=response.next_page_token,
|
||||
)
|
||||
assert isinstance(response.rows, list)
|
||||
assert len(response.rows) == 2
|
||||
assert response.next_page_token == "5"
|
||||
|
|
@ -10,15 +10,13 @@ import pytest
|
|||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.common.type_system import ChatCompletionInputType, StringType
|
||||
from llama_stack.apis.eval.eval import (
|
||||
AppBenchmarkConfig,
|
||||
BenchmarkBenchmarkConfig,
|
||||
ModelCandidate,
|
||||
)
|
||||
from llama_stack.apis.inference import SamplingParams
|
||||
from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset
|
||||
|
||||
from ..datasetio.test_datasetio import register_dataset
|
||||
from .constants import JUDGE_PROMPT
|
||||
|
||||
# How to run this test:
|
||||
|
|
@ -28,6 +26,7 @@ from .constants import JUDGE_PROMPT
|
|||
# -v -s --tb=short --disable-warnings
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="FIXME FIXME @yanxi0830 this needs to be migrated to use the API")
|
||||
class Testeval:
|
||||
@pytest.mark.asyncio
|
||||
async def test_benchmarks_list(self, eval_stack):
|
||||
|
|
@ -68,7 +67,7 @@ class Testeval:
|
|||
benchmark_id=benchmark_id,
|
||||
input_rows=rows.rows,
|
||||
scoring_functions=scoring_functions,
|
||||
benchmark_config=AppBenchmarkConfig(
|
||||
benchmark_config=dict(
|
||||
eval_candidate=ModelCandidate(
|
||||
model=inference_model,
|
||||
sampling_params=SamplingParams(),
|
||||
|
|
@ -111,7 +110,7 @@ class Testeval:
|
|||
)
|
||||
response = await eval_impl.run_eval(
|
||||
benchmark_id=benchmark_id,
|
||||
benchmark_config=AppBenchmarkConfig(
|
||||
benchmark_config=dict(
|
||||
eval_candidate=ModelCandidate(
|
||||
model=inference_model,
|
||||
sampling_params=SamplingParams(),
|
||||
|
|
@ -169,7 +168,7 @@ class Testeval:
|
|||
benchmark_id = "meta-reference-mmlu"
|
||||
response = await eval_impl.run_eval(
|
||||
benchmark_id=benchmark_id,
|
||||
benchmark_config=BenchmarkBenchmarkConfig(
|
||||
benchmark_config=dict(
|
||||
eval_candidate=ModelCandidate(
|
||||
model=inference_model,
|
||||
sampling_params=SamplingParams(),
|
||||
File diff suppressed because one or more lines are too long
Binary file not shown.
|
|
@ -44,6 +44,15 @@
|
|||
"metadata": null
|
||||
}
|
||||
},
|
||||
"()_[('kwargs', {'session_id': '<UUID>', 'code': 'import pandas as pd\\n\\n# Load the CSV file\\ndf = pd.read_csv(\"<TEMP_FILE>\")\\n\\n# Print the first few rows of the dataframe\\nprint(df.head())\\n\\n# Print information about the dataframe\\nprint(df.info())\\n\\n# Print summary statistics about the dataframe\\nprint(df.describe())'}), ('tool_name', 'code_interpreter')]": {
|
||||
"type": "value",
|
||||
"value": {
|
||||
"content": "error\n[stdout]\n[Errno 2] No such file or directory: 'bwrap'\n[/stdout]\n[stderr]\n[Errno 2] No such file or directory: 'bwrap'\n[/stderr]",
|
||||
"error_code": null,
|
||||
"error_message": null,
|
||||
"metadata": null
|
||||
}
|
||||
},
|
||||
"()_[('kwargs', {'session_id': '<UUID>', 'code': 'import pandas as pd\\n\\n# Load the CSV file\\ndf = pd.read_csv(\"<TEMP_FILE>\")\\n\\n# Print the first few rows of the dataframe\\nprint(df.head())\\n\\n# Print information about the dataframe\\nprint(df.info())\\n\\n# Print summary statistics of the dataframe\\nprint(df.describe())'}), ('tool_name', 'code_interpreter')]": {
|
||||
"type": "value",
|
||||
"value": {
|
||||
|
|
@ -71,15 +80,6 @@
|
|||
"metadata": null
|
||||
}
|
||||
},
|
||||
"()_[('kwargs', {'session_id': '<UUID>', 'code': 'import pandas as pd\\nimport matplotlib.pyplot as plt\\n\\n# Load data\\ndf = pd.read_csv(\"<TEMP_FILE>\")\\n\\n# Convert \\'Year\\' column to datetime\\ndf[\\'Year\\'] = pd.to_datetime(df[\\'Year\\'])\\n\\n# Group by year and calculate average inflation\\naverage_inflation = df.groupby(\\'Year\\')[\\'Inflation\\'].mean().reset_index()\\n\\n# Plot average yearly inflation as a time series\\nplt.figure(figsize=(10,6))\\nplt.plot(average_inflation[\\'Year\\'], average_inflation[\\'Inflation\\'], marker=\\'o\\')\\nplt.title(\\'Average Yearly Inflation\\')\\nplt.xlabel(\\'Year\\')\\nplt.ylabel(\\'Inflation Rate\\')\\nplt.grid(True)\\nplt.show()'}), ('tool_name', 'code_interpreter')]": {
|
||||
"type": "value",
|
||||
"value": {
|
||||
"content": "completed\n[stderr]\nTraceback (most recent call last):\n line 5, in <module>\n from bwrap.core import main\nModuleNotFoundError: No module named 'bwrap.core'\n[/stderr]",
|
||||
"error_code": null,
|
||||
"error_message": null,
|
||||
"metadata": null
|
||||
}
|
||||
},
|
||||
"()_[('kwargs', {'session_id': '<UUID>', 'code': 'import pandas as pd\\nimport matplotlib.pyplot as plt\\n\\n# Load the CSV file\\ndf = pd.read_csv(\"<TEMP_FILE>\")\\n\\n# Convert the \\'Year\\' column to datetime\\ndf[\\'Year\\'] = pd.to_datetime(df[\\'Year\\'], format=\\'%Y\\')\\n\\n# Group by \\'Year\\' and calculate the average inflation\\ndf_avg_inflation = df.groupby(\\'Year\\')[\\'Inflation\\'].mean().reset_index()\\n\\n# Plot the average inflation as a time series\\nplt.figure(figsize=(10,6))\\nplt.plot(df_avg_inflation[\\'Year\\'], df_avg_inflation[\\'Inflation\\'], marker=\\'o\\')\\nplt.title(\\'Average Yearly Inflation\\')\\nplt.xlabel(\\'Year\\')\\nplt.ylabel(\\'Inflation\\')\\nplt.grid(True)\\nplt.show()'}), ('tool_name', 'code_interpreter')]": {
|
||||
"type": "value",
|
||||
"value": {
|
||||
|
|
@ -107,23 +107,23 @@
|
|||
"type": "text"
|
||||
},
|
||||
{
|
||||
"text": "Result 1:\nDocument_id:64211\nContent: .. _lora_finetune_label:\n\n============================\nFine-Tuning Llama2 with LoRA\n============================\n\nThis guide will teach you about `LoRA <https://arxiv.org/abs/2106.09685>`_, a parameter-efficient finetuning technique,\nand show you how you can use torchtune to finetune a Llama2 model with LoRA.\nIf you already know what LoRA is and want to get straight to running\nyour own LoRA finetune in torchtune, you can jump to :ref:`LoRA finetuning recipe in torchtune<lora_recipe_label>`.\n\n.. grid:: 2\n\n .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn\n\n * What LoRA is and how it saves memory during finetuning\n * An overview of LoRA components in torchtune\n * How to run a LoRA finetune using torchtune\n * How to experiment with different LoRA configurations\n\n .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites\n\n * Be familiar with :ref:`torchtune<overview_label>`\n * Make sure to :ref:`install torchtune<install_label>`\n * Make sure you have downloaded the :ref:`Llama2-7B model weights<download_llama_label>`\n\nWhat is LoRA?\n-------------\n\n`LoRA <https://arxiv.org/abs/2106.09685>`_ is an adapter-based method for\nparameter-efficient finetuning that adds trainable low-rank decomposition matrices to different layers of a neural network,\nthen freezes the network's remaining parameters. LoRA is most commonly applied to\ntransformer models, in which case it is common to add the low-rank matrices\nto some of the linear projections in each transformer layer's self-attention.\n\n.. note::\n\n If you're unfamiliar, check out these references for the `definition of rank <https://en.wikipedia.org/wiki/Rank_(linear_algebra)>`_\n and discussion of `low-rank approximations <https://en.wikipedia.org/wiki/Low-rank_approximation>`_.\n\nBy finetuning with LoRA (as opposed to finetuning all model parameters),\nyou can expect to see memory savings due to a substantial reduction in the\nnumber of parameters with gradients. When using an optimizer with momentum,\nlike `AdamW <https://py\n",
|
||||
"text": "Result 1:\nDocument_id:cbc88\nContent: .. _lora_finetune_label:\n\n============================\nFine-Tuning Llama2 with LoRA\n============================\n\nThis guide will teach you about `LoRA <https://arxiv.org/abs/2106.09685>`_, a parameter-efficient finetuning technique,\nand show you how you can use torchtune to finetune a Llama2 model with LoRA.\nIf you already know what LoRA is and want to get straight to running\nyour own LoRA finetune in torchtune, you can jump to :ref:`LoRA finetuning recipe in torchtune<lora_recipe_label>`.\n\n.. grid:: 2\n\n .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn\n\n * What LoRA is and how it saves memory during finetuning\n * An overview of LoRA components in torchtune\n * How to run a LoRA finetune using torchtune\n * How to experiment with different LoRA configurations\n\n .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites\n\n * Be familiar with :ref:`torchtune<overview_label>`\n * Make sure to :ref:`install torchtune<install_label>`\n * Make sure you have downloaded the :ref:`Llama2-7B model weights<download_llama_label>`\n\nWhat is LoRA?\n-------------\n\n`LoRA <https://arxiv.org/abs/2106.09685>`_ is an adapter-based method for\nparameter-efficient finetuning that adds trainable low-rank decomposition matrices to different layers of a neural network,\nthen freezes the network's remaining parameters. LoRA is most commonly applied to\ntransformer models, in which case it is common to add the low-rank matrices\nto some of the linear projections in each transformer layer's self-attention.\n\n.. note::\n\n If you're unfamiliar, check out these references for the `definition of rank <https://en.wikipedia.org/wiki/Rank_(linear_algebra)>`_\n and discussion of `low-rank approximations <https://en.wikipedia.org/wiki/Low-rank_approximation>`_.\n\nBy finetuning with LoRA (as opposed to finetuning all model parameters),\nyou can expect to see memory savings due to a substantial reduction in the\nnumber of parameters with gradients. When using an optimizer with momentum,\nlike `AdamW <https://py\n",
|
||||
"type": "text"
|
||||
},
|
||||
{
|
||||
"text": "Result 2:\nDocument_id:64211\nContent: 06% of all params are trainable.\n\n.. note::\n If you are directly using the LoRA recipe (as detailed :ref:`here<lora_recipe_label>`), you need only pass the\n relevant checkpoint path. Loading model weights and setting trainable parameters will be taken care\n of in the recipe.\n\n\n.. _lora_recipe_label:\n\nLoRA finetuning recipe in torchtune\n-----------------------------------\n\nFinally, we can put it all together and finetune a model using torchtune's `LoRA recipe <https://github.com/pytorch/torchtune/blob/48626d19d2108f92c749411fbd5f0ff140023a25/recipes/lora_finetune.py>`_.\nMake sure that you have first downloaded the Llama2 weights and tokenizer by following :ref:`these instructions<download_llama_label>`.\nYou can then run the following command to perform a LoRA finetune of Llama2-7B with two GPUs (each having VRAM of at least 16GB):\n\n.. code-block:: bash\n\n tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config llama2/7B_lora\n\n.. note::\n Make sure to point to the location of your Llama2 weights and tokenizer. This can be done\n either by adding :code:`checkpointer.checkpoint_files=[my_model_checkpoint_path] tokenizer_checkpoint=my_tokenizer_checkpoint_path`\n or by directly modifying the :code:`7B_lora.yaml` file. See our \"\":ref:`config_tutorial_label`\" recipe\n for more details on how you can easily clone and modify torchtune configs.\n\n.. note::\n You can modify the value of :code:`nproc_per_node` depending on (a) the number of GPUs you have available,\n and (b) the memory constraints of your hardware.\n\nThe preceding command will run a LoRA finetune with torchtune's factory settings, but we may want to experiment a bit.\nLet's take a closer look at some of the :code:`lora_finetune_distributed` config.\n\n.. code-block:: yaml\n\n # Model Arguments\n model:\n _component_: lora_llama2_7b\n lora_attn_modules: ['q_proj', 'v_proj']\n lora_rank: 8\n lora_alpha: 16\n ...\n\nWe see that the\n",
|
||||
"text": "Result 2:\nDocument_id:cbc88\nContent: 06% of all params are trainable.\n\n.. note::\n If you are directly using the LoRA recipe (as detailed :ref:`here<lora_recipe_label>`), you need only pass the\n relevant checkpoint path. Loading model weights and setting trainable parameters will be taken care\n of in the recipe.\n\n\n.. _lora_recipe_label:\n\nLoRA finetuning recipe in torchtune\n-----------------------------------\n\nFinally, we can put it all together and finetune a model using torchtune's `LoRA recipe <https://github.com/pytorch/torchtune/blob/48626d19d2108f92c749411fbd5f0ff140023a25/recipes/lora_finetune.py>`_.\nMake sure that you have first downloaded the Llama2 weights and tokenizer by following :ref:`these instructions<download_llama_label>`.\nYou can then run the following command to perform a LoRA finetune of Llama2-7B with two GPUs (each having VRAM of at least 16GB):\n\n.. code-block:: bash\n\n tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config llama2/7B_lora\n\n.. note::\n Make sure to point to the location of your Llama2 weights and tokenizer. This can be done\n either by adding :code:`checkpointer.checkpoint_files=[my_model_checkpoint_path] tokenizer_checkpoint=my_tokenizer_checkpoint_path`\n or by directly modifying the :code:`7B_lora.yaml` file. See our \"\":ref:`config_tutorial_label`\" recipe\n for more details on how you can easily clone and modify torchtune configs.\n\n.. note::\n You can modify the value of :code:`nproc_per_node` depending on (a) the number of GPUs you have available,\n and (b) the memory constraints of your hardware.\n\nThe preceding command will run a LoRA finetune with torchtune's factory settings, but we may want to experiment a bit.\nLet's take a closer look at some of the :code:`lora_finetune_distributed` config.\n\n.. code-block:: yaml\n\n # Model Arguments\n model:\n _component_: lora_llama2_7b\n lora_attn_modules: ['q_proj', 'v_proj']\n lora_rank: 8\n lora_alpha: 16\n ...\n\nWe see that the\n",
|
||||
"type": "text"
|
||||
},
|
||||
{
|
||||
"text": "Result 3:\nDocument_id:0c95c\nContent: with training with LoRA quickly,\njust specify any config with ``_lora`` in its name, e.g:\n\n.. code-block:: bash\n\n tune run lora_finetune_single_device --config llama3/8B_lora_single_device\n\n\nThere are two sets of parameters to customize LoRA to suit your needs. Firstly, the parameters which control\nwhich linear layers LoRA should be applied to in the model:\n\n* ``lora_attn_modules: List[str]`` accepts a list of strings specifying which layers of the model to apply\n LoRA to:\n\n * ``q_proj`` applies LoRA to the query projection layer.\n * ``k_proj`` applies LoRA to the key projection layer.\n * ``v_proj`` applies LoRA to the value projection layer.\n * ``output_proj`` applies LoRA to the attention output projection layer.\n\n Whilst adding more layers to be fine-tuned may improve model accuracy,\n this will come at the cost of increased memory usage and reduced training speed.\n\n* ``apply_lora_to_mlp: Bool`` applies LoRA to the MLP in each transformer layer.\n* ``apply_lora_to_output: Bool`` applies LoRA to the model's final output projection.\n This is usually a projection to vocabulary space (e.g. in language models), but\n other modelling tasks may have different projections - classifier models will project\n to the number of classes, for example\n\n.. note::\n\n Models which use tied embeddings (such as Gemma and Qwen2 1.5B and 0.5B) for the\n final output projection do not support ``apply_lora_to_output``.\n\nThese are all specified under the ``model`` flag or config entry, i.e:\n\n.. code-block:: bash\n\n tune run lora_finetune_single_device --config llama3/8B_lora_single_device \\\n model.apply_lora_to_mlp=True \\\n model.lora_attn_modules=[\"q_proj\",\"k_proj\",\"v_proj\",\"output_proj\"]\n\n.. code-block:: yaml\n\n model:\n _component_: torchtune.models.llama3.lora_llama3_8b\n apply_lora_to_mlp: True\n model.lora_attn_modules: [\"q_proj\", \"k_proj\", \"v_proj\",\"output_proj\"]\n\nSecondly, parameters which control the scale of the impact of LoRA on the model:\n\n* ``lora_rank: int`` affects the scale of\n",
|
||||
"text": "Result 3:\nDocument_id:8892b\nContent: with training with LoRA quickly,\njust specify any config with ``_lora`` in its name, e.g:\n\n.. code-block:: bash\n\n tune run lora_finetune_single_device --config llama3/8B_lora_single_device\n\n\nThere are two sets of parameters to customize LoRA to suit your needs. Firstly, the parameters which control\nwhich linear layers LoRA should be applied to in the model:\n\n* ``lora_attn_modules: List[str]`` accepts a list of strings specifying which layers of the model to apply\n LoRA to:\n\n * ``q_proj`` applies LoRA to the query projection layer.\n * ``k_proj`` applies LoRA to the key projection layer.\n * ``v_proj`` applies LoRA to the value projection layer.\n * ``output_proj`` applies LoRA to the attention output projection layer.\n\n Whilst adding more layers to be fine-tuned may improve model accuracy,\n this will come at the cost of increased memory usage and reduced training speed.\n\n* ``apply_lora_to_mlp: Bool`` applies LoRA to the MLP in each transformer layer.\n* ``apply_lora_to_output: Bool`` applies LoRA to the model's final output projection.\n This is usually a projection to vocabulary space (e.g. in language models), but\n other modelling tasks may have different projections - classifier models will project\n to the number of classes, for example\n\n.. note::\n\n Models which use tied embeddings (such as Gemma and Qwen2 1.5B and 0.5B) for the\n final output projection do not support ``apply_lora_to_output``.\n\nThese are all specified under the ``model`` flag or config entry, i.e:\n\n.. code-block:: bash\n\n tune run lora_finetune_single_device --config llama3/8B_lora_single_device \\\n model.apply_lora_to_mlp=True \\\n model.lora_attn_modules=[\"q_proj\",\"k_proj\",\"v_proj\",\"output_proj\"]\n\n.. code-block:: yaml\n\n model:\n _component_: torchtune.models.llama3.lora_llama3_8b\n apply_lora_to_mlp: True\n model.lora_attn_modules: [\"q_proj\", \"k_proj\", \"v_proj\",\"output_proj\"]\n\nSecondly, parameters which control the scale of the impact of LoRA on the model:\n\n* ``lora_rank: int`` affects the scale of\n",
|
||||
"type": "text"
|
||||
},
|
||||
{
|
||||
"text": "Result 4:\nDocument_id:64211\nContent: LoRA to Llama2 models\n------------------------------\n\nWith torchtune, we can easily apply LoRA to Llama2 with a variety of different configurations.\nLet's take a look at how to construct Llama2 models in torchtune with and without LoRA.\n\n.. code-block:: python\n\n from torchtune.models.llama2 import llama2_7b, lora_llama2_7b\n\n # Build Llama2 without any LoRA layers\n base_model = llama2_7b()\n\n # The default settings for lora_llama2_7b will match those for llama2_7b\n # We just need to define which layers we want LoRA applied to.\n # Within each self-attention, we can choose from [\"q_proj\", \"k_proj\", \"v_proj\", and \"output_proj\"].\n # We can also set apply_lora_to_mlp=True or apply_lora_to_output=True to apply LoRA to other linear\n # layers outside of the self-attention.\n lora_model = lora_llama2_7b(lora_attn_modules=[\"q_proj\", \"v_proj\"])\n\n.. note::\n\n Calling :func:`lora_llama_2_7b <torchtune.models.llama2.lora_llama2_7b>` alone will not handle the definition of which parameters are trainable.\n See :ref:`below<setting_trainable_params>` for how to do this.\n\nLet's inspect each of these models a bit more closely.\n\n.. code-block:: bash\n\n # Print the first layer's self-attention in the usual Llama2 model\n >>> print(base_model.layers[0].attn)\n MultiHeadAttention(\n (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n (k_proj): Linear(in_features=4096, out_features=4096, bias=False)\n (v_proj): Linear(in_features=4096, out_features=4096, bias=False)\n (output_proj): Linear(in_features=4096, out_features=4096, bias=False)\n (pos_embeddings): RotaryPositionalEmbeddings()\n )\n\n # Print the same for Llama2 with LoRA weights\n >>> print(lora_model.layers[0].attn)\n MultiHeadAttention(\n (q_proj): LoRALinear(\n (dropout): Dropout(p=0.0, inplace=False)\n \n",
|
||||
"text": "Result 4:\nDocument_id:cbc88\nContent: LoRA to Llama2 models\n------------------------------\n\nWith torchtune, we can easily apply LoRA to Llama2 with a variety of different configurations.\nLet's take a look at how to construct Llama2 models in torchtune with and without LoRA.\n\n.. code-block:: python\n\n from torchtune.models.llama2 import llama2_7b, lora_llama2_7b\n\n # Build Llama2 without any LoRA layers\n base_model = llama2_7b()\n\n # The default settings for lora_llama2_7b will match those for llama2_7b\n # We just need to define which layers we want LoRA applied to.\n # Within each self-attention, we can choose from [\"q_proj\", \"k_proj\", \"v_proj\", and \"output_proj\"].\n # We can also set apply_lora_to_mlp=True or apply_lora_to_output=True to apply LoRA to other linear\n # layers outside of the self-attention.\n lora_model = lora_llama2_7b(lora_attn_modules=[\"q_proj\", \"v_proj\"])\n\n.. note::\n\n Calling :func:`lora_llama_2_7b <torchtune.models.llama2.lora_llama2_7b>` alone will not handle the definition of which parameters are trainable.\n See :ref:`below<setting_trainable_params>` for how to do this.\n\nLet's inspect each of these models a bit more closely.\n\n.. code-block:: bash\n\n # Print the first layer's self-attention in the usual Llama2 model\n >>> print(base_model.layers[0].attn)\n MultiHeadAttention(\n (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n (k_proj): Linear(in_features=4096, out_features=4096, bias=False)\n (v_proj): Linear(in_features=4096, out_features=4096, bias=False)\n (output_proj): Linear(in_features=4096, out_features=4096, bias=False)\n (pos_embeddings): RotaryPositionalEmbeddings()\n )\n\n # Print the same for Llama2 with LoRA weights\n >>> print(lora_model.layers[0].attn)\n MultiHeadAttention(\n (q_proj): LoRALinear(\n (dropout): Dropout(p=0.0, inplace=False)\n \n",
|
||||
"type": "text"
|
||||
},
|
||||
{
|
||||
"text": "Result 5:\nDocument_id:1d70c\nContent: ora_finetune_label>`.\nFor more on QLoRA in torchtune, see our :ref:`QLoRA Tutorial <qlora_finetune_label>`.\n\nLet's take a look at how we can fine-tune Llama3-8B-Instruct with LoRA on a single device using torchtune. In this example, we will fine-tune\nfor one epoch on a common instruct dataset for illustrative purposes. The basic command for a single-device LoRA fine-tune is\n\n.. code-block:: bash\n\n tune run lora_finetune_single_device --config llama3/8B_lora_single_device\n\n.. note::\n To see a full list of recipes and their corresponding configs, simply run ``tune ls`` from the command line.\n\nWe can also add :ref:`command-line overrides <cli_override>` as needed, e.g.\n\n.. code-block:: bash\n\n tune run lora_finetune_single_device --config llama3/8B_lora_single_device \\\n checkpointer.checkpoint_dir=<checkpoint_dir> \\\n tokenizer.path=<checkpoint_dir>/tokenizer.model \\\n checkpointer.output_dir=<checkpoint_dir>\n\nThis will load the Llama3-8B-Instruct checkpoint and tokenizer from ``<checkpoint_dir>`` used in the :ref:`tune download <tune_download_label>` command above,\nthen save a final checkpoint in the same directory following the original format. For more details on the\ncheckpoint formats supported in torchtune, see our :ref:`checkpointing deep-dive <understand_checkpointer>`.\n\n.. note::\n To see the full set of configurable parameters for this (and other) configs we can use :ref:`tune cp <tune_cp_cli_label>` to copy (and modify)\n the default config. :ref:`tune cp <tune_cp_cli_label>` can be used with recipe scripts too, in case you want to make more custom changes\n that cannot be achieved by directly modifying existing configurable parameters. For more on :ref:`tune cp <tune_cp_cli_label>` see the section on\n :ref:`modifying configs <tune_cp_label>` in our \":ref:`finetune_llama_label`\" tutorial.\n\nOnce training is complete, the model checkpoints will be saved and their locations will be logged. For\nLoRA fine-tuning, the final checkpoint will contain the merged weights, and a copy of just the (much smaller) LoRA weights\nwill\n",
|
||||
"text": "Result 5:\nDocument_id:9dcb7\nContent: ora_finetune_label>`.\nFor more on QLoRA in torchtune, see our :ref:`QLoRA Tutorial <qlora_finetune_label>`.\n\nLet's take a look at how we can fine-tune Llama3-8B-Instruct with LoRA on a single device using torchtune. In this example, we will fine-tune\nfor one epoch on a common instruct dataset for illustrative purposes. The basic command for a single-device LoRA fine-tune is\n\n.. code-block:: bash\n\n tune run lora_finetune_single_device --config llama3/8B_lora_single_device\n\n.. note::\n To see a full list of recipes and their corresponding configs, simply run ``tune ls`` from the command line.\n\nWe can also add :ref:`command-line overrides <cli_override>` as needed, e.g.\n\n.. code-block:: bash\n\n tune run lora_finetune_single_device --config llama3/8B_lora_single_device \\\n checkpointer.checkpoint_dir=<checkpoint_dir> \\\n tokenizer.path=<checkpoint_dir>/tokenizer.model \\\n checkpointer.output_dir=<checkpoint_dir>\n\nThis will load the Llama3-8B-Instruct checkpoint and tokenizer from ``<checkpoint_dir>`` used in the :ref:`tune download <tune_download_label>` command above,\nthen save a final checkpoint in the same directory following the original format. For more details on the\ncheckpoint formats supported in torchtune, see our :ref:`checkpointing deep-dive <understand_checkpointer>`.\n\n.. note::\n To see the full set of configurable parameters for this (and other) configs we can use :ref:`tune cp <tune_cp_cli_label>` to copy (and modify)\n the default config. :ref:`tune cp <tune_cp_cli_label>` can be used with recipe scripts too, in case you want to make more custom changes\n that cannot be achieved by directly modifying existing configurable parameters. For more on :ref:`tune cp <tune_cp_cli_label>` see the section on\n :ref:`modifying configs <tune_cp_label>` in our \":ref:`finetune_llama_label`\" tutorial.\n\nOnce training is complete, the model checkpoints will be saved and their locations will be logged. For\nLoRA fine-tuning, the final checkpoint will contain the merged weights, and a copy of just the (much smaller) LoRA weights\nwill\n",
|
||||
"type": "text"
|
||||
},
|
||||
{
|
||||
|
|
@ -135,11 +135,11 @@
|
|||
"error_message": null,
|
||||
"metadata": {
|
||||
"document_ids": [
|
||||
"6421150d-d334-4163-a058-3818b2b742e9",
|
||||
"6421150d-d334-4163-a058-3818b2b742e9",
|
||||
"0c95cff3-5612-40cf-a73d-77644a2462d0",
|
||||
"6421150d-d334-4163-a058-3818b2b742e9",
|
||||
"1d70c86d-4cdf-4be9-a1f2-8a271b15ce2c"
|
||||
"cbc884b1-9d88-4d5c-aff4-7a4b3a56618c",
|
||||
"cbc884b1-9d88-4d5c-aff4-7a4b3a56618c",
|
||||
"8892b092-6394-471e-b143-a23c6cc374f8",
|
||||
"cbc884b1-9d88-4d5c-aff4-7a4b3a56618c",
|
||||
"9dcb747d-0627-40cc-a23c-0bee2b6b05af"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
|
@ -307,23 +307,23 @@
|
|||
"type": "text"
|
||||
},
|
||||
{
|
||||
"text": "Result 1:\nDocument_id:7bdfa\nContent: conversational data, :func:`~torchtune.datasets.chat_dataset` seems to be a good fit. For any\ncustom local dataset we always need to specify ``source``, ``data_files``, and ``split`` for any dataset\nbuilder in torchtune. For :func:`~torchtune.datasets.chat_dataset`, we additionally need to specify\n``conversation_column`` and ``conversation_style``. Our data follows the ``\"sharegpt\"`` format, so\nwe can specify that here. Altogether, our :func:`~torchtune.datasets.chat_dataset` call should\nlook like so:\n\n.. code-block:: python\n\n from torchtune.datasets import chat_dataset\n from torchtune.models.llama3 import llama3_tokenizer\n\n tokenizer = llama3_tokenizer(\"/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model\")\n ds = chat_dataset(\n tokenizer=tokenizer,\n source=\"json\",\n data_files=\"data/my_data.json\",\n split=\"train\",\n conversation_column=\"dialogue\",\n conversation_style=\"sharegpt\",\n )\n\n.. code-block:: yaml\n\n # In config\n tokenizer:\n _component_: torchtune.models.llama3.llama3_tokenizer\n path: /tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model\n\n dataset:\n _component_: torchtune.datasets.chat_dataset\n source: json\n data_files: data/my_data.json\n split: train\n conversation_column: dialogue\n conversation_style: sharegpt\n\n.. note::\n You can pass in any keyword argument for `load_dataset <https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset>`_ into all our\n Dataset classes and they will honor them. This is useful for common parameters\n such as specifying the data split with :code:`split` or configuration with\n :code:`name`\n\nIf you needed to add a prompt template, you would simply pass it into the tokenizer.\nSince we're fine-tuning Llama3, the tokenizer will handle all formatting for\nus and prompt templates are optional. Other models such as Mistral's :class:`~torchtune.models.mistral._tokenizer.MistralTokenizer`,\nuse a chat template by default (:class:`~torchtune.models.mistral.MistralChatTemplate`) to format\nall messages according to their `recommendations <https://\n",
|
||||
"text": "Result 1:\nDocument_id:f76dc\nContent: conversational data, :func:`~torchtune.datasets.chat_dataset` seems to be a good fit. For any\ncustom local dataset we always need to specify ``source``, ``data_files``, and ``split`` for any dataset\nbuilder in torchtune. For :func:`~torchtune.datasets.chat_dataset`, we additionally need to specify\n``conversation_column`` and ``conversation_style``. Our data follows the ``\"sharegpt\"`` format, so\nwe can specify that here. Altogether, our :func:`~torchtune.datasets.chat_dataset` call should\nlook like so:\n\n.. code-block:: python\n\n from torchtune.datasets import chat_dataset\n from torchtune.models.llama3 import llama3_tokenizer\n\n tokenizer = llama3_tokenizer(\"/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model\")\n ds = chat_dataset(\n tokenizer=tokenizer,\n source=\"json\",\n data_files=\"data/my_data.json\",\n split=\"train\",\n conversation_column=\"dialogue\",\n conversation_style=\"sharegpt\",\n )\n\n.. code-block:: yaml\n\n # In config\n tokenizer:\n _component_: torchtune.models.llama3.llama3_tokenizer\n path: /tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model\n\n dataset:\n _component_: torchtune.datasets.chat_dataset\n source: json\n data_files: data/my_data.json\n split: train\n conversation_column: dialogue\n conversation_style: sharegpt\n\n.. note::\n You can pass in any keyword argument for `load_dataset <https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset>`_ into all our\n Dataset classes and they will honor them. This is useful for common parameters\n such as specifying the data split with :code:`split` or configuration with\n :code:`name`\n\nIf you needed to add a prompt template, you would simply pass it into the tokenizer.\nSince we're fine-tuning Llama3, the tokenizer will handle all formatting for\nus and prompt templates are optional. Other models such as Mistral's :class:`~torchtune.models.mistral._tokenizer.MistralTokenizer`,\nuse a chat template by default (:class:`~torchtune.models.mistral.MistralChatTemplate`) to format\nall messages according to their `recommendations <https://\n",
|
||||
"type": "text"
|
||||
},
|
||||
{
|
||||
"text": "Result 2:\nDocument_id:64211\nContent: .. _lora_finetune_label:\n\n============================\nFine-Tuning Llama2 with LoRA\n============================\n\nThis guide will teach you about `LoRA <https://arxiv.org/abs/2106.09685>`_, a parameter-efficient finetuning technique,\nand show you how you can use torchtune to finetune a Llama2 model with LoRA.\nIf you already know what LoRA is and want to get straight to running\nyour own LoRA finetune in torchtune, you can jump to :ref:`LoRA finetuning recipe in torchtune<lora_recipe_label>`.\n\n.. grid:: 2\n\n .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn\n\n * What LoRA is and how it saves memory during finetuning\n * An overview of LoRA components in torchtune\n * How to run a LoRA finetune using torchtune\n * How to experiment with different LoRA configurations\n\n .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites\n\n * Be familiar with :ref:`torchtune<overview_label>`\n * Make sure to :ref:`install torchtune<install_label>`\n * Make sure you have downloaded the :ref:`Llama2-7B model weights<download_llama_label>`\n\nWhat is LoRA?\n-------------\n\n`LoRA <https://arxiv.org/abs/2106.09685>`_ is an adapter-based method for\nparameter-efficient finetuning that adds trainable low-rank decomposition matrices to different layers of a neural network,\nthen freezes the network's remaining parameters. LoRA is most commonly applied to\ntransformer models, in which case it is common to add the low-rank matrices\nto some of the linear projections in each transformer layer's self-attention.\n\n.. note::\n\n If you're unfamiliar, check out these references for the `definition of rank <https://en.wikipedia.org/wiki/Rank_(linear_algebra)>`_\n and discussion of `low-rank approximations <https://en.wikipedia.org/wiki/Low-rank_approximation>`_.\n\nBy finetuning with LoRA (as opposed to finetuning all model parameters),\nyou can expect to see memory savings due to a substantial reduction in the\nnumber of parameters with gradients. When using an optimizer with momentum,\nlike `AdamW <https://py\n",
|
||||
"text": "Result 2:\nDocument_id:c4fc3\nContent: .. _lora_finetune_label:\n\n============================\nFine-Tuning Llama2 with LoRA\n============================\n\nThis guide will teach you about `LoRA <https://arxiv.org/abs/2106.09685>`_, a parameter-efficient finetuning technique,\nand show you how you can use torchtune to finetune a Llama2 model with LoRA.\nIf you already know what LoRA is and want to get straight to running\nyour own LoRA finetune in torchtune, you can jump to :ref:`LoRA finetuning recipe in torchtune<lora_recipe_label>`.\n\n.. grid:: 2\n\n .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn\n\n * What LoRA is and how it saves memory during finetuning\n * An overview of LoRA components in torchtune\n * How to run a LoRA finetune using torchtune\n * How to experiment with different LoRA configurations\n\n .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites\n\n * Be familiar with :ref:`torchtune<overview_label>`\n * Make sure to :ref:`install torchtune<install_label>`\n * Make sure you have downloaded the :ref:`Llama2-7B model weights<download_llama_label>`\n\nWhat is LoRA?\n-------------\n\n`LoRA <https://arxiv.org/abs/2106.09685>`_ is an adapter-based method for\nparameter-efficient finetuning that adds trainable low-rank decomposition matrices to different layers of a neural network,\nthen freezes the network's remaining parameters. LoRA is most commonly applied to\ntransformer models, in which case it is common to add the low-rank matrices\nto some of the linear projections in each transformer layer's self-attention.\n\n.. note::\n\n If you're unfamiliar, check out these references for the `definition of rank <https://en.wikipedia.org/wiki/Rank_(linear_algebra)>`_\n and discussion of `low-rank approximations <https://en.wikipedia.org/wiki/Low-rank_approximation>`_.\n\nBy finetuning with LoRA (as opposed to finetuning all model parameters),\nyou can expect to see memory savings due to a substantial reduction in the\nnumber of parameters with gradients. When using an optimizer with momentum,\nlike `AdamW <https://py\n",
|
||||
"type": "text"
|
||||
},
|
||||
{
|
||||
"text": "Result 3:\nDocument_id:0c95c\nContent: ` module, which we swap\n out for :class:`~torchtune.modules.peft.LoRALinear` when ``use_dora=True``.\n\n.. _glossary_distrib:\n\n\n.. TODO\n\n.. Distributed\n.. -----------\n\n.. .. _glossary_fsdp:\n\n.. Fully Sharded Data Parallel (FSDP)\n.. ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n.. All our ``_distributed`` recipes use `FSDP <https://pytorch.org/docs/stable/fsdp.html>`.\n.. .. _glossary_fsdp2:\n\n",
|
||||
"text": "Result 3:\nDocument_id:de2d4\nContent: ` module, which we swap\n out for :class:`~torchtune.modules.peft.LoRALinear` when ``use_dora=True``.\n\n.. _glossary_distrib:\n\n\n.. TODO\n\n.. Distributed\n.. -----------\n\n.. .. _glossary_fsdp:\n\n.. Fully Sharded Data Parallel (FSDP)\n.. ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n.. All our ``_distributed`` recipes use `FSDP <https://pytorch.org/docs/stable/fsdp.html>`.\n.. .. _glossary_fsdp2:\n\n",
|
||||
"type": "text"
|
||||
},
|
||||
{
|
||||
"text": "Result 4:\nDocument_id:64211\nContent: 06% of all params are trainable.\n\n.. note::\n If you are directly using the LoRA recipe (as detailed :ref:`here<lora_recipe_label>`), you need only pass the\n relevant checkpoint path. Loading model weights and setting trainable parameters will be taken care\n of in the recipe.\n\n\n.. _lora_recipe_label:\n\nLoRA finetuning recipe in torchtune\n-----------------------------------\n\nFinally, we can put it all together and finetune a model using torchtune's `LoRA recipe <https://github.com/pytorch/torchtune/blob/48626d19d2108f92c749411fbd5f0ff140023a25/recipes/lora_finetune.py>`_.\nMake sure that you have first downloaded the Llama2 weights and tokenizer by following :ref:`these instructions<download_llama_label>`.\nYou can then run the following command to perform a LoRA finetune of Llama2-7B with two GPUs (each having VRAM of at least 16GB):\n\n.. code-block:: bash\n\n tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config llama2/7B_lora\n\n.. note::\n Make sure to point to the location of your Llama2 weights and tokenizer. This can be done\n either by adding :code:`checkpointer.checkpoint_files=[my_model_checkpoint_path] tokenizer_checkpoint=my_tokenizer_checkpoint_path`\n or by directly modifying the :code:`7B_lora.yaml` file. See our \"\":ref:`config_tutorial_label`\" recipe\n for more details on how you can easily clone and modify torchtune configs.\n\n.. note::\n You can modify the value of :code:`nproc_per_node` depending on (a) the number of GPUs you have available,\n and (b) the memory constraints of your hardware.\n\nThe preceding command will run a LoRA finetune with torchtune's factory settings, but we may want to experiment a bit.\nLet's take a closer look at some of the :code:`lora_finetune_distributed` config.\n\n.. code-block:: yaml\n\n # Model Arguments\n model:\n _component_: lora_llama2_7b\n lora_attn_modules: ['q_proj', 'v_proj']\n lora_rank: 8\n lora_alpha: 16\n ...\n\nWe see that the\n",
|
||||
"text": "Result 4:\nDocument_id:c4fc3\nContent: 06% of all params are trainable.\n\n.. note::\n If you are directly using the LoRA recipe (as detailed :ref:`here<lora_recipe_label>`), you need only pass the\n relevant checkpoint path. Loading model weights and setting trainable parameters will be taken care\n of in the recipe.\n\n\n.. _lora_recipe_label:\n\nLoRA finetuning recipe in torchtune\n-----------------------------------\n\nFinally, we can put it all together and finetune a model using torchtune's `LoRA recipe <https://github.com/pytorch/torchtune/blob/48626d19d2108f92c749411fbd5f0ff140023a25/recipes/lora_finetune.py>`_.\nMake sure that you have first downloaded the Llama2 weights and tokenizer by following :ref:`these instructions<download_llama_label>`.\nYou can then run the following command to perform a LoRA finetune of Llama2-7B with two GPUs (each having VRAM of at least 16GB):\n\n.. code-block:: bash\n\n tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config llama2/7B_lora\n\n.. note::\n Make sure to point to the location of your Llama2 weights and tokenizer. This can be done\n either by adding :code:`checkpointer.checkpoint_files=[my_model_checkpoint_path] tokenizer_checkpoint=my_tokenizer_checkpoint_path`\n or by directly modifying the :code:`7B_lora.yaml` file. See our \"\":ref:`config_tutorial_label`\" recipe\n for more details on how you can easily clone and modify torchtune configs.\n\n.. note::\n You can modify the value of :code:`nproc_per_node` depending on (a) the number of GPUs you have available,\n and (b) the memory constraints of your hardware.\n\nThe preceding command will run a LoRA finetune with torchtune's factory settings, but we may want to experiment a bit.\nLet's take a closer look at some of the :code:`lora_finetune_distributed` config.\n\n.. code-block:: yaml\n\n # Model Arguments\n model:\n _component_: lora_llama2_7b\n lora_attn_modules: ['q_proj', 'v_proj']\n lora_rank: 8\n lora_alpha: 16\n ...\n\nWe see that the\n",
|
||||
"type": "text"
|
||||
},
|
||||
{
|
||||
"text": "Result 5:\nDocument_id:0c95c\nContent: etune\n:func:`torchtune.models.llama3.llama3_8b` with DoRA, you would use :func:`torchtune.models.llama3.lora_llama3_8b` with ``use_dora=True``:\n\n.. code-block:: bash\n\n tune run lora_finetune_single_device --config llama3/8B_lora_single_device \\\n model.use_dora=True\n\n.. code-block:: yaml\n\n model:\n _component_: torchtune.models.lora_llama3_8b\n use_dora: True\n\nSince DoRA extends LoRA, the parameters for :ref:`customizing LoRA <glossary_lora>` are identical. You can also quantize the base model weights like in :ref:`glossary_qlora` by using ``quantize=True`` to reap\neven more memory savings!\n\n.. code-block:: bash\n\n tune run lora_finetune_single_device --config llama3/8B_lora_single_device \\\n model.apply_lora_to_mlp=True \\\n model.lora_attn_modules=[\"q_proj\",\"k_proj\",\"v_proj\"] \\\n model.lora_rank=16 \\\n model.lora_alpha=32 \\\n model.use_dora=True \\\n model.quantize_base=True\n\n.. code-block:: yaml\n\n model:\n _component_: torchtune.models.lora_llama3_8b\n apply_lora_to_mlp: True\n lora_attn_modules: [\"q_proj\", \"k_proj\", \"v_proj\"]\n lora_rank: 16\n lora_alpha: 32\n use_dora: True\n quantize_base: True\n\n\n.. note::\n\n Under the hood, we've enabled DoRA by adding the :class:`~torchtune.modules.peft.DoRALinear` module, which we swap\n out for :class:`~torchtune.modules.peft.LoRALinear` when ``use_dora=True``.\n\n.. _glossary_distrib:\n\n\n.. TODO\n\n.. Distributed\n.. -----------\n\n.. .. _glossary_fsdp:\n\n.. Fully Sharded Data Parallel (FSDP)\n.. ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n.. All our ``_distributed`` recipes use `FSDP <https://pytorch.org/docs/stable/fsdp.html>`.\n.. .. _glossary_fsdp2:\n\n",
|
||||
"text": "Result 5:\nDocument_id:de2d4\nContent: etune\n:func:`torchtune.models.llama3.llama3_8b` with DoRA, you would use :func:`torchtune.models.llama3.lora_llama3_8b` with ``use_dora=True``:\n\n.. code-block:: bash\n\n tune run lora_finetune_single_device --config llama3/8B_lora_single_device \\\n model.use_dora=True\n\n.. code-block:: yaml\n\n model:\n _component_: torchtune.models.lora_llama3_8b\n use_dora: True\n\nSince DoRA extends LoRA, the parameters for :ref:`customizing LoRA <glossary_lora>` are identical. You can also quantize the base model weights like in :ref:`glossary_qlora` by using ``quantize=True`` to reap\neven more memory savings!\n\n.. code-block:: bash\n\n tune run lora_finetune_single_device --config llama3/8B_lora_single_device \\\n model.apply_lora_to_mlp=True \\\n model.lora_attn_modules=[\"q_proj\",\"k_proj\",\"v_proj\"] \\\n model.lora_rank=16 \\\n model.lora_alpha=32 \\\n model.use_dora=True \\\n model.quantize_base=True\n\n.. code-block:: yaml\n\n model:\n _component_: torchtune.models.lora_llama3_8b\n apply_lora_to_mlp: True\n lora_attn_modules: [\"q_proj\", \"k_proj\", \"v_proj\"]\n lora_rank: 16\n lora_alpha: 32\n use_dora: True\n quantize_base: True\n\n\n.. note::\n\n Under the hood, we've enabled DoRA by adding the :class:`~torchtune.modules.peft.DoRALinear` module, which we swap\n out for :class:`~torchtune.modules.peft.LoRALinear` when ``use_dora=True``.\n\n.. _glossary_distrib:\n\n\n.. TODO\n\n.. Distributed\n.. -----------\n\n.. .. _glossary_fsdp:\n\n.. Fully Sharded Data Parallel (FSDP)\n.. ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n.. All our ``_distributed`` recipes use `FSDP <https://pytorch.org/docs/stable/fsdp.html>`.\n.. .. _glossary_fsdp2:\n\n",
|
||||
"type": "text"
|
||||
},
|
||||
{
|
||||
|
|
@ -335,11 +335,11 @@
|
|||
"error_message": null,
|
||||
"metadata": {
|
||||
"document_ids": [
|
||||
"7bdfad34-d546-4e98-9757-a0289696cd97",
|
||||
"6421150d-d334-4163-a058-3818b2b742e9",
|
||||
"0c95cff3-5612-40cf-a73d-77644a2462d0",
|
||||
"6421150d-d334-4163-a058-3818b2b742e9",
|
||||
"0c95cff3-5612-40cf-a73d-77644a2462d0"
|
||||
"f76dc7f5-9648-4272-a579-c8387fb1408a",
|
||||
"c4fc3cb6-6172-489e-90a7-b39d343e14c0",
|
||||
"de2d49de-55de-44dd-9bca-6f4f6d633b0a",
|
||||
"c4fc3cb6-6172-489e-90a7-b39d343e14c0",
|
||||
"de2d49de-55de-44dd-9bca-6f4f6d633b0a"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
|
@ -362,23 +362,23 @@
|
|||
"type": "text"
|
||||
},
|
||||
{
|
||||
"text": "Result 1:\nDocument_id:7da0c\nContent: .. _lora_finetune_label:\n\n============================\nFine-Tuning Llama2 with LoRA\n============================\n\nThis guide will teach you about `LoRA <https://arxiv.org/abs/2106.09685>`_, a parameter-efficient finetuning technique,\nand show you how you can use torchtune to finetune a Llama2 model with LoRA.\nIf you already know what LoRA is and want to get straight to running\nyour own LoRA finetune in torchtune, you can jump to :ref:`LoRA finetuning recipe in torchtune<lora_recipe_label>`.\n\n.. grid:: 2\n\n .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn\n\n * What LoRA is and how it saves memory during finetuning\n * An overview of LoRA components in torchtune\n * How to run a LoRA finetune using torchtune\n * How to experiment with different LoRA configurations\n\n .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites\n\n * Be familiar with :ref:`torchtune<overview_label>`\n * Make sure to :ref:`install torchtune<install_label>`\n * Make sure you have downloaded the :ref:`Llama2-7B model weights<download_llama_label>`\n\nWhat is LoRA?\n-------------\n\n`LoRA <https://arxiv.org/abs/2106.09685>`_ is an adapter-based method for\nparameter-efficient finetuning that adds trainable low-rank decomposition matrices to different layers of a neural network,\nthen freezes the network's remaining parameters. LoRA is most commonly applied to\ntransformer models, in which case it is common to add the low-rank matrices\nto some of the linear projections in each transformer layer's self-attention.\n\n.. note::\n\n If you're unfamiliar, check out these references for the `definition of rank <https://en.wikipedia.org/wiki/Rank_(linear_algebra)>`_\n and discussion of `low-rank approximations <https://en.wikipedia.org/wiki/Low-rank_approximation>`_.\n\nBy finetuning with LoRA (as opposed to finetuning all model parameters),\nyou can expect to see memory savings due to a substantial reduction in the\nnumber of parameters with gradients. When using an optimizer with momentum,\nlike `AdamW <https://py\n",
|
||||
"text": "Result 1:\nDocument_id:c4fc3\nContent: .. _lora_finetune_label:\n\n============================\nFine-Tuning Llama2 with LoRA\n============================\n\nThis guide will teach you about `LoRA <https://arxiv.org/abs/2106.09685>`_, a parameter-efficient finetuning technique,\nand show you how you can use torchtune to finetune a Llama2 model with LoRA.\nIf you already know what LoRA is and want to get straight to running\nyour own LoRA finetune in torchtune, you can jump to :ref:`LoRA finetuning recipe in torchtune<lora_recipe_label>`.\n\n.. grid:: 2\n\n .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn\n\n * What LoRA is and how it saves memory during finetuning\n * An overview of LoRA components in torchtune\n * How to run a LoRA finetune using torchtune\n * How to experiment with different LoRA configurations\n\n .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites\n\n * Be familiar with :ref:`torchtune<overview_label>`\n * Make sure to :ref:`install torchtune<install_label>`\n * Make sure you have downloaded the :ref:`Llama2-7B model weights<download_llama_label>`\n\nWhat is LoRA?\n-------------\n\n`LoRA <https://arxiv.org/abs/2106.09685>`_ is an adapter-based method for\nparameter-efficient finetuning that adds trainable low-rank decomposition matrices to different layers of a neural network,\nthen freezes the network's remaining parameters. LoRA is most commonly applied to\ntransformer models, in which case it is common to add the low-rank matrices\nto some of the linear projections in each transformer layer's self-attention.\n\n.. note::\n\n If you're unfamiliar, check out these references for the `definition of rank <https://en.wikipedia.org/wiki/Rank_(linear_algebra)>`_\n and discussion of `low-rank approximations <https://en.wikipedia.org/wiki/Low-rank_approximation>`_.\n\nBy finetuning with LoRA (as opposed to finetuning all model parameters),\nyou can expect to see memory savings due to a substantial reduction in the\nnumber of parameters with gradients. When using an optimizer with momentum,\nlike `AdamW <https://py\n",
|
||||
"type": "text"
|
||||
},
|
||||
{
|
||||
"text": "Result 2:\nDocument_id:7da0c\nContent: LoRA to Llama2 models\n------------------------------\n\nWith torchtune, we can easily apply LoRA to Llama2 with a variety of different configurations.\nLet's take a look at how to construct Llama2 models in torchtune with and without LoRA.\n\n.. code-block:: python\n\n from torchtune.models.llama2 import llama2_7b, lora_llama2_7b\n\n # Build Llama2 without any LoRA layers\n base_model = llama2_7b()\n\n # The default settings for lora_llama2_7b will match those for llama2_7b\n # We just need to define which layers we want LoRA applied to.\n # Within each self-attention, we can choose from [\"q_proj\", \"k_proj\", \"v_proj\", and \"output_proj\"].\n # We can also set apply_lora_to_mlp=True or apply_lora_to_output=True to apply LoRA to other linear\n # layers outside of the self-attention.\n lora_model = lora_llama2_7b(lora_attn_modules=[\"q_proj\", \"v_proj\"])\n\n.. note::\n\n Calling :func:`lora_llama_2_7b <torchtune.models.llama2.lora_llama2_7b>` alone will not handle the definition of which parameters are trainable.\n See :ref:`below<setting_trainable_params>` for how to do this.\n\nLet's inspect each of these models a bit more closely.\n\n.. code-block:: bash\n\n # Print the first layer's self-attention in the usual Llama2 model\n >>> print(base_model.layers[0].attn)\n MultiHeadAttention(\n (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n (k_proj): Linear(in_features=4096, out_features=4096, bias=False)\n (v_proj): Linear(in_features=4096, out_features=4096, bias=False)\n (output_proj): Linear(in_features=4096, out_features=4096, bias=False)\n (pos_embeddings): RotaryPositionalEmbeddings()\n )\n\n # Print the same for Llama2 with LoRA weights\n >>> print(lora_model.layers[0].attn)\n MultiHeadAttention(\n (q_proj): LoRALinear(\n (dropout): Dropout(p=0.0, inplace=False)\n \n",
|
||||
"text": "Result 2:\nDocument_id:c4fc3\nContent: LoRA to Llama2 models\n------------------------------\n\nWith torchtune, we can easily apply LoRA to Llama2 with a variety of different configurations.\nLet's take a look at how to construct Llama2 models in torchtune with and without LoRA.\n\n.. code-block:: python\n\n from torchtune.models.llama2 import llama2_7b, lora_llama2_7b\n\n # Build Llama2 without any LoRA layers\n base_model = llama2_7b()\n\n # The default settings for lora_llama2_7b will match those for llama2_7b\n # We just need to define which layers we want LoRA applied to.\n # Within each self-attention, we can choose from [\"q_proj\", \"k_proj\", \"v_proj\", and \"output_proj\"].\n # We can also set apply_lora_to_mlp=True or apply_lora_to_output=True to apply LoRA to other linear\n # layers outside of the self-attention.\n lora_model = lora_llama2_7b(lora_attn_modules=[\"q_proj\", \"v_proj\"])\n\n.. note::\n\n Calling :func:`lora_llama_2_7b <torchtune.models.llama2.lora_llama2_7b>` alone will not handle the definition of which parameters are trainable.\n See :ref:`below<setting_trainable_params>` for how to do this.\n\nLet's inspect each of these models a bit more closely.\n\n.. code-block:: bash\n\n # Print the first layer's self-attention in the usual Llama2 model\n >>> print(base_model.layers[0].attn)\n MultiHeadAttention(\n (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n (k_proj): Linear(in_features=4096, out_features=4096, bias=False)\n (v_proj): Linear(in_features=4096, out_features=4096, bias=False)\n (output_proj): Linear(in_features=4096, out_features=4096, bias=False)\n (pos_embeddings): RotaryPositionalEmbeddings()\n )\n\n # Print the same for Llama2 with LoRA weights\n >>> print(lora_model.layers[0].attn)\n MultiHeadAttention(\n (q_proj): LoRALinear(\n (dropout): Dropout(p=0.0, inplace=False)\n \n",
|
||||
"type": "text"
|
||||
},
|
||||
{
|
||||
"text": "Result 3:\nDocument_id:7da0c\nContent: 06% of all params are trainable.\n\n.. note::\n If you are directly using the LoRA recipe (as detailed :ref:`here<lora_recipe_label>`), you need only pass the\n relevant checkpoint path. Loading model weights and setting trainable parameters will be taken care\n of in the recipe.\n\n\n.. _lora_recipe_label:\n\nLoRA finetuning recipe in torchtune\n-----------------------------------\n\nFinally, we can put it all together and finetune a model using torchtune's `LoRA recipe <https://github.com/pytorch/torchtune/blob/48626d19d2108f92c749411fbd5f0ff140023a25/recipes/lora_finetune.py>`_.\nMake sure that you have first downloaded the Llama2 weights and tokenizer by following :ref:`these instructions<download_llama_label>`.\nYou can then run the following command to perform a LoRA finetune of Llama2-7B with two GPUs (each having VRAM of at least 16GB):\n\n.. code-block:: bash\n\n tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config llama2/7B_lora\n\n.. note::\n Make sure to point to the location of your Llama2 weights and tokenizer. This can be done\n either by adding :code:`checkpointer.checkpoint_files=[my_model_checkpoint_path] tokenizer_checkpoint=my_tokenizer_checkpoint_path`\n or by directly modifying the :code:`7B_lora.yaml` file. See our \"\":ref:`config_tutorial_label`\" recipe\n for more details on how you can easily clone and modify torchtune configs.\n\n.. note::\n You can modify the value of :code:`nproc_per_node` depending on (a) the number of GPUs you have available,\n and (b) the memory constraints of your hardware.\n\nThe preceding command will run a LoRA finetune with torchtune's factory settings, but we may want to experiment a bit.\nLet's take a closer look at some of the :code:`lora_finetune_distributed` config.\n\n.. code-block:: yaml\n\n # Model Arguments\n model:\n _component_: lora_llama2_7b\n lora_attn_modules: ['q_proj', 'v_proj']\n lora_rank: 8\n lora_alpha: 16\n ...\n\nWe see that the\n",
|
||||
"text": "Result 3:\nDocument_id:c4fc3\nContent: 06% of all params are trainable.\n\n.. note::\n If you are directly using the LoRA recipe (as detailed :ref:`here<lora_recipe_label>`), you need only pass the\n relevant checkpoint path. Loading model weights and setting trainable parameters will be taken care\n of in the recipe.\n\n\n.. _lora_recipe_label:\n\nLoRA finetuning recipe in torchtune\n-----------------------------------\n\nFinally, we can put it all together and finetune a model using torchtune's `LoRA recipe <https://github.com/pytorch/torchtune/blob/48626d19d2108f92c749411fbd5f0ff140023a25/recipes/lora_finetune.py>`_.\nMake sure that you have first downloaded the Llama2 weights and tokenizer by following :ref:`these instructions<download_llama_label>`.\nYou can then run the following command to perform a LoRA finetune of Llama2-7B with two GPUs (each having VRAM of at least 16GB):\n\n.. code-block:: bash\n\n tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config llama2/7B_lora\n\n.. note::\n Make sure to point to the location of your Llama2 weights and tokenizer. This can be done\n either by adding :code:`checkpointer.checkpoint_files=[my_model_checkpoint_path] tokenizer_checkpoint=my_tokenizer_checkpoint_path`\n or by directly modifying the :code:`7B_lora.yaml` file. See our \"\":ref:`config_tutorial_label`\" recipe\n for more details on how you can easily clone and modify torchtune configs.\n\n.. note::\n You can modify the value of :code:`nproc_per_node` depending on (a) the number of GPUs you have available,\n and (b) the memory constraints of your hardware.\n\nThe preceding command will run a LoRA finetune with torchtune's factory settings, but we may want to experiment a bit.\nLet's take a closer look at some of the :code:`lora_finetune_distributed` config.\n\n.. code-block:: yaml\n\n # Model Arguments\n model:\n _component_: lora_llama2_7b\n lora_attn_modules: ['q_proj', 'v_proj']\n lora_rank: 8\n lora_alpha: 16\n ...\n\nWe see that the\n",
|
||||
"type": "text"
|
||||
},
|
||||
{
|
||||
"text": "Result 4:\nDocument_id:7da0c\nContent: from our Llama2\nmodel without any wrappers or custom checkpoint conversion logic.\n\n.. code-block:: python\n\n # Assuming that base_model already has the pretrained Llama2 weights,\n # this will directly load them into your LoRA model without any conversion necessary.\n lora_model.load_state_dict(base_model.state_dict(), strict=False)\n\n.. note::\n Whenever loading weights with :code:`strict=False`, you should verify that any missing or extra keys in\n the loaded :code:`state_dict` are as expected. torchtune's LoRA recipes do this by default via\n :func:`validate_missing_and_unexpected_for_lora() <torchtune.modules.peft.validate_missing_and_unexpected_for_lora>`.\n\nOnce we've loaded the base model weights, we also want to set only LoRA parameters to trainable.\n\n.. _setting_trainable_params:\n\n.. code-block:: python\n\n from torchtune.modules.peft.peft_utils import get_adapter_params, set_trainable_params\n\n # Fetch all params from the model that are associated with LoRA.\n lora_params = get_adapter_params(lora_model)\n\n # Set requires_grad=True on lora_params, and requires_grad=False on all others.\n set_trainable_params(lora_model, lora_params)\n\n # Print the total number of parameters\n total_params = sum([p.numel() for p in lora_model.parameters()])\n trainable_params = sum([p.numel() for p in lora_model.parameters() if p.requires_grad])\n print(\n f\"\"\"\n {total_params} total params,\n {trainable_params}\" trainable params,\n {(100.0 * trainable_params / total_params):.2f}% of all params are trainable.\n \"\"\"\n )\n\n 6742609920 total params,\n 4194304 trainable params,\n 0.06% of all params are trainable.\n\n.. note::\n If you are directly using the LoRA recipe (as detailed :ref:`here<lora_recipe_label>`), you need only pass the\n relevant checkpoint path. Loading model weights and setting trainable parameters will be taken care\n of in the recipe.\n\n\n.. _lora_recipe_label:\n\nLoRA finetuning recipe in torchtune\n-----------------------------------\n\nFinally, we can put it all together and finetune a model using torchtune's `LoRA recipe <https://github.com/pytorch/torchtune/blob/48626d19d2108f92\n",
|
||||
"text": "Result 4:\nDocument_id:c4fc3\nContent: from our Llama2\nmodel without any wrappers or custom checkpoint conversion logic.\n\n.. code-block:: python\n\n # Assuming that base_model already has the pretrained Llama2 weights,\n # this will directly load them into your LoRA model without any conversion necessary.\n lora_model.load_state_dict(base_model.state_dict(), strict=False)\n\n.. note::\n Whenever loading weights with :code:`strict=False`, you should verify that any missing or extra keys in\n the loaded :code:`state_dict` are as expected. torchtune's LoRA recipes do this by default via\n :func:`validate_missing_and_unexpected_for_lora() <torchtune.modules.peft.validate_missing_and_unexpected_for_lora>`.\n\nOnce we've loaded the base model weights, we also want to set only LoRA parameters to trainable.\n\n.. _setting_trainable_params:\n\n.. code-block:: python\n\n from torchtune.modules.peft.peft_utils import get_adapter_params, set_trainable_params\n\n # Fetch all params from the model that are associated with LoRA.\n lora_params = get_adapter_params(lora_model)\n\n # Set requires_grad=True on lora_params, and requires_grad=False on all others.\n set_trainable_params(lora_model, lora_params)\n\n # Print the total number of parameters\n total_params = sum([p.numel() for p in lora_model.parameters()])\n trainable_params = sum([p.numel() for p in lora_model.parameters() if p.requires_grad])\n print(\n f\"\"\"\n {total_params} total params,\n {trainable_params}\" trainable params,\n {(100.0 * trainable_params / total_params):.2f}% of all params are trainable.\n \"\"\"\n )\n\n 6742609920 total params,\n 4194304 trainable params,\n 0.06% of all params are trainable.\n\n.. note::\n If you are directly using the LoRA recipe (as detailed :ref:`here<lora_recipe_label>`), you need only pass the\n relevant checkpoint path. Loading model weights and setting trainable parameters will be taken care\n of in the recipe.\n\n\n.. _lora_recipe_label:\n\nLoRA finetuning recipe in torchtune\n-----------------------------------\n\nFinally, we can put it all together and finetune a model using torchtune's `LoRA recipe <https://github.com/pytorch/torchtune/blob/48626d19d2108f92\n",
|
||||
"type": "text"
|
||||
},
|
||||
{
|
||||
"text": "Result 5:\nDocument_id:7da0c\nContent: ,\n and (b) the memory constraints of your hardware.\n\nThe preceding command will run a LoRA finetune with torchtune's factory settings, but we may want to experiment a bit.\nLet's take a closer look at some of the :code:`lora_finetune_distributed` config.\n\n.. code-block:: yaml\n\n # Model Arguments\n model:\n _component_: lora_llama2_7b\n lora_attn_modules: ['q_proj', 'v_proj']\n lora_rank: 8\n lora_alpha: 16\n ...\n\nWe see that the default is to apply LoRA to Q and V projections with a rank of 8.\nSome experiments with LoRA have found that it can be beneficial to apply LoRA to all linear layers in\nthe self-attention, and to increase the rank to 16 or 32. Note that this is likely to increase our max memory,\nbut as long as we keep :code:`rank<<embed_dim`, the impact should be relatively minor.\n\nLet's run this experiment. We can also increase alpha (in general it is good practice to scale alpha and rank together).\n\n.. code-block:: bash\n\n tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config llama2/7B_lora \\\n lora_attn_modules=['q_proj','k_proj','v_proj','output_proj'] \\\n lora_rank=32 lora_alpha=64 output_dir=./lora_experiment_1\n\nA comparison of the (smoothed) loss curves between this run and our baseline over the first 500 steps can be seen below.\n\n.. image:: /_static/img/lora_experiment_loss_curves.png\n\n.. note::\n The above figure was generated with W&B. You can use torchtune's :class:`~torchtune.training.metric_logging.WandBLogger`\n to generate similar loss curves, but you will need to install W&B and setup an account separately. For more details on\n using W&B in torchtune, see our \":ref:`wandb_logging`\" recipe.\n\n.. _lora_tutorial_memory_tradeoff_label:\n\nTrading off memory and model performance with LoRA\n--------------------------------------------------\n\nIn the preceding example, we ran LoRA on two devices. But given LoRA's low memory footprint, we can run fine-tuning\non a single device using most commodity GPUs which support `bfloat16 <https://\n",
|
||||
"text": "Result 5:\nDocument_id:c4fc3\nContent: ,\n and (b) the memory constraints of your hardware.\n\nThe preceding command will run a LoRA finetune with torchtune's factory settings, but we may want to experiment a bit.\nLet's take a closer look at some of the :code:`lora_finetune_distributed` config.\n\n.. code-block:: yaml\n\n # Model Arguments\n model:\n _component_: lora_llama2_7b\n lora_attn_modules: ['q_proj', 'v_proj']\n lora_rank: 8\n lora_alpha: 16\n ...\n\nWe see that the default is to apply LoRA to Q and V projections with a rank of 8.\nSome experiments with LoRA have found that it can be beneficial to apply LoRA to all linear layers in\nthe self-attention, and to increase the rank to 16 or 32. Note that this is likely to increase our max memory,\nbut as long as we keep :code:`rank<<embed_dim`, the impact should be relatively minor.\n\nLet's run this experiment. We can also increase alpha (in general it is good practice to scale alpha and rank together).\n\n.. code-block:: bash\n\n tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config llama2/7B_lora \\\n lora_attn_modules=['q_proj','k_proj','v_proj','output_proj'] \\\n lora_rank=32 lora_alpha=64 output_dir=./lora_experiment_1\n\nA comparison of the (smoothed) loss curves between this run and our baseline over the first 500 steps can be seen below.\n\n.. image:: /_static/img/lora_experiment_loss_curves.png\n\n.. note::\n The above figure was generated with W&B. You can use torchtune's :class:`~torchtune.training.metric_logging.WandBLogger`\n to generate similar loss curves, but you will need to install W&B and setup an account separately. For more details on\n using W&B in torchtune, see our \":ref:`wandb_logging`\" recipe.\n\n.. _lora_tutorial_memory_tradeoff_label:\n\nTrading off memory and model performance with LoRA\n--------------------------------------------------\n\nIn the preceding example, we ran LoRA on two devices. But given LoRA's low memory footprint, we can run fine-tuning\non a single device using most commodity GPUs which support `bfloat16 <https://\n",
|
||||
"type": "text"
|
||||
},
|
||||
{
|
||||
|
|
@ -390,11 +390,11 @@
|
|||
"error_message": null,
|
||||
"metadata": {
|
||||
"document_ids": [
|
||||
"7da0c755-7ffa-4c1a-9ab0-cfdda7cce00f",
|
||||
"7da0c755-7ffa-4c1a-9ab0-cfdda7cce00f",
|
||||
"7da0c755-7ffa-4c1a-9ab0-cfdda7cce00f",
|
||||
"7da0c755-7ffa-4c1a-9ab0-cfdda7cce00f",
|
||||
"7da0c755-7ffa-4c1a-9ab0-cfdda7cce00f"
|
||||
"c4fc3cb6-6172-489e-90a7-b39d343e14c0",
|
||||
"c4fc3cb6-6172-489e-90a7-b39d343e14c0",
|
||||
"c4fc3cb6-6172-489e-90a7-b39d343e14c0",
|
||||
"c4fc3cb6-6172-489e-90a7-b39d343e14c0",
|
||||
"c4fc3cb6-6172-489e-90a7-b39d343e14c0"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -26,6 +26,7 @@ from llama_stack.apis.post_training import (
|
|||
# -v -s --tb=short --disable-warnings
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="FIXME FIXME @yanxi0830 this needs to be migrated to use the API")
|
||||
class TestPostTraining:
|
||||
@pytest.mark.asyncio
|
||||
async def test_supervised_fine_tune(self, post_training_stack):
|
||||
|
|
@ -16,6 +16,7 @@ import pytest
|
|||
from pytest import CollectReport
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.env import get_env_or_fail
|
||||
from llama_stack.models.llama.datatypes import CoreModelId
|
||||
from llama_stack.models.llama.sku_list import (
|
||||
all_registered_models,
|
||||
|
|
@ -26,7 +27,6 @@ from llama_stack.models.llama.sku_list import (
|
|||
safety_models,
|
||||
)
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack.providers.tests.env import get_env_or_fail
|
||||
|
||||
from .metadata import API_MAPS
|
||||
|
||||
|
|
|
|||
160
tests/integration/scoring/test_scoring.py
Normal file
160
tests/integration/scoring/test_scoring.py
Normal file
|
|
@ -0,0 +1,160 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from ..datasetio.test_datasetio import register_dataset
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_judge_prompt_template():
|
||||
return "Output a number response in the following format: Score: <number>, where <number> is the number between 0 and 9."
|
||||
|
||||
|
||||
def test_scoring_functions_list(llama_stack_client):
|
||||
# NOTE: this needs you to ensure that you are starting from a clean state
|
||||
# but so far we don't have an unregister API unfortunately, so be careful
|
||||
response = llama_stack_client.scoring_functions.list()
|
||||
assert isinstance(response, list)
|
||||
assert len(response) > 0
|
||||
|
||||
|
||||
def test_scoring_score(llama_stack_client):
|
||||
register_dataset(llama_stack_client, for_rag=True)
|
||||
response = llama_stack_client.datasets.list()
|
||||
assert len(response) == 1
|
||||
|
||||
# scoring individual rows
|
||||
rows = llama_stack_client.datasetio.get_rows_paginated(
|
||||
dataset_id="test_dataset",
|
||||
rows_in_page=3,
|
||||
)
|
||||
assert len(rows.rows) == 3
|
||||
|
||||
scoring_fns_list = llama_stack_client.scoring_functions.list()
|
||||
scoring_functions = {
|
||||
scoring_fns_list[0].identifier: None,
|
||||
}
|
||||
|
||||
response = llama_stack_client.scoring.score(
|
||||
input_rows=rows.rows,
|
||||
scoring_functions=scoring_functions,
|
||||
)
|
||||
assert len(response.results) == len(scoring_functions)
|
||||
for x in scoring_functions:
|
||||
assert x in response.results
|
||||
assert len(response.results[x].score_rows) == len(rows.rows)
|
||||
|
||||
# score batch
|
||||
response = llama_stack_client.scoring.score_batch(
|
||||
dataset_id="test_dataset",
|
||||
scoring_functions=scoring_functions,
|
||||
save_results_dataset=False,
|
||||
)
|
||||
assert len(response.results) == len(scoring_functions)
|
||||
for x in scoring_functions:
|
||||
assert x in response.results
|
||||
assert len(response.results[x].score_rows) == 5
|
||||
|
||||
|
||||
def test_scoring_score_with_params_llm_as_judge(llama_stack_client, sample_judge_prompt_template, judge_model_id):
|
||||
register_dataset(llama_stack_client, for_rag=True)
|
||||
response = llama_stack_client.datasets.list()
|
||||
assert len(response) == 1
|
||||
|
||||
# scoring individual rows
|
||||
rows = llama_stack_client.datasetio.get_rows_paginated(
|
||||
dataset_id="test_dataset",
|
||||
rows_in_page=3,
|
||||
)
|
||||
assert len(rows.rows) == 3
|
||||
|
||||
scoring_functions = {
|
||||
"llm-as-judge::base": dict(
|
||||
type="llm_as_judge",
|
||||
judge_model=judge_model_id,
|
||||
prompt_template=sample_judge_prompt_template,
|
||||
judge_score_regexes=[r"Score: (\d+)"],
|
||||
aggregation_functions=[
|
||||
"categorical_count",
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
response = llama_stack_client.scoring.score(
|
||||
input_rows=rows.rows,
|
||||
scoring_functions=scoring_functions,
|
||||
)
|
||||
assert len(response.results) == len(scoring_functions)
|
||||
for x in scoring_functions:
|
||||
assert x in response.results
|
||||
assert len(response.results[x].score_rows) == len(rows.rows)
|
||||
|
||||
# score batch
|
||||
response = llama_stack_client.scoring.score_batch(
|
||||
dataset_id="test_dataset",
|
||||
scoring_functions=scoring_functions,
|
||||
save_results_dataset=False,
|
||||
)
|
||||
assert len(response.results) == len(scoring_functions)
|
||||
for x in scoring_functions:
|
||||
assert x in response.results
|
||||
assert len(response.results[x].score_rows) == 5
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Skipping because this seems to be really slow")
|
||||
def test_scoring_score_with_aggregation_functions(llama_stack_client, sample_judge_prompt_template, judge_model_id):
|
||||
register_dataset(llama_stack_client, for_rag=True)
|
||||
rows = llama_stack_client.datasetio.get_rows_paginated(
|
||||
dataset_id="test_dataset",
|
||||
rows_in_page=3,
|
||||
)
|
||||
assert len(rows.rows) == 3
|
||||
|
||||
scoring_fns_list = llama_stack_client.scoring_functions.list()
|
||||
scoring_functions = {}
|
||||
aggr_fns = [
|
||||
"accuracy",
|
||||
"median",
|
||||
"categorical_count",
|
||||
"average",
|
||||
]
|
||||
for x in scoring_fns_list:
|
||||
if x.provider_id == "llm-as-judge":
|
||||
aggr_fns = ["categorical_count"]
|
||||
scoring_functions[x.identifier] = dict(
|
||||
type="llm_as_judge",
|
||||
judge_model=judge_model_id,
|
||||
prompt_template=sample_judge_prompt_template,
|
||||
judge_score_regexes=[r"Score: (\d+)"],
|
||||
aggregation_functions=aggr_fns,
|
||||
)
|
||||
elif x.provider_id == "basic" or x.provider_id == "braintrust":
|
||||
if "regex_parser" in x.identifier:
|
||||
scoring_functions[x.identifier] = dict(
|
||||
type="regex_parser",
|
||||
parsing_regexes=[r"Score: (\d+)"],
|
||||
aggregation_functions=aggr_fns,
|
||||
)
|
||||
else:
|
||||
scoring_functions[x.identifier] = dict(
|
||||
type="basic",
|
||||
aggregation_functions=aggr_fns,
|
||||
)
|
||||
else:
|
||||
scoring_functions[x.identifier] = None
|
||||
|
||||
response = llama_stack_client.scoring.score(
|
||||
input_rows=rows.rows,
|
||||
scoring_functions=scoring_functions,
|
||||
)
|
||||
|
||||
assert len(response.results) == len(scoring_functions)
|
||||
for x in scoring_functions:
|
||||
assert x in response.results
|
||||
assert len(response.results[x].score_rows) == len(rows.rows)
|
||||
assert len(response.results[x].aggregated_results) == len(aggr_fns)
|
||||
66
tests/integration/tool_runtime/test_builtin_tools.py
Normal file
66
tests/integration/tool_runtime/test_builtin_tools.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_search_query():
|
||||
return "What are the latest developments in quantum computing?"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_wolfram_alpha_query():
|
||||
return "What is the square root of 16?"
|
||||
|
||||
|
||||
def test_web_search_tool(llama_stack_client, sample_search_query):
|
||||
"""Test the web search tool functionality."""
|
||||
if "TAVILY_SEARCH_API_KEY" not in os.environ:
|
||||
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
|
||||
|
||||
response = llama_stack_client.tool_runtime.invoke_tool(
|
||||
tool_name="web_search", kwargs={"query": sample_search_query}
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
assert response.content is not None
|
||||
assert len(response.content) > 0
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
content = json.loads(response.content)
|
||||
assert "query" in content
|
||||
assert "top_k" in content
|
||||
assert len(content["top_k"]) > 0
|
||||
|
||||
first = content["top_k"][0]
|
||||
assert "title" in first
|
||||
assert "url" in first
|
||||
|
||||
|
||||
def test_wolfram_alpha_tool(llama_stack_client, sample_wolfram_alpha_query):
|
||||
"""Test the wolfram alpha tool functionality."""
|
||||
if "WOLFRAM_ALPHA_API_KEY" not in os.environ:
|
||||
pytest.skip("WOLFRAM_ALPHA_API_KEY not set, skipping test")
|
||||
|
||||
response = llama_stack_client.tool_runtime.invoke_tool(
|
||||
tool_name="wolfram_alpha", kwargs={"query": sample_wolfram_alpha_query}
|
||||
)
|
||||
|
||||
print(response.content)
|
||||
assert response.content is not None
|
||||
assert len(response.content) > 0
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
content = json.loads(response.content)
|
||||
result = content["queryresult"]
|
||||
assert "success" in result
|
||||
assert result["success"]
|
||||
assert "pods" in result
|
||||
assert len(result["pods"]) > 0
|
||||
|
|
@ -4,29 +4,23 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import random
|
||||
|
||||
import pytest
|
||||
from llama_stack_client.types import Document
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def empty_vector_db_registry(llama_stack_client):
|
||||
vector_dbs = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()]
|
||||
for vector_db_id in vector_dbs:
|
||||
llama_stack_client.vector_dbs.unregister(vector_db_id=vector_db_id)
|
||||
def client_with_empty_registry(client_with_models):
|
||||
def clear_registry():
|
||||
vector_dbs = [vector_db.identifier for vector_db in client_with_models.vector_dbs.list()]
|
||||
for vector_db_id in vector_dbs:
|
||||
client_with_models.vector_dbs.unregister(vector_db_id=vector_db_id)
|
||||
|
||||
clear_registry()
|
||||
yield client_with_models
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def single_entry_vector_db_registry(llama_stack_client, empty_vector_db_registry):
|
||||
vector_db_id = f"test_vector_db_{random.randint(1000, 9999)}"
|
||||
llama_stack_client.vector_dbs.register(
|
||||
vector_db_id=vector_db_id,
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
)
|
||||
vector_dbs = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()]
|
||||
return vector_dbs
|
||||
# you must clean after the last test if you were running tests against
|
||||
# a stateful server instance
|
||||
clear_registry()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
|
@ -63,9 +57,15 @@ def assert_valid_response(response):
|
|||
assert isinstance(chunk.content, str)
|
||||
|
||||
|
||||
def test_vector_db_insert_inline_and_query(llama_stack_client, single_entry_vector_db_registry, sample_documents):
|
||||
vector_db_id = single_entry_vector_db_registry[0]
|
||||
llama_stack_client.tool_runtime.rag_tool.insert(
|
||||
def test_vector_db_insert_inline_and_query(client_with_empty_registry, sample_documents, embedding_model_id):
|
||||
vector_db_id = "test_vector_db"
|
||||
client_with_empty_registry.vector_dbs.register(
|
||||
vector_db_id=vector_db_id,
|
||||
embedding_model=embedding_model_id,
|
||||
embedding_dimension=384,
|
||||
)
|
||||
|
||||
client_with_empty_registry.tool_runtime.rag_tool.insert(
|
||||
documents=sample_documents,
|
||||
chunk_size_in_tokens=512,
|
||||
vector_db_id=vector_db_id,
|
||||
|
|
@ -73,7 +73,7 @@ def test_vector_db_insert_inline_and_query(llama_stack_client, single_entry_vect
|
|||
|
||||
# Query with a direct match
|
||||
query1 = "programming language"
|
||||
response1 = llama_stack_client.vector_io.query(
|
||||
response1 = client_with_empty_registry.vector_io.query(
|
||||
vector_db_id=vector_db_id,
|
||||
query=query1,
|
||||
)
|
||||
|
|
@ -82,7 +82,7 @@ def test_vector_db_insert_inline_and_query(llama_stack_client, single_entry_vect
|
|||
|
||||
# Query with semantic similarity
|
||||
query2 = "AI and brain-inspired computing"
|
||||
response2 = llama_stack_client.vector_io.query(
|
||||
response2 = client_with_empty_registry.vector_io.query(
|
||||
vector_db_id=vector_db_id,
|
||||
query=query2,
|
||||
)
|
||||
|
|
@ -91,7 +91,7 @@ def test_vector_db_insert_inline_and_query(llama_stack_client, single_entry_vect
|
|||
|
||||
# Query with limit on number of results (max_chunks=2)
|
||||
query3 = "computer"
|
||||
response3 = llama_stack_client.vector_io.query(
|
||||
response3 = client_with_empty_registry.vector_io.query(
|
||||
vector_db_id=vector_db_id,
|
||||
query=query3,
|
||||
params={"max_chunks": 2},
|
||||
|
|
@ -101,7 +101,7 @@ def test_vector_db_insert_inline_and_query(llama_stack_client, single_entry_vect
|
|||
|
||||
# Query with threshold on similarity score
|
||||
query4 = "computer"
|
||||
response4 = llama_stack_client.vector_io.query(
|
||||
response4 = client_with_empty_registry.vector_io.query(
|
||||
vector_db_id=vector_db_id,
|
||||
query=query4,
|
||||
params={"score_threshold": 0.01},
|
||||
|
|
@ -110,20 +110,20 @@ def test_vector_db_insert_inline_and_query(llama_stack_client, single_entry_vect
|
|||
assert all(score >= 0.01 for score in response4.scores)
|
||||
|
||||
|
||||
def test_vector_db_insert_from_url_and_query(llama_stack_client, empty_vector_db_registry):
|
||||
providers = [p for p in llama_stack_client.providers.list() if p.api == "vector_io"]
|
||||
def test_vector_db_insert_from_url_and_query(client_with_empty_registry, sample_documents, embedding_model_id):
|
||||
providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"]
|
||||
assert len(providers) > 0
|
||||
|
||||
vector_db_id = "test_vector_db"
|
||||
|
||||
llama_stack_client.vector_dbs.register(
|
||||
client_with_empty_registry.vector_dbs.register(
|
||||
vector_db_id=vector_db_id,
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_model=embedding_model_id,
|
||||
embedding_dimension=384,
|
||||
)
|
||||
|
||||
# list to check memory bank is successfully registered
|
||||
available_vector_dbs = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()]
|
||||
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
|
||||
assert vector_db_id in available_vector_dbs
|
||||
|
||||
# URLs of documents to insert
|
||||
|
|
@ -144,14 +144,14 @@ def test_vector_db_insert_from_url_and_query(llama_stack_client, empty_vector_db
|
|||
for i, url in enumerate(urls)
|
||||
]
|
||||
|
||||
llama_stack_client.tool_runtime.rag_tool.insert(
|
||||
client_with_empty_registry.tool_runtime.rag_tool.insert(
|
||||
documents=documents,
|
||||
vector_db_id=vector_db_id,
|
||||
chunk_size_in_tokens=512,
|
||||
)
|
||||
|
||||
# Query for the name of method
|
||||
response1 = llama_stack_client.vector_io.query(
|
||||
response1 = client_with_empty_registry.vector_io.query(
|
||||
vector_db_id=vector_db_id,
|
||||
query="What's the name of the fine-tunning method used?",
|
||||
)
|
||||
|
|
@ -159,7 +159,7 @@ def test_vector_db_insert_from_url_and_query(llama_stack_client, empty_vector_db
|
|||
assert any("lora" in chunk.content.lower() for chunk in response1.chunks)
|
||||
|
||||
# Query for the name of model
|
||||
response2 = llama_stack_client.vector_io.query(
|
||||
response2 = client_with_empty_registry.vector_io.query(
|
||||
vector_db_id=vector_db_id,
|
||||
query="Which Llama model is mentioned?",
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue