build: format codebase imports using ruff linter (#1028)

# What does this PR do?

- Configured ruff linter to automatically fix import sorting issues.
- Set --exit-non-zero-on-fix to ensure non-zero exit code when fixes are
applied.
- Enabled the 'I' selection to focus on import-related linting rules.
- Ran the linter, and formatted all codebase imports accordingly.
- Removed the black dep from the "dev" group since we use ruff

Signed-off-by: Sébastien Han <seb@redhat.com>

[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan
[Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.*]

[//]: # (## Documentation)
[//]: # (- [ ] Added a Changelog entry if the change is significant)

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-02-13 19:06:21 +01:00 committed by GitHub
parent 1527c30107
commit e4a1579e63
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
140 changed files with 139 additions and 243 deletions

View file

@ -11,7 +11,6 @@ from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
from llama_stack.apis.datasets import Dataset
from llama_stack.apis.datatypes import Api
from llama_stack.apis.eval_tasks import EvalTask
from llama_stack.apis.models import Model

View file

@ -42,10 +42,10 @@ from llama_stack.apis.agents import (
Turn,
)
from llama_stack.apis.common.content_types import (
URL,
TextContentItem,
ToolCallDelta,
ToolCallParseStatus,
URL,
)
from llama_stack.apis.inference import (
ChatCompletionResponseEventType,

View file

@ -6,11 +6,9 @@
import asyncio
import logging
from typing import List
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
log = logging.getLogger(__name__)

View file

@ -41,7 +41,6 @@ from llama_stack.apis.tools import (
ToolInvocationResult,
)
from llama_stack.apis.vector_io import QueryChunksResponse
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
MEMORY_QUERY_TOOL,
)

View file

@ -15,14 +15,12 @@ import pandas
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
from llama_stack.providers.utils.kvstore import kvstore_impl
from .config import LocalFSDatasetIOConfig
DATASETS_PREFIX = "localfs_datasets:"

View file

@ -15,7 +15,6 @@ from llama_stack.apis.inference import Inference, UserMessage
from llama_stack.apis.scoring import Scoring
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
MEMORY_QUERY_TOOL,
)
@ -28,7 +27,6 @@ from llama_stack.providers.utils.kvstore import kvstore_impl
from .....apis.common.job_types import Job
from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
from .config import MetaReferenceEvalConfig
EVAL_TASKS_PREFIX = "eval_tasks:"

View file

@ -9,7 +9,6 @@ from typing import Any, Dict, Optional
from pydantic import BaseModel, field_validator
from llama_stack.apis.inference import QuantizationConfig
from llama_stack.providers.utils.inference import supported_inference_models

View file

@ -37,7 +37,6 @@ from llama_models.llama3.reference_impl.multimodal.model import (
CrossAttentionTransformer,
)
from llama_models.sku_list import resolve_model
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
from pydantic import BaseModel
@ -47,7 +46,6 @@ from llama_stack.apis.inference import (
ResponseFormat,
ResponseFormatType,
)
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,

View file

@ -46,8 +46,8 @@ from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
)
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
build_model_alias,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
augment_content_with_response_format_prompt,

View file

@ -22,16 +22,13 @@ from typing import Callable, Generator, Literal, Optional, Union
import torch
import zmq
from fairscale.nn.model_parallel.initialize import (
get_model_parallel_group,
get_model_parallel_rank,
get_model_parallel_src_rank,
)
from pydantic import BaseModel, Field
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
from typing_extensions import Annotated
from llama_stack.providers.utils.inference.prompt_adapter import (

View file

@ -8,7 +8,6 @@
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import collections
import logging
from typing import Optional, Type
@ -23,7 +22,7 @@ except ImportError:
raise
import torch
from torch import nn, Tensor
from torch import Tensor, nn
class Fp8ScaledWeights:

View file

@ -10,9 +10,9 @@
import unittest
import torch
from fp8_impls import ffn_swiglu_fp8_dynamic, FfnQuantizeMode, quantize_fp8
from hypothesis import given, settings, strategies as st
from fp8_impls import FfnQuantizeMode, ffn_swiglu_fp8_dynamic, quantize_fp8
from hypothesis import given, settings
from hypothesis import strategies as st
from torch import Tensor

View file

@ -12,18 +12,13 @@ import os
from typing import Any, Dict, List, Optional
import torch
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from llama_models.datatypes import CheckpointQuantizationFormat
from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
from llama_models.sku_list import resolve_model
from torch import nn, Tensor
from torch import Tensor, nn
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
from llama_stack.apis.inference import QuantizationType

View file

@ -16,14 +16,12 @@ from pathlib import Path
from typing import Optional
import fire
import torch
from fairscale.nn.model_parallel.initialize import (
get_model_parallel_rank,
initialize_model_parallel,
model_parallel_is_initialized,
)
from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock

View file

@ -15,9 +15,9 @@ from llama_stack.apis.inference import (
ResponseFormat,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
ToolConfig,
)
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (

View file

@ -37,9 +37,9 @@ from llama_stack.apis.inference import (
from llama_stack.apis.models import Model
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)

View file

@ -15,10 +15,8 @@ from typing import Any, Callable, Dict
import torch
from llama_models.datatypes import Model
from llama_models.sku_list import resolve_model
from pydantic import BaseModel
from torchtune.data._messages import InputOutputToMessages, ShareGPTToMessages
from torchtune.models.llama3 import llama3_tokenizer
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
from torchtune.models.llama3_1 import lora_llama3_1_8b

View file

@ -13,7 +13,6 @@
from typing import Any, Dict, List, Mapping
import numpy as np
from torch.utils.data import Dataset
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
from torchtune.data._messages import validate_messages

View file

@ -18,9 +18,9 @@ from llama_models.sku_list import resolve_model
from torch import nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import modules, training, utils as torchtune_utils
from torchtune import modules, training
from torchtune import utils as torchtune_utils
from torchtune.data import padded_collate_sft
from torchtune.modules.loss import CEWithChunkedOutputLoss
from torchtune.modules.peft import (
get_adapter_params,
@ -44,14 +44,11 @@ from llama_stack.apis.post_training import (
OptimizerConfig,
TrainingConfig,
)
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.inline.post_training.common.validator import (
validate_input_dataset_schema,
)
from llama_stack.providers.inline.post_training.torchtune.common import utils
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
TorchtuneCheckpointer,

View file

@ -21,7 +21,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import CodeScannerConfig
log = logging.getLogger(__name__)
ALLOWED_CODE_SCANNER_MODEL_IDS = [

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import re
from string import Template
from typing import Any, Dict, List, Optional
@ -25,10 +24,8 @@ from llama_stack.apis.safety import (
SafetyViolation,
ViolationLevel,
)
from llama_stack.apis.shields import Shield
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
@ -36,7 +33,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import LlamaGuardConfig
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
SAFE_RESPONSE = "safe"

View file

@ -8,7 +8,6 @@ import logging
from typing import Any, Dict, List
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from llama_stack.apis.inference import Message
@ -19,7 +18,6 @@ from llama_stack.apis.safety import (
ViolationLevel,
)
from llama_stack.apis.shields import Shield
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (

View file

@ -14,13 +14,13 @@ from llama_stack.apis.scoring import (
ScoringResult,
)
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator import (
get_valid_schemas,
validate_dataset_schema,
)
from .config import BasicScoringConfig
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn

View file

@ -7,7 +7,6 @@
from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn

View file

@ -11,7 +11,6 @@ from llama_stack.apis.scoring_functions import (
ScoringFn,
)
equality = ScoringFn(
identifier="basic::equality",
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",

View file

@ -11,7 +11,6 @@ from llama_stack.apis.scoring_functions import (
ScoringFn,
)
subset_of = ScoringFn(
identifier="basic::subset_of",
description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.",

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import re
from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow

View file

@ -29,9 +29,7 @@ from llama_stack.apis.scoring import (
ScoringResultRow,
)
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator import (
@ -39,8 +37,8 @@ from llama_stack.providers.utils.common.data_schema_validator import (
validate_dataset_schema,
validate_row_schema,
)
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
from .config import BraintrustScoringConfig
from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def
from .scoring_fn.fn_defs.answer_relevancy import answer_relevancy_fn_def

View file

@ -11,7 +11,6 @@ from llama_stack.apis.scoring_functions import (
ScoringFn,
)
answer_correctness_fn_def = ScoringFn(
identifier="braintrust::answer-correctness",
description=(

View file

@ -11,7 +11,6 @@ from llama_stack.apis.scoring_functions import (
ScoringFn,
)
factuality_fn_def = ScoringFn(
identifier="braintrust::factuality",
description=(

View file

@ -8,7 +8,6 @@ from typing import Any, Dict, List, Optional
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.inference.inference import Inference
from llama_stack.apis.scoring import (
ScoreBatchResponse,
ScoreResponse,
@ -26,7 +25,6 @@ from llama_stack.providers.utils.common.data_schema_validator import (
from .config import LlmAsJudgeScoringConfig
from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn
LLM_JUDGE_FNS = [LlmAsJudgeScoringFn]

View file

@ -7,7 +7,6 @@
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams, ScoringFn
llm_as_judge_base = ScoringFn(
identifier="llm-as-judge::base",
description="Llm As Judge Scoring Function",

View file

@ -4,18 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import re
from typing import Any, Dict, Optional
from llama_stack.apis.inference.inference import Inference
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from .fn_defs.llm_as_judge_405b_simpleqa import llm_as_judge_405b_simpleqa
from .fn_defs.llm_as_judge_base import llm_as_judge_base

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
from llama_stack.apis.telemetry import Telemetry
from .config import SampleConfig

View file

@ -82,7 +82,11 @@ import sys as _sys
# them with linters - they're used in code_execution.py
from contextlib import ( # noqa
contextmanager as _contextmanager,
)
from contextlib import (
redirect_stderr as _redirect_stderr,
)
from contextlib import (
redirect_stdout as _redirect_stdout,
)
from multiprocessing.connection import Connection as _Connection

View file

@ -9,7 +9,6 @@ from jinja2 import Template
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import UserMessage
from llama_stack.apis.tools.rag_tool import (
DefaultRAGQueryGeneratorConfig,
LLMRAGQueryGeneratorConfig,

View file

@ -11,9 +11,9 @@ import string
from typing import Any, Dict, List, Optional
from llama_stack.apis.common.content_types import (
URL,
InterleavedContent,
TextContentItem,
URL,
)
from llama_stack.apis.inference import Inference
from llama_stack.apis.tools import (

View file

@ -7,6 +7,7 @@
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import FaissImplConfig

View file

@ -8,11 +8,9 @@ import base64
import io
import json
import logging
from typing import Any, Dict, List, Optional
import faiss
import numpy as np
from numpy.typing import NDArray

View file

@ -5,7 +5,9 @@
# the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import SQLiteVectorIOConfig

View file

@ -5,9 +5,10 @@
# the root directory of this source tree.
# config.py
from pydantic import BaseModel
from typing import Any, Dict
from pydantic import BaseModel
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,

View file

@ -4,13 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import sqlite3
import sqlite_vec
import struct
import logging
import sqlite3
import struct
from typing import Any, Dict, List, Optional
import numpy as np
import sqlite_vec
from numpy.typing import NDArray
from typing import List, Optional, Dict, Any
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
from llama_stack.apis.agents import Agents
from .config import SampleConfig

View file

@ -9,7 +9,6 @@ import datasets as hf_datasets
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
from llama_stack.providers.utils.kvstore import kvstore_impl

View file

@ -31,13 +31,13 @@ from llama_stack.apis.inference import (
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
build_model_alias,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_strategy_options,
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
get_sampling_strategy_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)

View file

@ -29,8 +29,8 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
)
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
build_model_alias,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,

View file

@ -26,8 +26,8 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
)
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
build_model_alias,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,

View file

@ -31,8 +31,8 @@ from llama_stack.apis.inference import (
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
build_model_alias,
)
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,

View file

@ -31,9 +31,9 @@ from llama_stack.apis.inference import (
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_model_alias,
build_model_alias_with_just_provider_model_id,
ModelRegistryHelper,
)
from .groq_utils import (

View file

@ -24,10 +24,8 @@ from groq.types.chat.chat_completion_user_message_param import (
)
from groq.types.chat.completion_create_params import CompletionCreateParams
from groq.types.shared.function_definition import FunctionDefinition
from llama_models.llama3.api.datatypes import ToolParamDefinition
from llama_stack.apis.common.content_types import (
TextDelta,
ToolCallDelta,
@ -47,9 +45,9 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_strategy_options,
convert_tool_call,
UnparseableToolCall,
convert_tool_call,
get_sampling_strategy_options,
)

View file

@ -29,8 +29,8 @@ from llama_stack.apis.inference import (
ToolConfig,
)
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
build_model_alias,
)
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media

View file

@ -22,17 +22,35 @@ from llama_models.llama3.api.datatypes import (
from openai import AsyncStream
from openai.types.chat import (
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
)
from openai.types.chat import (
ChatCompletionChunk as OpenAIChatCompletionChunk,
)
from openai.types.chat import (
ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam,
)
from openai.types.chat import (
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
)
from openai.types.chat import (
ChatCompletionMessageParam as OpenAIChatCompletionMessage,
)
from openai.types.chat import (
ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall,
)
from openai.types.chat import (
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
)
from openai.types.chat import (
ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage,
)
from openai.types.chat import (
ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage,
)
from openai.types.chat.chat_completion import (
Choice as OpenAIChoice,
)
from openai.types.chat.chat_completion import (
ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs
)
from openai.types.chat.chat_completion_content_part_image_param import (
@ -69,7 +87,6 @@ from llama_stack.apis.inference import (
ToolResponseMessage,
UserMessage,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_content_to_url,
)

View file

@ -8,7 +8,6 @@ from typing import Any, Dict
from pydantic import BaseModel
DEFAULT_OLLAMA_URL = "http://localhost:11434"

View file

@ -36,14 +36,14 @@ from llama_stack.apis.inference import (
from llama_stack.apis.models import Model, ModelType
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_model_alias,
build_model_alias_with_just_provider_model_id,
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,

View file

@ -8,14 +8,12 @@ from typing import AsyncGenerator
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer
from openai import OpenAI
from llama_stack.apis.inference import * # noqa: F403
# from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,

View file

@ -24,8 +24,8 @@ from llama_stack.apis.common.content_types import (
)
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
build_model_alias,
)
from llama_stack.providers.utils.inference.openai_compat import (
process_chat_completion_stream_response,

View file

@ -6,6 +6,7 @@
from llama_stack.apis.inference import Inference
from llama_stack.apis.models import Model
from .config import SampleConfig

View file

@ -33,13 +33,13 @@ from llama_stack.apis.inference import (
from llama_stack.apis.models import Model
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
build_model_alias,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,

View file

@ -30,8 +30,8 @@ from llama_stack.apis.inference import (
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
build_model_alias,
)
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,

View file

@ -13,10 +13,14 @@ from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import all_registered_models
from openai import OpenAI
from llama_stack.apis.common.content_types import InterleavedContent, ToolCallDelta, ToolCallParseStatus, TextDelta
from llama_stack.apis.common.content_types import InterleavedContent, TextDelta, ToolCallDelta, ToolCallParseStatus
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
CompletionRequest,
CompletionResponse,
CompletionResponseStreamChunk,
@ -31,26 +35,22 @@ from llama_stack.apis.inference import (
ToolConfig,
ToolDefinition,
ToolPromptFormat,
CompletionMessage,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
ChatCompletionResponseEvent,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
build_model_alias,
)
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
get_sampling_options,
process_completion_response,
process_completion_stream_response,
OpenAICompatCompletionResponse,
UnparseableToolCall,
convert_message_to_openai_dict,
convert_tool_call,
get_sampling_options,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
completion_request_to_prompt,

View file

@ -6,11 +6,9 @@
import json
import logging
from typing import Any, Dict, List
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
@ -23,7 +21,6 @@ from llama_stack.providers.utils.bedrock.client import create_bedrock_client
from .config import BedrockSafetyConfig
logger = logging.getLogger(__name__)

View file

@ -6,6 +6,7 @@
from llama_stack.apis.safety import Safety
from llama_stack.apis.shields import Shield
from .config import SampleConfig

View file

@ -7,7 +7,6 @@
from pydantic import BaseModel
from .config import ModelContextProtocolConfig
from .model_context_protocol import ModelContextProtocolToolRuntimeImpl

View file

@ -21,6 +21,7 @@ from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex,
VectorDBWithIndex,
)
from .config import ChromaRemoteImplConfig
log = logging.getLogger(__name__)

View file

@ -10,15 +10,13 @@ from typing import Any, Dict, List, Optional, Tuple
import psycopg2
from numpy.typing import NDArray
from psycopg2 import sql
from psycopg2.extras import execute_values, Json
from psycopg2.extras import Json, execute_values
from pydantic import BaseModel, TypeAdapter
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex,
VectorDBWithIndex,

View file

@ -20,6 +20,7 @@ from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex,
VectorDBWithIndex,
)
from .config import QdrantConfig
log = logging.getLogger(__name__)

View file

@ -6,6 +6,7 @@
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import VectorIO
from .config import SampleConfig

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import json
import logging
from typing import Any, Dict, List, Optional
import weaviate

View file

@ -13,7 +13,6 @@ from ..conftest import (
)
from ..inference.fixtures import INFERENCE_FIXTURES
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
from ..tools.fixtures import TOOL_RUNTIME_FIXTURES
from ..vector_io.fixtures import VECTOR_IO_FIXTURES
from .fixtures import AGENTS_FIXTURES

View file

@ -23,7 +23,6 @@ from llama_stack.apis.agents import (
ToolExecutionStep,
Turn,
)
from llama_stack.apis.inference import CompletionMessage, UserMessage
from llama_stack.apis.safety import ViolationLevel
from llama_stack.providers.datatypes import Api

View file

@ -13,7 +13,6 @@ from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from .fixtures import pick_inference_model
from .utils import create_agent_session

View file

@ -6,13 +6,11 @@
import os
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Optional
import pytest
import yaml
from dotenv import load_dotenv
from pydantic import BaseModel, Field
from termcolor import colored

View file

@ -8,7 +8,6 @@ import pytest
import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture, remote_stack_fixture

View file

@ -7,16 +7,14 @@
import pytest
from ..agents.fixtures import AGENTS_FIXTURES
from ..conftest import get_provider_fixture_overrides
from ..datasetio.fixtures import DATASETIO_FIXTURES
from ..inference.fixtures import INFERENCE_FIXTURES
from ..safety.fixtures import SAFETY_FIXTURES
from ..scoring.fixtures import SCORING_FIXTURES
from ..tools.fixtures import TOOL_RUNTIME_FIXTURES
from .fixtures import EVAL_FIXTURES
from ..vector_io.fixtures import VECTOR_IO_FIXTURES
from .fixtures import EVAL_FIXTURES
DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(

View file

@ -8,8 +8,8 @@ import pytest
import pytest_asyncio
from llama_stack.distribution.datatypes import Api, ModelInput, Provider
from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture, remote_stack_fixture

View file

@ -9,7 +9,6 @@ import pytest
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ChatCompletionInputType, StringType
from llama_stack.apis.eval.eval import (
AppEvalTaskConfig,
BenchmarkEvalTaskConfig,
@ -19,6 +18,7 @@ from llama_stack.apis.inference import SamplingParams
from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset
from .constants import JUDGE_PROMPT
# How to run this test:

View file

@ -11,13 +11,11 @@ import pytest_asyncio
from llama_stack.apis.models import ModelInput, ModelType
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.inference.meta_reference import (
MetaReferenceInferenceConfig,
)
from llama_stack.providers.inline.inference.vllm import VLLMConfig
from llama_stack.providers.remote.inference.bedrock import BedrockConfig
from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
from llama_stack.providers.remote.inference.groq import GroqConfig

View file

@ -10,11 +10,13 @@ import pytest
from groq.types.chat.chat_completion import ChatCompletion, Choice
from groq.types.chat.chat_completion_chunk import (
ChatCompletionChunk,
Choice as StreamChoice,
ChoiceDelta,
ChoiceDeltaToolCall,
ChoiceDeltaToolCallFunction,
)
from groq.types.chat.chat_completion_chunk import (
Choice as StreamChoice,
)
from groq.types.chat.chat_completion_message import ChatCompletionMessage
from groq.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
@ -23,6 +25,7 @@ from groq.types.chat.chat_completion_message_tool_call import (
from groq.types.shared.function_definition import FunctionDefinition
from llama_models.datatypes import GreedySamplingStrategy, TopPSamplingStrategy
from llama_models.llama3.api.datatypes import ToolParamDefinition
from llama_stack.apis.common.content_types import ToolCallParseStatus
from llama_stack.apis.inference import (
ChatCompletionRequest,

View file

@ -5,11 +5,11 @@
# the root directory of this source tree.
import pytest
from llama_stack.apis.inference import Inference
from llama_stack.providers.remote.inference.groq import get_adapter_impl
from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig

View file

@ -8,7 +8,6 @@ from unittest.mock import AsyncMock, patch
import pytest
# How to run this test:
#
# torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="Llama3.1-8B-Instruct"

View file

@ -6,7 +6,6 @@
import pytest
from llama_models.llama3.api.datatypes import (
SamplingParams,
StopReason,
@ -15,7 +14,6 @@ from llama_models.llama3.api.datatypes import (
ToolParamDefinition,
ToolPromptFormat,
)
from pydantic import BaseModel, ValidationError
from llama_stack.apis.common.content_types import ToolCallParseStatus
@ -35,7 +33,6 @@ from llama_stack.apis.models import ListModelsResponse, Model
from .utils import group_chunks
# How to run this test:
#
# pytest -v -s llama_stack/providers/tests/inference/test_text_inference.py

View file

@ -7,9 +7,7 @@
import pytest
from ..conftest import get_provider_fixture_overrides
from ..datasetio.fixtures import DATASETIO_FIXTURES
from .fixtures import POST_TRAINING_FIXTURES
DEFAULT_PROVIDER_COMBINATIONS = [

View file

@ -8,13 +8,10 @@ import pytest
import pytest_asyncio
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import StringType
from llama_stack.apis.datasets import DatasetInput
from llama_stack.apis.models import ModelInput
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture

View file

@ -12,10 +12,8 @@ import pytest
from llama_models.datatypes import CoreModelId
from llama_models.sku_list import all_registered_models
from pytest import ExitCode
from pytest_html.basereport import _process_outcome
INFERENCE_APIS = ["chat_completion"]
FUNCTIONALITIES = ["streaming", "structured_output", "tool_calling"]
SUPPORTED_MODELS = {

View file

@ -7,11 +7,9 @@
import pytest
from ..conftest import get_provider_fixture_overrides
from ..inference.fixtures import INFERENCE_FIXTURES
from .fixtures import SAFETY_FIXTURES
DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
{

View file

@ -8,14 +8,11 @@ import pytest
import pytest_asyncio
from llama_stack.apis.models import ModelInput
from llama_stack.apis.shields import ShieldInput
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
from llama_stack.providers.inline.safety.prompt_guard import PromptGuardConfig
from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig
from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture, remote_stack_fixture

View file

@ -7,7 +7,6 @@
import pytest
from ..conftest import get_provider_fixture_overrides
from ..datasetio.fixtures import DATASETIO_FIXTURES
from ..inference.fixtures import INFERENCE_FIXTURES
from .fixtures import SCORING_FIXTURES

View file

@ -8,10 +8,10 @@ import pytest
import pytest_asyncio
from llama_stack.apis.models import ModelInput
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.scoring.braintrust import BraintrustScoringConfig
from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture, remote_stack_fixture
from ..env import get_env_or_fail

View file

@ -11,11 +11,9 @@ from ..conftest import (
get_provider_fixture_overrides_from_test_config,
get_test_config_for_api,
)
from ..inference.fixtures import INFERENCE_FIXTURES
from .fixtures import VECTOR_IO_FIXTURES
DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
{

View file

@ -12,7 +12,6 @@ import pytest_asyncio
from llama_stack.apis.models import ModelInput, ModelType
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.vector_io.chroma import ChromaInlineImplConfig
from llama_stack.providers.inline.vector_io.faiss import FaissImplConfig
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig

View file

@ -9,10 +9,8 @@ import uuid
import pytest
from llama_stack.apis.tools import RAGDocument
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB
from llama_stack.apis.vector_io import QueryChunksResponse
from llama_stack.providers.utils.memory.vector_store import make_overlapped_chunks
# How to run this test:

View file

@ -12,8 +12,7 @@ from pathlib import Path
import pytest
from llama_stack.apis.tools import RAGDocument
from llama_stack.providers.utils.memory.vector_store import content_from_doc, URL
from llama_stack.providers.utils.memory.vector_store import URL, content_from_doc
DUMMY_PDF_PATH = Path(os.path.abspath(__file__)).parent / "fixtures" / "dummy.pdf"

View file

@ -12,7 +12,6 @@ from llama_stack.apis.common.type_system import (
CompletionInputType,
StringType,
)
from llama_stack.distribution.datatypes import Api

View file

@ -11,7 +11,6 @@ from urllib.parse import unquote
import pandas
from llama_stack.apis.common.content_types import URL
from llama_stack.providers.utils.memory.vector_store import parse_data_url

View file

@ -11,7 +11,6 @@ from llama_models.sku_list import all_registered_models
from llama_stack.apis.models.models import ModelType
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference import (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
)

View file

@ -13,7 +13,6 @@ from llama_models.datatypes import (
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import StopReason, ToolCall
from openai.types.chat import ChatCompletionMessageToolCall
@ -26,7 +25,6 @@ from llama_stack.apis.common.content_types import (
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
@ -39,7 +37,6 @@ from llama_stack.apis.inference import (
Message,
TokenLogProbs,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_content_to_url,
)

View file

@ -13,7 +13,7 @@ import re
from typing import List, Optional, Tuple, Union
import httpx
from llama_models.datatypes import is_multimodal, ModelFamily
from llama_models.datatypes import ModelFamily, is_multimodal
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import (
RawContent,
@ -47,9 +47,9 @@ from llama_stack.apis.inference import (
ResponseFormat,
ResponseFormatType,
SystemMessage,
SystemMessageBehavior,
ToolChoice,
UserMessage,
SystemMessageBehavior,
)
from llama_stack.providers.utils.inference import supported_inference_models

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import os
from datetime import datetime
from typing import List, Optional

View file

@ -15,13 +15,14 @@ from urllib.parse import unquote
import chardet
import httpx
import numpy as np
from llama_models.llama3.api.tokenizer import Tokenizer
from numpy.typing import NDArray
from pypdf import PdfReader
from llama_stack.apis.common.content_types import (
URL,
InterleavedContent,
TextContentItem,
URL,
)
from llama_stack.apis.tools import RAGDocument
from llama_stack.apis.vector_dbs import VectorDB
@ -30,9 +31,6 @@ from llama_stack.providers.datatypes import Api
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from numpy.typing import NDArray
from pypdf import PdfReader
log = logging.getLogger(__name__)