mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
build: format codebase imports using ruff linter (#1028)
# What does this PR do? - Configured ruff linter to automatically fix import sorting issues. - Set --exit-non-zero-on-fix to ensure non-zero exit code when fixes are applied. - Enabled the 'I' selection to focus on import-related linting rules. - Ran the linter, and formatted all codebase imports accordingly. - Removed the black dep from the "dev" group since we use ruff Signed-off-by: Sébastien Han <seb@redhat.com> [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] [//]: # (## Documentation) [//]: # (- [ ] Added a Changelog entry if the change is significant) Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
1527c30107
commit
e4a1579e63
140 changed files with 139 additions and 243 deletions
|
@ -29,10 +29,12 @@ repos:
|
|||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.9.4
|
||||
hooks:
|
||||
# Run the linter with import sorting.
|
||||
- id: ruff
|
||||
args: [
|
||||
--fix,
|
||||
--exit-non-zero-on-fix
|
||||
--exit-non-zero-on-fix,
|
||||
--select, I,
|
||||
]
|
||||
- id: ruff-format
|
||||
|
||||
|
|
|
@ -15,14 +15,14 @@ from typing import (
|
|||
Literal,
|
||||
Optional,
|
||||
Protocol,
|
||||
runtime_checkable,
|
||||
Union,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, URL
|
||||
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
|
||||
from llama_stack.apis.inference import (
|
||||
CompletionMessage,
|
||||
ResponseFormat,
|
||||
|
|
|
@ -13,7 +13,6 @@ from termcolor import cprint
|
|||
from llama_stack.apis.agents import AgentTurnResponseEventType, StepType
|
||||
from llama_stack.apis.common.content_types import ToolCallParseStatus
|
||||
from llama_stack.apis.inference import ToolResponseMessage
|
||||
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
|
|
@ -8,7 +8,6 @@ from enum import Enum
|
|||
from typing import Annotated, List, Literal, Optional, Union
|
||||
|
||||
from llama_models.llama3.api.datatypes import ToolCall
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, register_schema
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
|
|
@ -8,7 +8,6 @@ from enum import Enum
|
|||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
|
|
|
@ -12,8 +12,8 @@ from typing import (
|
|||
Literal,
|
||||
Optional,
|
||||
Protocol,
|
||||
runtime_checkable,
|
||||
Union,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
|
|
@ -5,11 +5,9 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from typing import Any, Dict, List, Optional, Protocol, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
|
|
|
@ -4,5 +4,5 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .tools import * # noqa: F401 F403
|
||||
from .rag_tool import * # noqa: F401 F403
|
||||
from .tools import * # noqa: F401 F403
|
||||
|
|
|
@ -11,7 +11,7 @@ from llama_models.schema_utils import json_schema_type, register_schema, webmeth
|
|||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated, Protocol, runtime_checkable
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, URL
|
||||
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
|||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Protocol, runtime_checkable
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, URL
|
||||
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
|
||||
|
|
|
@ -16,11 +16,9 @@ from pathlib import Path
|
|||
from typing import Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from llama_models.datatypes import Model
|
||||
from llama_models.sku_list import LlamaDownloadInfo
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from rich.console import Console
|
||||
from rich.progress import (
|
||||
BarColumn,
|
||||
|
|
|
@ -8,7 +8,6 @@ import argparse
|
|||
import json
|
||||
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from termcolor import colored
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
|
|
|
@ -11,7 +11,6 @@ from llama_stack.cli.model.download import ModelDownload
|
|||
from llama_stack.cli.model.list import ModelList
|
||||
from llama_stack.cli.model.prompt_format import ModelPromptFormat
|
||||
from llama_stack.cli.model.verify_download import ModelVerifyDownload
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ import argparse
|
|||
import textwrap
|
||||
from io import StringIO
|
||||
|
||||
from llama_models.datatypes import CoreModelId, is_multimodal, model_family, ModelFamily
|
||||
from llama_models.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
|
||||
|
|
|
@ -9,7 +9,6 @@ from typing import Any, Dict, Optional
|
|||
from llama_models.datatypes import CheckpointQuantizationFormat
|
||||
from llama_models.llama3.api.datatypes import SamplingParams
|
||||
from llama_models.sku_list import LlamaDownloadInfo
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
|
|
|
@ -21,12 +21,11 @@ from prompt_toolkit.validation import Validator
|
|||
from termcolor import cprint
|
||||
|
||||
from llama_stack.cli.table import print_table
|
||||
|
||||
from llama_stack.distribution.build import (
|
||||
SERVER_DEPENDENCIES,
|
||||
ImageType,
|
||||
build_image,
|
||||
get_provider_dependencies,
|
||||
ImageType,
|
||||
SERVER_DEPENDENCIES,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import (
|
||||
BuildConfig,
|
||||
|
|
|
@ -8,6 +8,7 @@ from datetime import datetime
|
|||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from llama_stack.distribution.configure import (
|
||||
LLAMA_STACK_RUN_CONFIG_VERSION,
|
||||
parse_and_maybe_upgrade_config,
|
||||
|
|
|
@ -8,7 +8,6 @@ import importlib.resources
|
|||
import logging
|
||||
import sys
|
||||
from enum import Enum
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
|
@ -16,11 +15,8 @@ from pydantic import BaseModel
|
|||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.datatypes import BuildConfig, Provider
|
||||
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
|
||||
|
||||
from llama_stack.distribution.utils.exec import run_command, run_with_pty
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
|
|
|
@ -5,18 +5,16 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import inspect
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from enum import Enum
|
||||
from typing import Any, get_args, get_origin, Type, Union
|
||||
from typing import Any, Type, Union, get_args, get_origin
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, parse_obj_as
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
||||
|
||||
from llama_stack.providers.datatypes import RemoteProviderConfig
|
||||
|
||||
_CLIENT_CLASSES = {}
|
||||
|
|
|
@ -5,12 +5,11 @@
|
|||
# the root directory of this source tree.
|
||||
import logging
|
||||
import textwrap
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import (
|
||||
DistributionSpec,
|
||||
LLAMA_STACK_RUN_CONFIG_VERSION,
|
||||
DistributionSpec,
|
||||
Provider,
|
||||
StackRunConfig,
|
||||
)
|
||||
|
@ -20,7 +19,6 @@ from llama_stack.distribution.distribution import (
|
|||
)
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
|
@ -13,10 +13,21 @@ import re
|
|||
from concurrent.futures import ThreadPoolExecutor
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, get_args, get_origin, Optional, TypeVar
|
||||
from typing import Any, Optional, TypeVar, get_args, get_origin
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
from llama_stack_client import (
|
||||
NOT_GIVEN,
|
||||
APIResponse,
|
||||
AsyncAPIResponse,
|
||||
AsyncLlamaStackClient,
|
||||
AsyncStream,
|
||||
LlamaStackClient,
|
||||
)
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
from rich.console import Console
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.build import print_pip_install_help
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
|
@ -35,17 +46,6 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
|||
setup_logger,
|
||||
start_trace,
|
||||
)
|
||||
from llama_stack_client import (
|
||||
APIResponse,
|
||||
AsyncAPIResponse,
|
||||
AsyncLlamaStackClient,
|
||||
AsyncStream,
|
||||
LlamaStackClient,
|
||||
NOT_GIVEN,
|
||||
)
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
from rich.console import Console
|
||||
from termcolor import cprint
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import RoutedProtocol
|
||||
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, URL
|
||||
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
||||
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
||||
from llama_stack.apis.eval import (
|
||||
AppEvalTaskConfig,
|
||||
|
|
|
@ -10,11 +10,8 @@ from typing import Dict, List
|
|||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
|
||||
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
||||
|
||||
from llama_stack.distribution.resolver import api_protocol_map
|
||||
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
|
||||
|
|
|
@ -7,9 +7,9 @@
|
|||
import argparse
|
||||
import asyncio
|
||||
import functools
|
||||
import logging
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
|
@ -21,7 +21,8 @@ from pathlib import Path
|
|||
from typing import Any, List, Union
|
||||
|
||||
import yaml
|
||||
from fastapi import Body, FastAPI, HTTPException, Path as FastapiPath, Request
|
||||
from fastapi import Body, FastAPI, HTTPException, Request
|
||||
from fastapi import Path as FastapiPath
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
|
|
@ -8,9 +8,9 @@ import os
|
|||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.inference import Model
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
|
||||
from llama_stack.distribution.store.registry import (
|
||||
CachedDiskDistributionRegistry,
|
||||
DiskDistributionRegistry,
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from llama_stack_client import LlamaStackClient
|
||||
|
|
|
@ -10,7 +10,6 @@ from page.distribution.models import models
|
|||
from page.distribution.scoring_functions import scoring_functions
|
||||
from page.distribution.shields import shields
|
||||
from page.distribution.vector_dbs import vector_dbs
|
||||
|
||||
from streamlit_option_menu import option_menu
|
||||
|
||||
|
||||
|
|
|
@ -8,7 +8,6 @@ import json
|
|||
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
|
||||
from modules.api import llama_stack_api
|
||||
from modules.utils import process_dataset
|
||||
|
||||
|
|
|
@ -7,9 +7,7 @@
|
|||
import json
|
||||
|
||||
import pandas as pd
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from modules.api import llama_stack_api
|
||||
|
||||
|
||||
|
|
|
@ -9,7 +9,6 @@ from llama_stack_client.lib.agents.agent import Agent
|
|||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||
from llama_stack_client.types.agent_create_params import AgentConfig
|
||||
from llama_stack_client.types.memory_insert_params import Document
|
||||
|
||||
from modules.api import llama_stack_api
|
||||
from modules.utils import data_url_from_file
|
||||
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
LLAMA_STACK_CONFIG_DIR = Path(os.getenv("LLAMA_STACK_CONFIG_DIR", os.path.expanduser("~/.llama/")))
|
||||
|
||||
DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"
|
||||
|
|
|
@ -8,13 +8,11 @@ import inspect
|
|||
import json
|
||||
import logging
|
||||
from enum import Enum
|
||||
|
||||
from typing import Any, get_args, get_origin, List, Literal, Optional, Type, Union
|
||||
from typing import Any, List, Literal, Optional, Type, Union, get_args, get_origin
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_core import PydanticUndefinedType
|
||||
|
||||
from typing_extensions import Annotated
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
|
@ -11,7 +11,6 @@ from llama_models.schema_utils import json_schema_type
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.datasets import Dataset
|
||||
|
||||
from llama_stack.apis.datatypes import Api
|
||||
from llama_stack.apis.eval_tasks import EvalTask
|
||||
from llama_stack.apis.models import Model
|
||||
|
|
|
@ -42,10 +42,10 @@ from llama_stack.apis.agents import (
|
|||
Turn,
|
||||
)
|
||||
from llama_stack.apis.common.content_types import (
|
||||
URL,
|
||||
TextContentItem,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
URL,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponseEventType,
|
||||
|
|
|
@ -6,11 +6,9 @@
|
|||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from typing import List
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
|
||||
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
|
@ -41,7 +41,6 @@ from llama_stack.apis.tools import (
|
|||
ToolInvocationResult,
|
||||
)
|
||||
from llama_stack.apis.vector_io import QueryChunksResponse
|
||||
|
||||
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
|
||||
MEMORY_QUERY_TOOL,
|
||||
)
|
||||
|
|
|
@ -15,14 +15,12 @@ import pandas
|
|||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
||||
from llama_stack.apis.datasets import Dataset
|
||||
|
||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
from .config import LocalFSDatasetIOConfig
|
||||
|
||||
|
||||
DATASETS_PREFIX = "localfs_datasets:"
|
||||
|
||||
|
||||
|
|
|
@ -15,7 +15,6 @@ from llama_stack.apis.inference import Inference, UserMessage
|
|||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
|
||||
|
||||
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
|
||||
MEMORY_QUERY_TOOL,
|
||||
)
|
||||
|
@ -28,7 +27,6 @@ from llama_stack.providers.utils.kvstore import kvstore_impl
|
|||
|
||||
from .....apis.common.job_types import Job
|
||||
from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
|
||||
|
||||
from .config import MetaReferenceEvalConfig
|
||||
|
||||
EVAL_TASKS_PREFIX = "eval_tasks:"
|
||||
|
|
|
@ -9,7 +9,6 @@ from typing import Any, Dict, Optional
|
|||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from llama_stack.apis.inference import QuantizationConfig
|
||||
|
||||
from llama_stack.providers.utils.inference import supported_inference_models
|
||||
|
||||
|
||||
|
|
|
@ -37,7 +37,6 @@ from llama_models.llama3.reference_impl.multimodal.model import (
|
|||
CrossAttentionTransformer,
|
||||
)
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -47,7 +46,6 @@ from llama_stack.apis.inference import (
|
|||
ResponseFormat,
|
||||
ResponseFormatType,
|
||||
)
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
|
|
|
@ -46,8 +46,8 @@ from llama_stack.providers.utils.inference.embedding_mixin import (
|
|||
SentenceTransformerEmbeddingMixin,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_model_alias,
|
||||
ModelRegistryHelper,
|
||||
build_model_alias,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
augment_content_with_response_format_prompt,
|
||||
|
|
|
@ -22,16 +22,13 @@ from typing import Callable, Generator, Literal, Optional, Union
|
|||
|
||||
import torch
|
||||
import zmq
|
||||
|
||||
from fairscale.nn.model_parallel.initialize import (
|
||||
get_model_parallel_group,
|
||||
get_model_parallel_rank,
|
||||
get_model_parallel_src_rank,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
|
||||
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
|
|
|
@ -8,7 +8,6 @@
|
|||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import collections
|
||||
|
||||
import logging
|
||||
from typing import Optional, Type
|
||||
|
||||
|
@ -23,7 +22,7 @@ except ImportError:
|
|||
raise
|
||||
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class Fp8ScaledWeights:
|
||||
|
|
|
@ -10,9 +10,9 @@
|
|||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from fp8_impls import ffn_swiglu_fp8_dynamic, FfnQuantizeMode, quantize_fp8
|
||||
from hypothesis import given, settings, strategies as st
|
||||
from fp8_impls import FfnQuantizeMode, ffn_swiglu_fp8_dynamic, quantize_fp8
|
||||
from hypothesis import given, settings
|
||||
from hypothesis import strategies as st
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
|
|
|
@ -12,18 +12,13 @@ import os
|
|||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
|
||||
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
||||
|
||||
from llama_models.datatypes import CheckpointQuantizationFormat
|
||||
|
||||
from llama_models.llama3.api.args import ModelArgs
|
||||
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from torch import nn, Tensor
|
||||
|
||||
from torch import Tensor, nn
|
||||
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
|
||||
|
||||
from llama_stack.apis.inference import QuantizationType
|
||||
|
|
|
@ -16,14 +16,12 @@ from pathlib import Path
|
|||
from typing import Optional
|
||||
|
||||
import fire
|
||||
|
||||
import torch
|
||||
from fairscale.nn.model_parallel.initialize import (
|
||||
get_model_parallel_rank,
|
||||
initialize_model_parallel,
|
||||
model_parallel_is_initialized,
|
||||
)
|
||||
|
||||
from llama_models.llama3.api.args import ModelArgs
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
|
||||
|
|
|
@ -15,9 +15,9 @@ from llama_stack.apis.inference import (
|
|||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
ToolConfig,
|
||||
)
|
||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||
|
|
|
@ -37,9 +37,9 @@ from llama_stack.apis.inference import (
|
|||
from llama_stack.apis.models import Model
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
|
|
|
@ -15,10 +15,8 @@ from typing import Any, Callable, Dict
|
|||
import torch
|
||||
from llama_models.datatypes import Model
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from pydantic import BaseModel
|
||||
from torchtune.data._messages import InputOutputToMessages, ShareGPTToMessages
|
||||
|
||||
from torchtune.models.llama3 import llama3_tokenizer
|
||||
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
||||
from torchtune.models.llama3_1 import lora_llama3_1_8b
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
from typing import Any, Dict, List, Mapping
|
||||
|
||||
import numpy as np
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
|
||||
from torchtune.data._messages import validate_messages
|
||||
|
|
|
@ -18,9 +18,9 @@ from llama_models.sku_list import resolve_model
|
|||
from torch import nn
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from torchtune import modules, training, utils as torchtune_utils
|
||||
from torchtune import modules, training
|
||||
from torchtune import utils as torchtune_utils
|
||||
from torchtune.data import padded_collate_sft
|
||||
|
||||
from torchtune.modules.loss import CEWithChunkedOutputLoss
|
||||
from torchtune.modules.peft import (
|
||||
get_adapter_params,
|
||||
|
@ -44,14 +44,11 @@ from llama_stack.apis.post_training import (
|
|||
OptimizerConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from llama_stack.providers.inline.post_training.common.validator import (
|
||||
validate_input_dataset_schema,
|
||||
)
|
||||
|
||||
from llama_stack.providers.inline.post_training.torchtune.common import utils
|
||||
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
|
||||
TorchtuneCheckpointer,
|
||||
|
|
|
@ -21,7 +21,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
from .config import CodeScannerConfig
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
ALLOWED_CODE_SCANNER_MODEL_IDS = [
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import re
|
||||
|
||||
from string import Template
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
@ -25,10 +24,8 @@ from llama_stack.apis.safety import (
|
|||
SafetyViolation,
|
||||
ViolationLevel,
|
||||
)
|
||||
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
|
@ -36,7 +33,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
from .config import LlamaGuardConfig
|
||||
|
||||
|
||||
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
||||
|
||||
SAFE_RESPONSE = "safe"
|
||||
|
|
|
@ -8,7 +8,6 @@ import logging
|
|||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
|
@ -19,7 +18,6 @@ from llama_stack.apis.safety import (
|
|||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
|
|
|
@ -14,13 +14,13 @@ from llama_stack.apis.scoring import (
|
|||
ScoringResult,
|
||||
)
|
||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||
from llama_stack.providers.utils.common.data_schema_validator import (
|
||||
get_valid_schemas,
|
||||
validate_dataset_schema,
|
||||
)
|
||||
|
||||
from .config import BasicScoringConfig
|
||||
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
|
||||
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
|
||||
|
|
|
@ -11,7 +11,6 @@ from llama_stack.apis.scoring_functions import (
|
|||
ScoringFn,
|
||||
)
|
||||
|
||||
|
||||
equality = ScoringFn(
|
||||
identifier="basic::equality",
|
||||
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
|
||||
|
|
|
@ -11,7 +11,6 @@ from llama_stack.apis.scoring_functions import (
|
|||
ScoringFn,
|
||||
)
|
||||
|
||||
|
||||
subset_of = ScoringFn(
|
||||
identifier="basic::subset_of",
|
||||
description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.",
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import re
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
|
|
|
@ -29,9 +29,7 @@ from llama_stack.apis.scoring import (
|
|||
ScoringResultRow,
|
||||
)
|
||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||
from llama_stack.providers.utils.common.data_schema_validator import (
|
||||
|
@ -39,8 +37,8 @@ from llama_stack.providers.utils.common.data_schema_validator import (
|
|||
validate_dataset_schema,
|
||||
validate_row_schema,
|
||||
)
|
||||
|
||||
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
|
||||
|
||||
from .config import BraintrustScoringConfig
|
||||
from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def
|
||||
from .scoring_fn.fn_defs.answer_relevancy import answer_relevancy_fn_def
|
||||
|
|
|
@ -11,7 +11,6 @@ from llama_stack.apis.scoring_functions import (
|
|||
ScoringFn,
|
||||
)
|
||||
|
||||
|
||||
answer_correctness_fn_def = ScoringFn(
|
||||
identifier="braintrust::answer-correctness",
|
||||
description=(
|
||||
|
|
|
@ -11,7 +11,6 @@ from llama_stack.apis.scoring_functions import (
|
|||
ScoringFn,
|
||||
)
|
||||
|
||||
|
||||
factuality_fn_def = ScoringFn(
|
||||
identifier="braintrust::factuality",
|
||||
description=(
|
||||
|
|
|
@ -8,7 +8,6 @@ from typing import Any, Dict, List, Optional
|
|||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.inference.inference import Inference
|
||||
|
||||
from llama_stack.apis.scoring import (
|
||||
ScoreBatchResponse,
|
||||
ScoreResponse,
|
||||
|
@ -26,7 +25,6 @@ from llama_stack.providers.utils.common.data_schema_validator import (
|
|||
from .config import LlmAsJudgeScoringConfig
|
||||
from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn
|
||||
|
||||
|
||||
LLM_JUDGE_FNS = [LlmAsJudgeScoringFn]
|
||||
|
||||
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams, ScoringFn
|
||||
|
||||
|
||||
llm_as_judge_base = ScoringFn(
|
||||
identifier="llm-as-judge::base",
|
||||
description="Llm As Judge Scoring Function",
|
||||
|
|
|
@ -4,18 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import re
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.apis.inference.inference import Inference
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
|
||||
from .fn_defs.llm_as_judge_405b_simpleqa import llm_as_judge_405b_simpleqa
|
||||
|
||||
from .fn_defs.llm_as_judge_base import llm_as_judge_base
|
||||
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.telemetry import Telemetry
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
|
|
|
@ -82,7 +82,11 @@ import sys as _sys
|
|||
# them with linters - they're used in code_execution.py
|
||||
from contextlib import ( # noqa
|
||||
contextmanager as _contextmanager,
|
||||
)
|
||||
from contextlib import (
|
||||
redirect_stderr as _redirect_stderr,
|
||||
)
|
||||
from contextlib import (
|
||||
redirect_stdout as _redirect_stdout,
|
||||
)
|
||||
from multiprocessing.connection import Connection as _Connection
|
||||
|
|
|
@ -9,7 +9,6 @@ from jinja2 import Template
|
|||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.inference import UserMessage
|
||||
|
||||
from llama_stack.apis.tools.rag_tool import (
|
||||
DefaultRAGQueryGeneratorConfig,
|
||||
LLMRAGQueryGeneratorConfig,
|
||||
|
|
|
@ -11,9 +11,9 @@ import string
|
|||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
URL,
|
||||
InterleavedContent,
|
||||
TextContentItem,
|
||||
URL,
|
||||
)
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.tools import (
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
from typing import Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import FaissImplConfig
|
||||
|
||||
|
||||
|
|
|
@ -8,11 +8,9 @@ import base64
|
|||
import io
|
||||
import json
|
||||
import logging
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import faiss
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
|
|
|
@ -5,7 +5,9 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import SQLiteVectorIOConfig
|
||||
|
||||
|
||||
|
|
|
@ -5,9 +5,10 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
# config.py
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
|
|
|
@ -4,13 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import sqlite3
|
||||
import sqlite_vec
|
||||
import struct
|
||||
import logging
|
||||
import sqlite3
|
||||
import struct
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import sqlite_vec
|
||||
from numpy.typing import NDArray
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
|
|
|
@ -9,7 +9,6 @@ import datasets as hf_datasets
|
|||
|
||||
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
||||
from llama_stack.apis.datasets import Dataset
|
||||
|
||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
|
|
@ -31,13 +31,13 @@ from llama_stack.apis.inference import (
|
|||
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
||||
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_model_alias,
|
||||
ModelRegistryHelper,
|
||||
build_model_alias,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_strategy_options,
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
get_sampling_strategy_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
|
|
|
@ -29,8 +29,8 @@ from llama_stack.apis.inference import (
|
|||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_model_alias,
|
||||
ModelRegistryHelper,
|
||||
build_model_alias,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
|
|
|
@ -26,8 +26,8 @@ from llama_stack.apis.inference import (
|
|||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_model_alias,
|
||||
ModelRegistryHelper,
|
||||
build_model_alias,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
|
|
|
@ -31,8 +31,8 @@ from llama_stack.apis.inference import (
|
|||
)
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_model_alias,
|
||||
ModelRegistryHelper,
|
||||
build_model_alias,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_message_to_openai_dict,
|
||||
|
|
|
@ -31,9 +31,9 @@ from llama_stack.apis.inference import (
|
|||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
build_model_alias,
|
||||
build_model_alias_with_just_provider_model_id,
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
|
||||
from .groq_utils import (
|
||||
|
|
|
@ -24,10 +24,8 @@ from groq.types.chat.chat_completion_user_message_param import (
|
|||
)
|
||||
from groq.types.chat.completion_create_params import CompletionCreateParams
|
||||
from groq.types.shared.function_definition import FunctionDefinition
|
||||
|
||||
from llama_models.llama3.api.datatypes import ToolParamDefinition
|
||||
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
TextDelta,
|
||||
ToolCallDelta,
|
||||
|
@ -47,9 +45,9 @@ from llama_stack.apis.inference import (
|
|||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_strategy_options,
|
||||
convert_tool_call,
|
||||
UnparseableToolCall,
|
||||
convert_tool_call,
|
||||
get_sampling_strategy_options,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -29,8 +29,8 @@ from llama_stack.apis.inference import (
|
|||
ToolConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_model_alias,
|
||||
ModelRegistryHelper,
|
||||
build_model_alias,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
|
||||
|
||||
|
|
|
@ -22,17 +22,35 @@ from llama_models.llama3.api.datatypes import (
|
|||
from openai import AsyncStream
|
||||
from openai.types.chat import (
|
||||
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
|
||||
)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionChunk as OpenAIChatCompletionChunk,
|
||||
)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam,
|
||||
)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
|
||||
)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessageParam as OpenAIChatCompletionMessage,
|
||||
)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall,
|
||||
)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
|
||||
)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage,
|
||||
)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage,
|
||||
)
|
||||
from openai.types.chat.chat_completion import (
|
||||
Choice as OpenAIChoice,
|
||||
)
|
||||
from openai.types.chat.chat_completion import (
|
||||
ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs
|
||||
)
|
||||
from openai.types.chat.chat_completion_content_part_image_param import (
|
||||
|
@ -69,7 +87,6 @@ from llama_stack.apis.inference import (
|
|||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
convert_image_content_to_url,
|
||||
)
|
||||
|
|
|
@ -8,7 +8,6 @@ from typing import Any, Dict
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
DEFAULT_OLLAMA_URL = "http://localhost:11434"
|
||||
|
||||
|
||||
|
|
|
@ -36,14 +36,14 @@ from llama_stack.apis.inference import (
|
|||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
build_model_alias,
|
||||
build_model_alias_with_just_provider_model_id,
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
process_completion_response,
|
||||
|
|
|
@ -8,14 +8,12 @@ from typing import AsyncGenerator
|
|||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.datatypes import Message
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
# from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
|
|
|
@ -24,8 +24,8 @@ from llama_stack.apis.common.content_types import (
|
|||
)
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_model_alias,
|
||||
ModelRegistryHelper,
|
||||
build_model_alias,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
process_chat_completion_stream_response,
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.models import Model
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
|
|
|
@ -33,13 +33,13 @@ from llama_stack.apis.inference import (
|
|||
from llama_stack.apis.models import Model
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_model_alias,
|
||||
ModelRegistryHelper,
|
||||
build_model_alias,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
process_completion_response,
|
||||
|
|
|
@ -30,8 +30,8 @@ from llama_stack.apis.inference import (
|
|||
)
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_model_alias,
|
||||
ModelRegistryHelper,
|
||||
build_model_alias,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_message_to_openai_dict,
|
||||
|
|
|
@ -13,10 +13,14 @@ from llama_models.llama3.api.tokenizer import Tokenizer
|
|||
from llama_models.sku_list import all_registered_models
|
||||
from openai import OpenAI
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, ToolCallDelta, ToolCallParseStatus, TextDelta
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, TextDelta, ToolCallDelta, ToolCallParseStatus
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
|
@ -31,26 +35,22 @@ from llama_stack.apis.inference import (
|
|||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
CompletionMessage,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
ChatCompletionResponseEvent,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_model_alias,
|
||||
ModelRegistryHelper,
|
||||
build_model_alias,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_message_to_openai_dict,
|
||||
get_sampling_options,
|
||||
process_completion_response,
|
||||
process_completion_stream_response,
|
||||
OpenAICompatCompletionResponse,
|
||||
UnparseableToolCall,
|
||||
convert_message_to_openai_dict,
|
||||
convert_tool_call,
|
||||
get_sampling_options,
|
||||
process_chat_completion_stream_response,
|
||||
process_completion_response,
|
||||
process_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
completion_request_to_prompt,
|
||||
|
|
|
@ -6,11 +6,9 @@
|
|||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
|
||||
from llama_stack.apis.safety import (
|
||||
RunShieldResponse,
|
||||
Safety,
|
||||
|
@ -23,7 +21,6 @@ from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
|||
|
||||
from .config import BedrockSafetyConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.shields import Shield
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
from .config import ModelContextProtocolConfig
|
||||
|
||||
from .model_context_protocol import ModelContextProtocolToolRuntimeImpl
|
||||
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
|
||||
from .config import ChromaRemoteImplConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
|
@ -10,15 +10,13 @@ from typing import Any, Dict, List, Optional, Tuple
|
|||
import psycopg2
|
||||
from numpy.typing import NDArray
|
||||
from psycopg2 import sql
|
||||
from psycopg2.extras import execute_values, Json
|
||||
|
||||
from psycopg2.extras import Json, execute_values
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
|
|
|
@ -20,6 +20,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
|
||||
from .config import QdrantConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue