Merge branch 'rag_scoring_fn_1' into rag_scoring_fn_2

This commit is contained in:
Xi Yan 2024-12-30 17:20:35 -08:00
commit dbecff60a4
128 changed files with 6391 additions and 493 deletions

View file

@ -81,14 +81,28 @@ async def agents_stack(request, inference_model, safety_shield):
inference_models = (
inference_model if isinstance(inference_model, list) else [inference_model]
)
models = [
ModelInput(
model_id=model,
model_type=ModelType.llm,
provider_id=providers["inference"][0].provider_id,
# NOTE: meta-reference provider needs 1 provider per model, lookup provider_id from provider config
model_to_provider_id = {}
for provider in providers["inference"]:
if "model" in provider.config:
model_to_provider_id[provider.config["model"]] = provider.provider_id
models = []
for model in inference_models:
if model in model_to_provider_id:
provider_id = model_to_provider_id[model]
else:
provider_id = providers["inference"][0].provider_id
models.append(
ModelInput(
model_id=model,
model_type=ModelType.llm,
provider_id=provider_id,
)
)
for model in inference_models
]
models.append(
ModelInput(
model_id="all-MiniLM-L6-v2",

View file

@ -5,11 +5,31 @@
# the root directory of this source tree.
import os
from typing import Dict, List
import pytest
from llama_models.llama3.api.datatypes import BuiltinTool
from llama_stack.apis.agents import * # noqa: F403
from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.apis.agents import (
AgentConfig,
AgentTool,
AgentTurnResponseEventType,
AgentTurnResponseStepCompletePayload,
AgentTurnResponseStreamChunk,
AgentTurnResponseTurnCompletePayload,
Attachment,
MemoryToolDefinition,
SearchEngineType,
SearchToolDefinition,
ShieldCallStep,
StepType,
ToolChoice,
ToolExecutionStep,
Turn,
)
from llama_stack.apis.inference import CompletionMessage, SamplingParams, UserMessage
from llama_stack.apis.safety import ViolationLevel
from llama_stack.providers.datatypes import Api
# How to run this test:
#

View file

@ -6,9 +6,9 @@
import pytest
from llama_stack.apis.agents import * # noqa: F403
from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.apis.agents import AgentConfig, Turn
from llama_stack.apis.inference import SamplingParams, UserMessage
from llama_stack.providers.datatypes import Api
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
from .fixtures import pick_inference_model

View file

@ -4,16 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import pytest
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
import base64
import mimetypes
import os
from pathlib import Path
import pytest
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ChatCompletionInputType, StringType
from llama_stack.apis.datasets import Datasets
# How to run this test:
#
# pytest llama_stack/providers/tests/datasetio/test_datasetio.py

View file

@ -6,8 +6,14 @@
import unittest
from llama_models.llama3.api import * # noqa: F403
from llama_stack.apis.inference.inference import * # noqa: F403
from llama_models.llama3.api.datatypes import (
BuiltinTool,
ToolDefinition,
ToolParamDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference import ChatCompletionRequest, SystemMessage, UserMessage
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
)
@ -24,7 +30,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
UserMessage(content=content),
],
)
messages = chat_completion_request_to_messages(request)
messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 2)
self.assertEqual(messages[-1].content, content)
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
@ -41,7 +47,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
ToolDefinition(tool_name=BuiltinTool.brave_search),
],
)
messages = chat_completion_request_to_messages(request)
messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 2)
self.assertEqual(messages[-1].content, content)
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
@ -69,7 +75,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
],
tool_prompt_format=ToolPromptFormat.json,
)
messages = chat_completion_request_to_messages(request)
messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 3)
self.assertTrue("Environment: ipython" in messages[0].content)
@ -99,7 +105,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
),
],
)
messages = chat_completion_request_to_messages(request)
messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 3)
self.assertTrue("Environment: ipython" in messages[0].content)
@ -121,7 +127,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
],
)
messages = chat_completion_request_to_messages(request)
messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 2, messages)
self.assertTrue(messages[0].content.endswith(system_prompt))

View file

@ -7,13 +7,32 @@
import pytest
from llama_models.llama3.api.datatypes import (
SamplingParams,
StopReason,
ToolCall,
ToolDefinition,
ToolParamDefinition,
ToolPromptFormat,
)
from pydantic import BaseModel, ValidationError
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionResponse,
CompletionResponseStreamChunk,
JsonSchemaResponseFormat,
LogProbConfig,
SystemMessage,
ToolCallDelta,
ToolCallParseStatus,
ToolChoice,
UserMessage,
)
from llama_stack.apis.models import Model
from .utils import group_chunks

View file

@ -8,11 +8,16 @@ from pathlib import Path
import pytest
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem, URL
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
SamplingParams,
UserMessage,
)
from .utils import group_chunks
THIS_DIR = Path(__file__).parent

View file

@ -10,8 +10,7 @@ import tempfile
import pytest
import pytest_asyncio
from llama_stack.apis.inference import ModelInput, ModelType
from llama_stack.apis.models import ModelInput, ModelType
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
from llama_stack.providers.inline.memory.faiss import FaissImplConfig
@ -19,7 +18,7 @@ from llama_stack.providers.remote.memory.chroma import ChromaRemoteImplConfig
from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
from llama_stack.providers.remote.memory.weaviate import WeaviateConfig
from llama_stack.providers.tests.resolver import construct_stack_for_test
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from ..conftest import ProviderFixture, remote_stack_fixture
from ..env import get_env_or_fail

View file

@ -8,14 +8,18 @@ import uuid
import pytest
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.apis.memory_banks.memory_banks import VectorMemoryBankParams
from llama_stack.apis.memory import MemoryBankDocument, QueryDocumentsResponse
from llama_stack.apis.memory_banks import (
MemoryBank,
MemoryBanks,
VectorMemoryBankParams,
)
# How to run this test:
#
# pytest llama_stack/providers/tests/memory/test_memory.py
# -m "meta_reference"
# -m "sentence_transformers" --env EMBEDDING_DIMENSION=384
# -v -s --tb=short --disable-warnings

View file

@ -7,8 +7,9 @@
import pytest
import pytest_asyncio
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import StringType
from llama_stack.apis.datasets import DatasetInput
from llama_stack.apis.models import ModelInput

View file

@ -4,9 +4,18 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.post_training import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.apis.common.type_system import JobStatus
from llama_stack.apis.post_training import (
Checkpoint,
DataConfig,
LoraFinetuningConfig,
OptimizerConfig,
PostTrainingJob,
PostTrainingJobArtifactsResponse,
PostTrainingJobStatusResponse,
TrainingConfig,
)
# How to run this test:
#

View file

@ -8,14 +8,24 @@ import json
import tempfile
from typing import Any, Dict, List, Optional
from llama_stack.distribution.datatypes import * # noqa: F403
from pydantic import BaseModel
from llama_stack.apis.datasets import DatasetInput
from llama_stack.apis.eval_tasks import EvalTaskInput
from llama_stack.apis.memory_banks import MemoryBankInput
from llama_stack.apis.models import ModelInput
from llama_stack.apis.scoring_functions import ScoringFnInput
from llama_stack.apis.shields import ShieldInput
from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.datatypes import Provider, StackRunConfig
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_remote_stack_impls
from llama_stack.distribution.stack import construct_stack
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
from llama_stack.providers.datatypes import Api, RemoteProviderConfig
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
class TestStack(BaseModel):

View file

@ -6,11 +6,9 @@
import pytest
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.apis.inference import UserMessage
from llama_stack.apis.safety import ViolationLevel
from llama_stack.apis.shields import Shield
# How to run this test:
#