more fixes

This commit is contained in:
Xi Yan 2024-12-26 18:30:42 -08:00
parent 6596caed55
commit 27da763af9
10 changed files with 44 additions and 26 deletions

View file

@ -9,15 +9,26 @@ import logging
import shutil import shutil
import tempfile import tempfile
import uuid import uuid
from typing import AsyncGenerator from typing import AsyncGenerator, List, Optional, Union
from termcolor import colored from termcolor import colored
from llama_stack.apis.inference import Inference from llama_stack.apis.agents import (
AgentConfig,
AgentCreateResponse,
Agents,
AgentSessionCreateResponse,
AgentStepResponse,
AgentTurnCreateRequest,
Attachment,
Session,
Turn,
)
from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage
from llama_stack.apis.memory import Memory from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.agents import * # noqa: F403
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl

View file

@ -10,9 +10,11 @@ import uuid
from datetime import datetime from datetime import datetime
from typing import List, Optional from typing import List, Optional
from llama_stack.apis.agents import * # noqa: F403
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.apis.agents import Turn
from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.kvstore import KVStore
log = logging.getLogger(__name__) log = logging.getLogger(__name__)

View file

@ -7,8 +7,6 @@
from typing import List from typing import List
from jinja2 import Template from jinja2 import Template
from llama_models.llama3.api import * # noqa: F403
from llama_stack.apis.agents import ( from llama_stack.apis.agents import (
DefaultMemoryQueryGeneratorConfig, DefaultMemoryQueryGeneratorConfig,
@ -16,7 +14,7 @@ from llama_stack.apis.agents import (
MemoryQueryGenerator, MemoryQueryGenerator,
MemoryQueryGeneratorConfig, MemoryQueryGeneratorConfig,
) )
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import Message, UserMessage
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str, interleaved_content_as_str,
) )

View file

@ -9,7 +9,9 @@ import logging
from typing import List from typing import List
from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.inference import Message
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
log = logging.getLogger(__name__) log = logging.getLogger(__name__)

View file

@ -7,7 +7,7 @@
from typing import List from typing import List
from llama_stack.apis.inference import Message from llama_stack.apis.inference import Message
from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.safety import Safety
from ..safety import ShieldRunnerMixin from ..safety import ShieldRunnerMixin
from .builtin import BaseTool from .builtin import BaseTool

View file

@ -3,7 +3,7 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.apis.datasetio import * # noqa: F401, F403 from pydantic import BaseModel
class LocalFSDatasetIOConfig(BaseModel): ... class LocalFSDatasetIOConfig(BaseModel): ...

View file

@ -3,18 +3,19 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict, List, Optional
import pandas
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
import base64 import base64
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse from urllib.parse import urlparse
import pandas
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url

View file

@ -5,13 +5,15 @@
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from llama_models.llama3.api.datatypes import * # noqa: F403
from tqdm import tqdm from tqdm import tqdm
from .....apis.common.job_types import Job
from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.agents import Agents from llama_stack.apis.agents import Agents
from llama_stack.apis.common.type_system import (
ChatCompletionInputType,
CompletionInputType,
StringType,
)
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval_tasks import EvalTask from llama_stack.apis.eval_tasks import EvalTask
@ -20,6 +22,9 @@ from llama_stack.apis.scoring import Scoring
from llama_stack.providers.datatypes import EvalTasksProtocolPrivate from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from .....apis.common.job_types import Job
from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
from .config import MetaReferenceEvalConfig from .config import MetaReferenceEvalConfig
EVAL_TASKS_PREFIX = "eval_tasks:" EVAL_TASKS_PREFIX = "eval_tasks:"

View file

@ -6,11 +6,10 @@
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from llama_models.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F401, F403
from pydantic import BaseModel, field_validator from pydantic import BaseModel, field_validator
from llama_stack.apis.inference import QuantizationConfig
from llama_stack.providers.utils.inference import supported_inference_models from llama_stack.providers.utils.inference import supported_inference_models

View file

@ -32,11 +32,11 @@ from llama_models.llama3.reference_impl.multimodal.model import (
CrossAttentionTransformer, CrossAttentionTransformer,
) )
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from pydantic import BaseModel
from llama_stack.apis.inference import * # noqa: F403
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
from pydantic import BaseModel
from llama_stack.apis.inference import ResponseFormat, ResponseFormatType
from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (