forked from phoenix-oss/llama-stack-mirror
[remove import *] clean up import *'s (#689)
# What does this PR do? - as title, cleaning up `import *`'s - upgrade tests to make them more robust to bad model outputs - remove import *'s in llama_stack/apis/* (skip __init__ modules) <img width="465" alt="image" src="https://github.com/user-attachments/assets/d8339c13-3b40-4ba5-9c53-0d2329726ee2" /> - run `sh run_openapi_generator.sh`, no types gets affected ## Test Plan ### Providers Tests **agents** ``` pytest -v -s llama_stack/providers/tests/agents/test_agents.py -m "together" --safety-shield meta-llama/Llama-Guard-3-8B --inference-model meta-llama/Llama-3.1-405B-Instruct-FP8 ``` **inference** ```bash # meta-reference torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="meta-llama/Llama-3.1-8B-Instruct" ./llama_stack/providers/tests/inference/test_text_inference.py torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="meta-llama/Llama-3.2-11B-Vision-Instruct" ./llama_stack/providers/tests/inference/test_vision_inference.py # together pytest -v -s -k "together" --inference-model="meta-llama/Llama-3.1-8B-Instruct" ./llama_stack/providers/tests/inference/test_text_inference.py pytest -v -s -k "together" --inference-model="meta-llama/Llama-3.2-11B-Vision-Instruct" ./llama_stack/providers/tests/inference/test_vision_inference.py pytest ./llama_stack/providers/tests/inference/test_prompt_adapter.py ``` **safety** ``` pytest -v -s llama_stack/providers/tests/safety/test_safety.py -m together --safety-shield meta-llama/Llama-Guard-3-8B ``` **memory** ``` pytest -v -s llama_stack/providers/tests/memory/test_memory.py -m "sentence_transformers" --env EMBEDDING_DIMENSION=384 ``` **scoring** ``` pytest -v -s -m llm_as_judge_scoring_together_inference llama_stack/providers/tests/scoring/test_scoring.py --judge-model meta-llama/Llama-3.2-3B-Instruct pytest -v -s -m basic_scoring_together_inference llama_stack/providers/tests/scoring/test_scoring.py pytest -v -s -m braintrust_scoring_together_inference llama_stack/providers/tests/scoring/test_scoring.py ``` **datasetio** ``` pytest -v -s -m localfs llama_stack/providers/tests/datasetio/test_datasetio.py pytest -v -s -m huggingface llama_stack/providers/tests/datasetio/test_datasetio.py ``` **eval** ``` pytest -v -s -m meta_reference_eval_together_inference llama_stack/providers/tests/eval/test_eval.py pytest -v -s -m meta_reference_eval_together_inference_huggingface_datasetio llama_stack/providers/tests/eval/test_eval.py ``` ### Client-SDK Tests ``` LLAMA_STACK_BASE_URL=http://localhost:5000 pytest -v ./tests/client-sdk ``` ### llama-stack-apps ``` PORT=5000 LOCALHOST=localhost python -m examples.agents.hello $LOCALHOST $PORT python -m examples.agents.inflation $LOCALHOST $PORT python -m examples.agents.podcast_transcript $LOCALHOST $PORT python -m examples.agents.rag_as_attachments $LOCALHOST $PORT python -m examples.agents.rag_with_memory_bank $LOCALHOST $PORT python -m examples.safety.llama_guard_demo_mm $LOCALHOST $PORT python -m examples.agents.e2e_loop_with_custom_tools $LOCALHOST $PORT # Vision model python -m examples.interior_design_assistant.app python -m examples.agent_store.app $LOCALHOST $PORT ``` ### CLI ``` which llama llama model prompt-format -m Llama3.2-11B-Vision-Instruct llama model list llama stack list-apis llama stack list-providers inference llama stack build --template ollama --image-type conda ``` ### Distributions Tests **ollama** ``` llama stack build --template ollama --image-type conda ollama run llama3.2:1b-instruct-fp16 llama stack run ./llama_stack/templates/ollama/run.yaml --env INFERENCE_MODEL=meta-llama/Llama-3.2-1B-Instruct ``` **fireworks** ``` llama stack build --template fireworks --image-type conda llama stack run ./llama_stack/templates/fireworks/run.yaml ``` **together** ``` llama stack build --template together --image-type conda llama stack run ./llama_stack/templates/together/run.yaml ``` **tgi** ``` llama stack run ./llama_stack/templates/tgi/run.yaml --env TGI_URL=http://0.0.0.0:5009 --env INFERENCE_MODEL=meta-llama/Llama-3.1-8B-Instruct ``` ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests.
This commit is contained in:
parent
70db039ff4
commit
3c72c034e6
99 changed files with 907 additions and 359 deletions
|
@ -5,11 +5,31 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from typing import Dict, List
|
||||
|
||||
import pytest
|
||||
from llama_models.llama3.api.datatypes import BuiltinTool
|
||||
|
||||
from llama_stack.apis.agents import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.agents import (
|
||||
AgentConfig,
|
||||
AgentTool,
|
||||
AgentTurnResponseEventType,
|
||||
AgentTurnResponseStepCompletePayload,
|
||||
AgentTurnResponseStreamChunk,
|
||||
AgentTurnResponseTurnCompletePayload,
|
||||
Attachment,
|
||||
MemoryToolDefinition,
|
||||
SearchEngineType,
|
||||
SearchToolDefinition,
|
||||
ShieldCallStep,
|
||||
StepType,
|
||||
ToolChoice,
|
||||
ToolExecutionStep,
|
||||
Turn,
|
||||
)
|
||||
from llama_stack.apis.inference import CompletionMessage, SamplingParams, UserMessage
|
||||
from llama_stack.apis.safety import ViolationLevel
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
|
|
|
@ -6,9 +6,9 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.agents import AgentConfig, Turn
|
||||
from llama_stack.apis.inference import SamplingParams, UserMessage
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
|
||||
from .fixtures import pick_inference_model
|
||||
|
||||
|
|
|
@ -4,16 +4,17 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
import base64
|
||||
import mimetypes
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.common.type_system import ChatCompletionInputType, StringType
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
# pytest llama_stack/providers/tests/datasetio/test_datasetio.py
|
||||
|
|
|
@ -7,8 +7,7 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_models.llama3.api import SamplingParams, URL
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.common.type_system import ChatCompletionInputType, StringType
|
||||
|
||||
from llama_stack.apis.eval.eval import (
|
||||
|
@ -16,6 +15,7 @@ from llama_stack.apis.eval.eval import (
|
|||
BenchmarkEvalTaskConfig,
|
||||
ModelCandidate,
|
||||
)
|
||||
from llama_stack.apis.inference import SamplingParams
|
||||
from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset
|
||||
|
|
|
@ -6,8 +6,14 @@
|
|||
|
||||
import unittest
|
||||
|
||||
from llama_models.llama3.api import * # noqa: F403
|
||||
from llama_stack.apis.inference.inference import * # noqa: F403
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
BuiltinTool,
|
||||
ToolDefinition,
|
||||
ToolParamDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
|
||||
from llama_stack.apis.inference import ChatCompletionRequest, SystemMessage, UserMessage
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_messages,
|
||||
)
|
||||
|
@ -24,7 +30,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
|||
UserMessage(content=content),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
self.assertEqual(len(messages), 2)
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
|
||||
|
@ -41,7 +47,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
|||
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
self.assertEqual(len(messages), 2)
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
|
||||
|
@ -69,7 +75,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
|||
],
|
||||
tool_prompt_format=ToolPromptFormat.json,
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
self.assertEqual(len(messages), 3)
|
||||
self.assertTrue("Environment: ipython" in messages[0].content)
|
||||
|
||||
|
@ -99,7 +105,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
|||
),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
self.assertEqual(len(messages), 3)
|
||||
|
||||
self.assertTrue("Environment: ipython" in messages[0].content)
|
||||
|
@ -121,7 +127,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
|||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
self.assertEqual(len(messages), 2, messages)
|
||||
self.assertTrue(messages[0].content.endswith(system_prompt))
|
||||
|
||||
|
|
|
@ -7,13 +7,32 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
ToolParamDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
JsonSchemaResponseFormat,
|
||||
LogProbConfig,
|
||||
SystemMessage,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
ToolChoice,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from .utils import group_chunks
|
||||
|
||||
|
||||
|
|
|
@ -8,11 +8,16 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem, URL
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
SamplingParams,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
from .utils import group_chunks
|
||||
|
||||
THIS_DIR = Path(__file__).parent
|
||||
|
|
|
@ -10,8 +10,7 @@ import tempfile
|
|||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.inference import ModelInput, ModelType
|
||||
|
||||
from llama_stack.apis.models import ModelInput, ModelType
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
|
||||
from llama_stack.providers.inline.memory.faiss import FaissImplConfig
|
||||
|
@ -19,7 +18,7 @@ from llama_stack.providers.remote.memory.chroma import ChromaRemoteImplConfig
|
|||
from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
|
||||
from llama_stack.providers.remote.memory.weaviate import WeaviateConfig
|
||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||
from ..env import get_env_or_fail
|
||||
|
||||
|
|
|
@ -8,14 +8,18 @@ import uuid
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.memory_banks.memory_banks import VectorMemoryBankParams
|
||||
from llama_stack.apis.memory import MemoryBankDocument, QueryDocumentsResponse
|
||||
|
||||
from llama_stack.apis.memory_banks import (
|
||||
MemoryBank,
|
||||
MemoryBanks,
|
||||
VectorMemoryBankParams,
|
||||
)
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
# pytest llama_stack/providers/tests/memory/test_memory.py
|
||||
# -m "meta_reference"
|
||||
# -m "sentence_transformers" --env EMBEDDING_DIMENSION=384
|
||||
# -v -s --tb=short --disable-warnings
|
||||
|
||||
|
||||
|
|
|
@ -7,8 +7,9 @@
|
|||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
|
||||
from llama_stack.apis.common.type_system import StringType
|
||||
from llama_stack.apis.datasets import DatasetInput
|
||||
from llama_stack.apis.models import ModelInput
|
||||
|
||||
|
|
|
@ -4,9 +4,18 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import pytest
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.apis.post_training import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.common.type_system import JobStatus
|
||||
from llama_stack.apis.post_training import (
|
||||
Checkpoint,
|
||||
DataConfig,
|
||||
LoraFinetuningConfig,
|
||||
OptimizerConfig,
|
||||
PostTrainingJob,
|
||||
PostTrainingJobArtifactsResponse,
|
||||
PostTrainingJobStatusResponse,
|
||||
TrainingConfig,
|
||||
)
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
|
|
|
@ -8,14 +8,24 @@ import json
|
|||
import tempfile
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.datasets import DatasetInput
|
||||
from llama_stack.apis.eval_tasks import EvalTaskInput
|
||||
from llama_stack.apis.memory_banks import MemoryBankInput
|
||||
from llama_stack.apis.models import ModelInput
|
||||
from llama_stack.apis.scoring_functions import ScoringFnInput
|
||||
from llama_stack.apis.shields import ShieldInput
|
||||
|
||||
from llama_stack.distribution.build import print_pip_install_help
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.resolver import resolve_remote_stack_impls
|
||||
from llama_stack.distribution.stack import construct_stack
|
||||
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
|
||||
from llama_stack.providers.datatypes import Api, RemoteProviderConfig
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
|
||||
class TestStack(BaseModel):
|
||||
|
|
|
@ -6,11 +6,9 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.inference import UserMessage
|
||||
from llama_stack.apis.safety import ViolationLevel
|
||||
from llama_stack.apis.shields import Shield
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
|
|
|
@ -197,7 +197,7 @@ class TestScoring:
|
|||
judge_score_regexes=[r"Score: (\d+)"],
|
||||
aggregation_functions=aggr_fns,
|
||||
)
|
||||
elif x.provider_id == "basic":
|
||||
elif x.provider_id == "basic" or x.provider_id == "braintrust":
|
||||
if "regex_parser" in x.identifier:
|
||||
scoring_functions[x.identifier] = RegexParserScoringFnParams(
|
||||
aggregation_functions=aggr_fns,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue