build: format codebase imports using ruff linter (#1028)

# What does this PR do?

- Configured ruff linter to automatically fix import sorting issues.
- Set --exit-non-zero-on-fix to ensure non-zero exit code when fixes are
applied.
- Enabled the 'I' selection to focus on import-related linting rules.
- Ran the linter, and formatted all codebase imports accordingly.
- Removed the black dep from the "dev" group since we use ruff

Signed-off-by: Sébastien Han <seb@redhat.com>

[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan
[Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.*]

[//]: # (## Documentation)
[//]: # (- [ ] Added a Changelog entry if the change is significant)

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-02-13 19:06:21 +01:00 committed by GitHub
parent 1527c30107
commit e4a1579e63
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
140 changed files with 139 additions and 243 deletions

View file

@ -29,10 +29,12 @@ repos:
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.4 rev: v0.9.4
hooks: hooks:
# Run the linter with import sorting.
- id: ruff - id: ruff
args: [ args: [
--fix, --fix,
--exit-non-zero-on-fix --exit-non-zero-on-fix,
--select, I,
] ]
- id: ruff-format - id: ruff-format

View file

@ -15,14 +15,14 @@ from typing import (
Literal, Literal,
Optional, Optional,
Protocol, Protocol,
runtime_checkable,
Union, Union,
runtime_checkable,
) )
from llama_models.schema_utils import json_schema_type, register_schema, webmethod from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, URL from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
CompletionMessage, CompletionMessage,
ResponseFormat, ResponseFormat,

View file

@ -13,7 +13,6 @@ from termcolor import cprint
from llama_stack.apis.agents import AgentTurnResponseEventType, StepType from llama_stack.apis.agents import AgentTurnResponseEventType, StepType
from llama_stack.apis.common.content_types import ToolCallParseStatus from llama_stack.apis.common.content_types import ToolCallParseStatus
from llama_stack.apis.inference import ToolResponseMessage from llama_stack.apis.inference import ToolResponseMessage
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str, interleaved_content_as_str,
) )

View file

@ -8,7 +8,6 @@ from enum import Enum
from typing import Annotated, List, Literal, Optional, Union from typing import Annotated, List, Literal, Optional, Union
from llama_models.llama3.api.datatypes import ToolCall from llama_models.llama3.api.datatypes import ToolCall
from llama_models.schema_utils import json_schema_type, register_schema from llama_models.schema_utils import json_schema_type, register_schema
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator

View file

@ -8,7 +8,6 @@ from enum import Enum
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL

View file

@ -12,8 +12,8 @@ from typing import (
Literal, Literal,
Optional, Optional,
Protocol, Protocol,
runtime_checkable,
Union, Union,
runtime_checkable,
) )
from llama_models.schema_utils import json_schema_type, register_schema, webmethod from llama_models.schema_utils import json_schema_type, register_schema, webmethod

View file

@ -5,11 +5,9 @@
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional, Protocol, Union from typing import Any, Dict, List, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.apis.inference import Message from llama_stack.apis.inference import Message

View file

@ -4,5 +4,5 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .tools import * # noqa: F401 F403
from .rag_tool import * # noqa: F401 F403 from .rag_tool import * # noqa: F401 F403
from .tools import * # noqa: F401 F403

View file

@ -11,7 +11,7 @@ from llama_models.schema_utils import json_schema_type, register_schema, webmeth
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Annotated, Protocol, runtime_checkable from typing_extensions import Annotated, Protocol, runtime_checkable
from llama_stack.apis.common.content_types import InterleavedContent, URL from llama_stack.apis.common.content_types import URL, InterleavedContent
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol

View file

@ -11,7 +11,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Protocol, runtime_checkable from typing_extensions import Protocol, runtime_checkable
from llama_stack.apis.common.content_types import InterleavedContent, URL from llama_stack.apis.common.content_types import URL, InterleavedContent
from llama_stack.apis.resource import Resource, ResourceType from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol

View file

@ -16,11 +16,9 @@ from pathlib import Path
from typing import Dict, List, Optional from typing import Dict, List, Optional
import httpx import httpx
from llama_models.datatypes import Model from llama_models.datatypes import Model
from llama_models.sku_list import LlamaDownloadInfo from llama_models.sku_list import LlamaDownloadInfo
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from rich.console import Console from rich.console import Console
from rich.progress import ( from rich.progress import (
BarColumn, BarColumn,

View file

@ -8,7 +8,6 @@ import argparse
import json import json
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from termcolor import colored from termcolor import colored
from llama_stack.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand

View file

@ -11,7 +11,6 @@ from llama_stack.cli.model.download import ModelDownload
from llama_stack.cli.model.list import ModelList from llama_stack.cli.model.list import ModelList
from llama_stack.cli.model.prompt_format import ModelPromptFormat from llama_stack.cli.model.prompt_format import ModelPromptFormat
from llama_stack.cli.model.verify_download import ModelVerifyDownload from llama_stack.cli.model.verify_download import ModelVerifyDownload
from llama_stack.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand

View file

@ -8,7 +8,7 @@ import argparse
import textwrap import textwrap
from io import StringIO from io import StringIO
from llama_models.datatypes import CoreModelId, is_multimodal, model_family, ModelFamily from llama_models.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family
from llama_stack.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand

View file

@ -9,7 +9,6 @@ from typing import Any, Dict, Optional
from llama_models.datatypes import CheckpointQuantizationFormat from llama_models.datatypes import CheckpointQuantizationFormat
from llama_models.llama3.api.datatypes import SamplingParams from llama_models.llama3.api.datatypes import SamplingParams
from llama_models.sku_list import LlamaDownloadInfo from llama_models.sku_list import LlamaDownloadInfo
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field

View file

@ -21,12 +21,11 @@ from prompt_toolkit.validation import Validator
from termcolor import cprint from termcolor import cprint
from llama_stack.cli.table import print_table from llama_stack.cli.table import print_table
from llama_stack.distribution.build import ( from llama_stack.distribution.build import (
SERVER_DEPENDENCIES,
ImageType,
build_image, build_image,
get_provider_dependencies, get_provider_dependencies,
ImageType,
SERVER_DEPENDENCIES,
) )
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
BuildConfig, BuildConfig,

View file

@ -8,6 +8,7 @@ from datetime import datetime
import pytest import pytest
import yaml import yaml
from llama_stack.distribution.configure import ( from llama_stack.distribution.configure import (
LLAMA_STACK_RUN_CONFIG_VERSION, LLAMA_STACK_RUN_CONFIG_VERSION,
parse_and_maybe_upgrade_config, parse_and_maybe_upgrade_config,

View file

@ -8,7 +8,6 @@ import importlib.resources
import logging import logging
import sys import sys
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Dict, List from typing import Dict, List
@ -16,11 +15,8 @@ from pydantic import BaseModel
from termcolor import cprint from termcolor import cprint
from llama_stack.distribution.datatypes import BuildConfig, Provider from llama_stack.distribution.datatypes import BuildConfig, Provider
from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
from llama_stack.distribution.utils.exec import run_command, run_with_pty from llama_stack.distribution.utils.exec import run_command, run_with_pty
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api

View file

@ -5,18 +5,16 @@
# the root directory of this source tree. # the root directory of this source tree.
import inspect import inspect
import json import json
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from enum import Enum from enum import Enum
from typing import Any, get_args, get_origin, Type, Union from typing import Any, Type, Union, get_args, get_origin
import httpx import httpx
from pydantic import BaseModel, parse_obj_as from pydantic import BaseModel, parse_obj_as
from termcolor import cprint from termcolor import cprint
from llama_stack.apis.version import LLAMA_STACK_API_VERSION from llama_stack.apis.version import LLAMA_STACK_API_VERSION
from llama_stack.providers.datatypes import RemoteProviderConfig from llama_stack.providers.datatypes import RemoteProviderConfig
_CLIENT_CLASSES = {} _CLIENT_CLASSES = {}

View file

@ -5,12 +5,11 @@
# the root directory of this source tree. # the root directory of this source tree.
import logging import logging
import textwrap import textwrap
from typing import Any, Dict from typing import Any, Dict
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
DistributionSpec,
LLAMA_STACK_RUN_CONFIG_VERSION, LLAMA_STACK_RUN_CONFIG_VERSION,
DistributionSpec,
Provider, Provider,
StackRunConfig, StackRunConfig,
) )
@ -20,7 +19,6 @@ from llama_stack.distribution.distribution import (
) )
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
from llama_stack.providers.datatypes import Api, ProviderSpec from llama_stack.providers.datatypes import Api, ProviderSpec
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View file

@ -13,10 +13,21 @@ import re
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Any, get_args, get_origin, Optional, TypeVar from typing import Any, Optional, TypeVar, get_args, get_origin
import httpx import httpx
import yaml import yaml
from llama_stack_client import (
NOT_GIVEN,
APIResponse,
AsyncAPIResponse,
AsyncLlamaStackClient,
AsyncStream,
LlamaStackClient,
)
from pydantic import BaseModel, TypeAdapter
from rich.console import Console
from termcolor import cprint
from llama_stack.distribution.build import print_pip_install_help from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
@ -35,17 +46,6 @@ from llama_stack.providers.utils.telemetry.tracing import (
setup_logger, setup_logger,
start_trace, start_trace,
) )
from llama_stack_client import (
APIResponse,
AsyncAPIResponse,
AsyncLlamaStackClient,
AsyncStream,
LlamaStackClient,
NOT_GIVEN,
)
from pydantic import BaseModel, TypeAdapter
from rich.console import Console
from termcolor import cprint
T = TypeVar("T") T = TypeVar("T")

View file

@ -7,7 +7,6 @@
from typing import Any, Dict from typing import Any, Dict
from llama_stack.distribution.datatypes import RoutedProtocol from llama_stack.distribution.datatypes import RoutedProtocol
from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.store import DistributionRegistry
from llama_stack.providers.datatypes import Api, RoutingTable from llama_stack.providers.datatypes import Api, RoutingTable

View file

@ -6,7 +6,7 @@
from typing import Any, AsyncGenerator, Dict, List, Optional from typing import Any, AsyncGenerator, Dict, List, Optional
from llama_stack.apis.common.content_types import InterleavedContent, URL from llama_stack.apis.common.content_types import URL, InterleavedContent
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.eval import ( from llama_stack.apis.eval import (
AppEvalTaskConfig, AppEvalTaskConfig,

View file

@ -10,11 +10,8 @@ from typing import Dict, List
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
from llama_stack.apis.version import LLAMA_STACK_API_VERSION from llama_stack.apis.version import LLAMA_STACK_API_VERSION
from llama_stack.distribution.resolver import api_protocol_map from llama_stack.distribution.resolver import api_protocol_map
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api

View file

@ -7,9 +7,9 @@
import argparse import argparse
import asyncio import asyncio
import functools import functools
import logging
import inspect import inspect
import json import json
import logging
import os import os
import signal import signal
import sys import sys
@ -21,7 +21,8 @@ from pathlib import Path
from typing import Any, List, Union from typing import Any, List, Union
import yaml import yaml
from fastapi import Body, FastAPI, HTTPException, Path as FastapiPath, Request from fastapi import Body, FastAPI, HTTPException, Request
from fastapi import Path as FastapiPath
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError

View file

@ -8,9 +8,9 @@ import os
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from llama_stack.apis.inference import Model from llama_stack.apis.inference import Model
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.distribution.store.registry import ( from llama_stack.distribution.store.registry import (
CachedDiskDistributionRegistry, CachedDiskDistributionRegistry,
DiskDistributionRegistry, DiskDistributionRegistry,

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import os import os
from typing import Optional from typing import Optional
from llama_stack_client import LlamaStackClient from llama_stack_client import LlamaStackClient

View file

@ -10,7 +10,6 @@ from page.distribution.models import models
from page.distribution.scoring_functions import scoring_functions from page.distribution.scoring_functions import scoring_functions
from page.distribution.shields import shields from page.distribution.shields import shields
from page.distribution.vector_dbs import vector_dbs from page.distribution.vector_dbs import vector_dbs
from streamlit_option_menu import option_menu from streamlit_option_menu import option_menu

View file

@ -8,7 +8,6 @@ import json
import pandas as pd import pandas as pd
import streamlit as st import streamlit as st
from modules.api import llama_stack_api from modules.api import llama_stack_api
from modules.utils import process_dataset from modules.utils import process_dataset

View file

@ -7,9 +7,7 @@
import json import json
import pandas as pd import pandas as pd
import streamlit as st import streamlit as st
from modules.api import llama_stack_api from modules.api import llama_stack_api

View file

@ -9,7 +9,6 @@ from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types.agent_create_params import AgentConfig from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.types.memory_insert_params import Document from llama_stack_client.types.memory_insert_params import Document
from modules.api import llama_stack_api from modules.api import llama_stack_api
from modules.utils import data_url_from_file from modules.utils import data_url_from_file

View file

@ -7,7 +7,6 @@
import os import os
from pathlib import Path from pathlib import Path
LLAMA_STACK_CONFIG_DIR = Path(os.getenv("LLAMA_STACK_CONFIG_DIR", os.path.expanduser("~/.llama/"))) LLAMA_STACK_CONFIG_DIR = Path(os.getenv("LLAMA_STACK_CONFIG_DIR", os.path.expanduser("~/.llama/")))
DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions" DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"

View file

@ -8,13 +8,11 @@ import inspect
import json import json
import logging import logging
from enum import Enum from enum import Enum
from typing import Any, List, Literal, Optional, Type, Union, get_args, get_origin
from typing import Any, get_args, get_origin, List, Literal, Optional, Type, Union
from pydantic import BaseModel from pydantic import BaseModel
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefinedType from pydantic_core import PydanticUndefinedType
from typing_extensions import Annotated from typing_extensions import Annotated
log = logging.getLogger(__name__) log = logging.getLogger(__name__)

View file

@ -11,7 +11,6 @@ from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_stack.apis.datasets import Dataset from llama_stack.apis.datasets import Dataset
from llama_stack.apis.datatypes import Api from llama_stack.apis.datatypes import Api
from llama_stack.apis.eval_tasks import EvalTask from llama_stack.apis.eval_tasks import EvalTask
from llama_stack.apis.models import Model from llama_stack.apis.models import Model

View file

@ -42,10 +42,10 @@ from llama_stack.apis.agents import (
Turn, Turn,
) )
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
URL,
TextContentItem, TextContentItem,
ToolCallDelta, ToolCallDelta,
ToolCallParseStatus, ToolCallParseStatus,
URL,
) )
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionResponseEventType, ChatCompletionResponseEventType,

View file

@ -6,11 +6,9 @@
import asyncio import asyncio
import logging import logging
from typing import List from typing import List
from llama_stack.apis.inference import Message from llama_stack.apis.inference import Message
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
log = logging.getLogger(__name__) log = logging.getLogger(__name__)

View file

@ -41,7 +41,6 @@ from llama_stack.apis.tools import (
ToolInvocationResult, ToolInvocationResult,
) )
from llama_stack.apis.vector_io import QueryChunksResponse from llama_stack.apis.vector_io import QueryChunksResponse
from llama_stack.providers.inline.agents.meta_reference.agent_instance import ( from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
MEMORY_QUERY_TOOL, MEMORY_QUERY_TOOL,
) )

View file

@ -15,14 +15,12 @@ import pandas
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.datasets import Dataset from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from .config import LocalFSDatasetIOConfig from .config import LocalFSDatasetIOConfig
DATASETS_PREFIX = "localfs_datasets:" DATASETS_PREFIX = "localfs_datasets:"

View file

@ -15,7 +15,6 @@ from llama_stack.apis.inference import Inference, UserMessage
from llama_stack.apis.scoring import Scoring from llama_stack.apis.scoring import Scoring
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import EvalTasksProtocolPrivate from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
from llama_stack.providers.inline.agents.meta_reference.agent_instance import ( from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
MEMORY_QUERY_TOOL, MEMORY_QUERY_TOOL,
) )
@ -28,7 +27,6 @@ from llama_stack.providers.utils.kvstore import kvstore_impl
from .....apis.common.job_types import Job from .....apis.common.job_types import Job
from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
from .config import MetaReferenceEvalConfig from .config import MetaReferenceEvalConfig
EVAL_TASKS_PREFIX = "eval_tasks:" EVAL_TASKS_PREFIX = "eval_tasks:"

View file

@ -9,7 +9,6 @@ from typing import Any, Dict, Optional
from pydantic import BaseModel, field_validator from pydantic import BaseModel, field_validator
from llama_stack.apis.inference import QuantizationConfig from llama_stack.apis.inference import QuantizationConfig
from llama_stack.providers.utils.inference import supported_inference_models from llama_stack.providers.utils.inference import supported_inference_models

View file

@ -37,7 +37,6 @@ from llama_models.llama3.reference_impl.multimodal.model import (
CrossAttentionTransformer, CrossAttentionTransformer,
) )
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
from pydantic import BaseModel from pydantic import BaseModel
@ -47,7 +46,6 @@ from llama_stack.apis.inference import (
ResponseFormat, ResponseFormat,
ResponseFormatType, ResponseFormatType,
) )
from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent, ChatCompletionRequestWithRawContent,

View file

@ -46,8 +46,8 @@ from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin, SentenceTransformerEmbeddingMixin,
) )
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper, ModelRegistryHelper,
build_model_alias,
) )
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
augment_content_with_response_format_prompt, augment_content_with_response_format_prompt,

View file

@ -22,16 +22,13 @@ from typing import Callable, Generator, Literal, Optional, Union
import torch import torch
import zmq import zmq
from fairscale.nn.model_parallel.initialize import ( from fairscale.nn.model_parallel.initialize import (
get_model_parallel_group, get_model_parallel_group,
get_model_parallel_rank, get_model_parallel_rank,
get_model_parallel_src_rank, get_model_parallel_src_rank,
) )
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (

View file

@ -8,7 +8,6 @@
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import collections import collections
import logging import logging
from typing import Optional, Type from typing import Optional, Type
@ -23,7 +22,7 @@ except ImportError:
raise raise
import torch import torch
from torch import nn, Tensor from torch import Tensor, nn
class Fp8ScaledWeights: class Fp8ScaledWeights:

View file

@ -10,9 +10,9 @@
import unittest import unittest
import torch import torch
from fp8_impls import FfnQuantizeMode, ffn_swiglu_fp8_dynamic, quantize_fp8
from fp8_impls import ffn_swiglu_fp8_dynamic, FfnQuantizeMode, quantize_fp8 from hypothesis import given, settings
from hypothesis import given, settings, strategies as st from hypothesis import strategies as st
from torch import Tensor from torch import Tensor

View file

@ -12,18 +12,13 @@ import os
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import torch import torch
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from llama_models.datatypes import CheckpointQuantizationFormat from llama_models.datatypes import CheckpointQuantizationFormat
from llama_models.llama3.api.args import ModelArgs from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from torch import Tensor, nn
from torch import nn, Tensor
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
from llama_stack.apis.inference import QuantizationType from llama_stack.apis.inference import QuantizationType

View file

@ -16,14 +16,12 @@ from pathlib import Path
from typing import Optional from typing import Optional
import fire import fire
import torch import torch
from fairscale.nn.model_parallel.initialize import ( from fairscale.nn.model_parallel.initialize import (
get_model_parallel_rank, get_model_parallel_rank,
initialize_model_parallel, initialize_model_parallel,
model_parallel_is_initialized, model_parallel_is_initialized,
) )
from llama_models.llama3.api.args import ModelArgs from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock

View file

@ -15,9 +15,9 @@ from llama_stack.apis.inference import (
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
ToolChoice, ToolChoice,
ToolConfig,
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
ToolConfig,
) )
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import ( from llama_stack.providers.utils.inference.embedding_mixin import (

View file

@ -37,9 +37,9 @@ from llama_stack.apis.inference import (
from llama_stack.apis.models import Model from llama_stack.apis.models import Model
from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
OpenAICompatCompletionChoice, OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse, OpenAICompatCompletionResponse,
get_sampling_options,
process_chat_completion_response, process_chat_completion_response,
process_chat_completion_stream_response, process_chat_completion_stream_response,
) )

View file

@ -15,10 +15,8 @@ from typing import Any, Callable, Dict
import torch import torch
from llama_models.datatypes import Model from llama_models.datatypes import Model
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from pydantic import BaseModel from pydantic import BaseModel
from torchtune.data._messages import InputOutputToMessages, ShareGPTToMessages from torchtune.data._messages import InputOutputToMessages, ShareGPTToMessages
from torchtune.models.llama3 import llama3_tokenizer from torchtune.models.llama3 import llama3_tokenizer
from torchtune.models.llama3._tokenizer import Llama3Tokenizer from torchtune.models.llama3._tokenizer import Llama3Tokenizer
from torchtune.models.llama3_1 import lora_llama3_1_8b from torchtune.models.llama3_1 import lora_llama3_1_8b

View file

@ -13,7 +13,6 @@
from typing import Any, Dict, List, Mapping from typing import Any, Dict, List, Mapping
import numpy as np import numpy as np
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
from torchtune.data._messages import validate_messages from torchtune.data._messages import validate_messages

View file

@ -18,9 +18,9 @@ from llama_models.sku_list import resolve_model
from torch import nn from torch import nn
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler from torch.utils.data import DataLoader, DistributedSampler
from torchtune import modules, training, utils as torchtune_utils from torchtune import modules, training
from torchtune import utils as torchtune_utils
from torchtune.data import padded_collate_sft from torchtune.data import padded_collate_sft
from torchtune.modules.loss import CEWithChunkedOutputLoss from torchtune.modules.loss import CEWithChunkedOutputLoss
from torchtune.modules.peft import ( from torchtune.modules.peft import (
get_adapter_params, get_adapter_params,
@ -44,14 +44,11 @@ from llama_stack.apis.post_training import (
OptimizerConfig, OptimizerConfig,
TrainingConfig, TrainingConfig,
) )
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.inline.post_training.common.validator import ( from llama_stack.providers.inline.post_training.common.validator import (
validate_input_dataset_schema, validate_input_dataset_schema,
) )
from llama_stack.providers.inline.post_training.torchtune.common import utils from llama_stack.providers.inline.post_training.torchtune.common import utils
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import ( from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
TorchtuneCheckpointer, TorchtuneCheckpointer,

View file

@ -21,7 +21,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import CodeScannerConfig from .config import CodeScannerConfig
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
ALLOWED_CODE_SCANNER_MODEL_IDS = [ ALLOWED_CODE_SCANNER_MODEL_IDS = [

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import re import re
from string import Template from string import Template
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
@ -25,10 +24,8 @@ from llama_stack.apis.safety import (
SafetyViolation, SafetyViolation,
ViolationLevel, ViolationLevel,
) )
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str, interleaved_content_as_str,
@ -36,7 +33,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import LlamaGuardConfig from .config import LlamaGuardConfig
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?" CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
SAFE_RESPONSE = "safe" SAFE_RESPONSE = "safe"

View file

@ -8,7 +8,6 @@ import logging
from typing import Any, Dict, List from typing import Any, Dict, List
import torch import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer from transformers import AutoModelForSequenceClassification, AutoTokenizer
from llama_stack.apis.inference import Message from llama_stack.apis.inference import Message
@ -19,7 +18,6 @@ from llama_stack.apis.safety import (
ViolationLevel, ViolationLevel,
) )
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (

View file

@ -14,13 +14,13 @@ from llama_stack.apis.scoring import (
ScoringResult, ScoringResult,
) )
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator import ( from llama_stack.providers.utils.common.data_schema_validator import (
get_valid_schemas, get_valid_schemas,
validate_dataset_schema, validate_dataset_schema,
) )
from .config import BasicScoringConfig from .config import BasicScoringConfig
from .scoring_fn.equality_scoring_fn import EqualityScoringFn from .scoring_fn.equality_scoring_fn import EqualityScoringFn
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn

View file

@ -7,7 +7,6 @@
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn

View file

@ -11,7 +11,6 @@ from llama_stack.apis.scoring_functions import (
ScoringFn, ScoringFn,
) )
equality = ScoringFn( equality = ScoringFn(
identifier="basic::equality", identifier="basic::equality",
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.", description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",

View file

@ -11,7 +11,6 @@ from llama_stack.apis.scoring_functions import (
ScoringFn, ScoringFn,
) )
subset_of = ScoringFn( subset_of = ScoringFn(
identifier="basic::subset_of", identifier="basic::subset_of",
description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.", description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.",

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import re import re
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring import ScoringResultRow

View file

@ -29,9 +29,7 @@ from llama_stack.apis.scoring import (
ScoringResultRow, ScoringResultRow,
) )
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator import ( from llama_stack.providers.utils.common.data_schema_validator import (
@ -39,8 +37,8 @@ from llama_stack.providers.utils.common.data_schema_validator import (
validate_dataset_schema, validate_dataset_schema,
validate_row_schema, validate_row_schema,
) )
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
from .config import BraintrustScoringConfig from .config import BraintrustScoringConfig
from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def
from .scoring_fn.fn_defs.answer_relevancy import answer_relevancy_fn_def from .scoring_fn.fn_defs.answer_relevancy import answer_relevancy_fn_def

View file

@ -11,7 +11,6 @@ from llama_stack.apis.scoring_functions import (
ScoringFn, ScoringFn,
) )
answer_correctness_fn_def = ScoringFn( answer_correctness_fn_def = ScoringFn(
identifier="braintrust::answer-correctness", identifier="braintrust::answer-correctness",
description=( description=(

View file

@ -11,7 +11,6 @@ from llama_stack.apis.scoring_functions import (
ScoringFn, ScoringFn,
) )
factuality_fn_def = ScoringFn( factuality_fn_def = ScoringFn(
identifier="braintrust::factuality", identifier="braintrust::factuality",
description=( description=(

View file

@ -8,7 +8,6 @@ from typing import Any, Dict, List, Optional
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets from llama_stack.apis.datasets import Datasets
from llama_stack.apis.inference.inference import Inference from llama_stack.apis.inference.inference import Inference
from llama_stack.apis.scoring import ( from llama_stack.apis.scoring import (
ScoreBatchResponse, ScoreBatchResponse,
ScoreResponse, ScoreResponse,
@ -26,7 +25,6 @@ from llama_stack.providers.utils.common.data_schema_validator import (
from .config import LlmAsJudgeScoringConfig from .config import LlmAsJudgeScoringConfig
from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn
LLM_JUDGE_FNS = [LlmAsJudgeScoringFn] LLM_JUDGE_FNS = [LlmAsJudgeScoringFn]

View file

@ -7,7 +7,6 @@
from llama_stack.apis.common.type_system import NumberType from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams, ScoringFn from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams, ScoringFn
llm_as_judge_base = ScoringFn( llm_as_judge_base = ScoringFn(
identifier="llm-as-judge::base", identifier="llm-as-judge::base",
description="Llm As Judge Scoring Function", description="Llm As Judge Scoring Function",

View file

@ -4,18 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import re import re
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from llama_stack.apis.inference.inference import Inference from llama_stack.apis.inference.inference import Inference
from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from .fn_defs.llm_as_judge_405b_simpleqa import llm_as_judge_405b_simpleqa from .fn_defs.llm_as_judge_405b_simpleqa import llm_as_judge_405b_simpleqa
from .fn_defs.llm_as_judge_base import llm_as_judge_base from .fn_defs.llm_as_judge_base import llm_as_judge_base

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.apis.telemetry import Telemetry from llama_stack.apis.telemetry import Telemetry
from .config import SampleConfig from .config import SampleConfig

View file

@ -82,7 +82,11 @@ import sys as _sys
# them with linters - they're used in code_execution.py # them with linters - they're used in code_execution.py
from contextlib import ( # noqa from contextlib import ( # noqa
contextmanager as _contextmanager, contextmanager as _contextmanager,
)
from contextlib import (
redirect_stderr as _redirect_stderr, redirect_stderr as _redirect_stderr,
)
from contextlib import (
redirect_stdout as _redirect_stdout, redirect_stdout as _redirect_stdout,
) )
from multiprocessing.connection import Connection as _Connection from multiprocessing.connection import Connection as _Connection

View file

@ -9,7 +9,6 @@ from jinja2 import Template
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import UserMessage from llama_stack.apis.inference import UserMessage
from llama_stack.apis.tools.rag_tool import ( from llama_stack.apis.tools.rag_tool import (
DefaultRAGQueryGeneratorConfig, DefaultRAGQueryGeneratorConfig,
LLMRAGQueryGeneratorConfig, LLMRAGQueryGeneratorConfig,

View file

@ -11,9 +11,9 @@ import string
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
URL,
InterleavedContent, InterleavedContent,
TextContentItem, TextContentItem,
URL,
) )
from llama_stack.apis.inference import Inference from llama_stack.apis.inference import Inference
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (

View file

@ -7,6 +7,7 @@
from typing import Dict from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import FaissImplConfig from .config import FaissImplConfig

View file

@ -8,11 +8,9 @@ import base64
import io import io
import json import json
import logging import logging
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import faiss import faiss
import numpy as np import numpy as np
from numpy.typing import NDArray from numpy.typing import NDArray

View file

@ -5,7 +5,9 @@
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import SQLiteVectorIOConfig from .config import SQLiteVectorIOConfig

View file

@ -5,9 +5,10 @@
# the root directory of this source tree. # the root directory of this source tree.
# config.py # config.py
from pydantic import BaseModel
from typing import Any, Dict from typing import Any, Dict
from pydantic import BaseModel
from llama_stack.providers.utils.kvstore.config import ( from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig, KVStoreConfig,
SqliteKVStoreConfig, SqliteKVStoreConfig,

View file

@ -4,13 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import sqlite3
import sqlite_vec
import struct
import logging import logging
import sqlite3
import struct
from typing import Any, Dict, List, Optional
import numpy as np import numpy as np
import sqlite_vec
from numpy.typing import NDArray from numpy.typing import NDArray
from typing import List, Optional, Dict, Any
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.apis.agents import Agents from llama_stack.apis.agents import Agents
from .config import SampleConfig from .config import SampleConfig

View file

@ -9,7 +9,6 @@ import datasets as hf_datasets
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.datasets import Dataset from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl

View file

@ -31,13 +31,13 @@ from llama_stack.apis.inference import (
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
from llama_stack.providers.utils.bedrock.client import create_bedrock_client from llama_stack.providers.utils.bedrock.client import create_bedrock_client
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper, ModelRegistryHelper,
build_model_alias,
) )
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_strategy_options,
OpenAICompatCompletionChoice, OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse, OpenAICompatCompletionResponse,
get_sampling_strategy_options,
process_chat_completion_response, process_chat_completion_response,
process_chat_completion_stream_response, process_chat_completion_stream_response,
) )

View file

@ -29,8 +29,8 @@ from llama_stack.apis.inference import (
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper, ModelRegistryHelper,
build_model_alias,
) )
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options, get_sampling_options,

View file

@ -26,8 +26,8 @@ from llama_stack.apis.inference import (
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper, ModelRegistryHelper,
build_model_alias,
) )
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options, get_sampling_options,

View file

@ -31,8 +31,8 @@ from llama_stack.apis.inference import (
) )
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper, ModelRegistryHelper,
build_model_alias,
) )
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict, convert_message_to_openai_dict,

View file

@ -31,9 +31,9 @@ from llama_stack.apis.inference import (
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.remote.inference.groq.config import GroqConfig from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_model_alias, build_model_alias,
build_model_alias_with_just_provider_model_id, build_model_alias_with_just_provider_model_id,
ModelRegistryHelper,
) )
from .groq_utils import ( from .groq_utils import (

View file

@ -24,10 +24,8 @@ from groq.types.chat.chat_completion_user_message_param import (
) )
from groq.types.chat.completion_create_params import CompletionCreateParams from groq.types.chat.completion_create_params import CompletionCreateParams
from groq.types.shared.function_definition import FunctionDefinition from groq.types.shared.function_definition import FunctionDefinition
from llama_models.llama3.api.datatypes import ToolParamDefinition from llama_models.llama3.api.datatypes import ToolParamDefinition
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
TextDelta, TextDelta,
ToolCallDelta, ToolCallDelta,
@ -47,9 +45,9 @@ from llama_stack.apis.inference import (
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_strategy_options,
convert_tool_call,
UnparseableToolCall, UnparseableToolCall,
convert_tool_call,
get_sampling_strategy_options,
) )

View file

@ -29,8 +29,8 @@ from llama_stack.apis.inference import (
ToolConfig, ToolConfig,
) )
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper, ModelRegistryHelper,
build_model_alias,
) )
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media from llama_stack.providers.utils.inference.prompt_adapter import content_has_media

View file

@ -22,17 +22,35 @@ from llama_models.llama3.api.datatypes import (
from openai import AsyncStream from openai import AsyncStream
from openai.types.chat import ( from openai.types.chat import (
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage, ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
)
from openai.types.chat import (
ChatCompletionChunk as OpenAIChatCompletionChunk, ChatCompletionChunk as OpenAIChatCompletionChunk,
)
from openai.types.chat import (
ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam, ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam,
)
from openai.types.chat import (
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam, ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
)
from openai.types.chat import (
ChatCompletionMessageParam as OpenAIChatCompletionMessage, ChatCompletionMessageParam as OpenAIChatCompletionMessage,
)
from openai.types.chat import (
ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall, ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall,
)
from openai.types.chat import (
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage, ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
)
from openai.types.chat import (
ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage, ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage,
)
from openai.types.chat import (
ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage, ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage,
) )
from openai.types.chat.chat_completion import ( from openai.types.chat.chat_completion import (
Choice as OpenAIChoice, Choice as OpenAIChoice,
)
from openai.types.chat.chat_completion import (
ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs
) )
from openai.types.chat.chat_completion_content_part_image_param import ( from openai.types.chat.chat_completion_content_part_image_param import (
@ -69,7 +87,6 @@ from llama_stack.apis.inference import (
ToolResponseMessage, ToolResponseMessage,
UserMessage, UserMessage,
) )
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_content_to_url, convert_image_content_to_url,
) )

View file

@ -8,7 +8,6 @@ from typing import Any, Dict
from pydantic import BaseModel from pydantic import BaseModel
DEFAULT_OLLAMA_URL = "http://localhost:11434" DEFAULT_OLLAMA_URL = "http://localhost:11434"

View file

@ -36,14 +36,14 @@ from llama_stack.apis.inference import (
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_model_alias, build_model_alias,
build_model_alias_with_just_provider_model_id, build_model_alias_with_just_provider_model_id,
ModelRegistryHelper,
) )
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
OpenAICompatCompletionChoice, OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse, OpenAICompatCompletionResponse,
get_sampling_options,
process_chat_completion_response, process_chat_completion_response,
process_chat_completion_stream_response, process_chat_completion_stream_response,
process_completion_response, process_completion_response,

View file

@ -8,14 +8,12 @@ from typing import AsyncGenerator
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from openai import OpenAI from openai import OpenAI
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
# from llama_stack.providers.datatypes import ModelsProtocolPrivate # from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options, get_sampling_options,
process_chat_completion_response, process_chat_completion_response,

View file

@ -24,8 +24,8 @@ from llama_stack.apis.common.content_types import (
) )
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper, ModelRegistryHelper,
build_model_alias,
) )
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
process_chat_completion_stream_response, process_chat_completion_stream_response,

View file

@ -6,6 +6,7 @@
from llama_stack.apis.inference import Inference from llama_stack.apis.inference import Inference
from llama_stack.apis.models import Model from llama_stack.apis.models import Model
from .config import SampleConfig from .config import SampleConfig

View file

@ -33,13 +33,13 @@ from llama_stack.apis.inference import (
from llama_stack.apis.models import Model from llama_stack.apis.models import Model
from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper, ModelRegistryHelper,
build_model_alias,
) )
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
OpenAICompatCompletionChoice, OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse, OpenAICompatCompletionResponse,
get_sampling_options,
process_chat_completion_response, process_chat_completion_response,
process_chat_completion_stream_response, process_chat_completion_stream_response,
process_completion_response, process_completion_response,

View file

@ -30,8 +30,8 @@ from llama_stack.apis.inference import (
) )
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper, ModelRegistryHelper,
build_model_alias,
) )
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict, convert_message_to_openai_dict,

View file

@ -13,10 +13,14 @@ from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import all_registered_models from llama_models.sku_list import all_registered_models
from openai import OpenAI from openai import OpenAI
from llama_stack.apis.common.content_types import InterleavedContent, ToolCallDelta, ToolCallParseStatus, TextDelta from llama_stack.apis.common.content_types import InterleavedContent, TextDelta, ToolCallDelta, ToolCallParseStatus
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
CompletionRequest, CompletionRequest,
CompletionResponse, CompletionResponse,
CompletionResponseStreamChunk, CompletionResponseStreamChunk,
@ -31,26 +35,22 @@ from llama_stack.apis.inference import (
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
CompletionMessage,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
ChatCompletionResponseEvent,
) )
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper, ModelRegistryHelper,
build_model_alias,
) )
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
get_sampling_options,
process_completion_response,
process_completion_stream_response,
OpenAICompatCompletionResponse, OpenAICompatCompletionResponse,
UnparseableToolCall, UnparseableToolCall,
convert_message_to_openai_dict,
convert_tool_call, convert_tool_call,
get_sampling_options,
process_chat_completion_stream_response, process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
) )
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
completion_request_to_prompt, completion_request_to_prompt,

View file

@ -6,11 +6,9 @@
import json import json
import logging import logging
from typing import Any, Dict, List from typing import Any, Dict, List
from llama_stack.apis.inference import Message from llama_stack.apis.inference import Message
from llama_stack.apis.safety import ( from llama_stack.apis.safety import (
RunShieldResponse, RunShieldResponse,
Safety, Safety,
@ -23,7 +21,6 @@ from llama_stack.providers.utils.bedrock.client import create_bedrock_client
from .config import BedrockSafetyConfig from .config import BedrockSafetyConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View file

@ -6,6 +6,7 @@
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
from .config import SampleConfig from .config import SampleConfig

View file

@ -7,7 +7,6 @@
from pydantic import BaseModel from pydantic import BaseModel
from .config import ModelContextProtocolConfig from .config import ModelContextProtocolConfig
from .model_context_protocol import ModelContextProtocolToolRuntimeImpl from .model_context_protocol import ModelContextProtocolToolRuntimeImpl

View file

@ -21,6 +21,7 @@ from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex, EmbeddingIndex,
VectorDBWithIndex, VectorDBWithIndex,
) )
from .config import ChromaRemoteImplConfig from .config import ChromaRemoteImplConfig
log = logging.getLogger(__name__) log = logging.getLogger(__name__)

View file

@ -10,15 +10,13 @@ from typing import Any, Dict, List, Optional, Tuple
import psycopg2 import psycopg2
from numpy.typing import NDArray from numpy.typing import NDArray
from psycopg2 import sql from psycopg2 import sql
from psycopg2.extras import execute_values, Json from psycopg2.extras import Json, execute_values
from pydantic import BaseModel, TypeAdapter from pydantic import BaseModel, TypeAdapter
from llama_stack.apis.inference import InterleavedContent from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex, EmbeddingIndex,
VectorDBWithIndex, VectorDBWithIndex,

View file

@ -20,6 +20,7 @@ from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex, EmbeddingIndex,
VectorDBWithIndex, VectorDBWithIndex,
) )
from .config import QdrantConfig from .config import QdrantConfig
log = logging.getLogger(__name__) log = logging.getLogger(__name__)

View file

@ -6,6 +6,7 @@
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from .config import SampleConfig from .config import SampleConfig

Some files were not shown because too many files have changed in this diff Show more